新增Google Vertex AI的access token获取接口

This commit is contained in:
shuohigh@gmail.com 2025-06-17 16:41:48 +08:00
parent 79c462d0e4
commit ad41b80f4f
3 changed files with 71 additions and 15 deletions

View File

@ -1,5 +1,5 @@
from typing import Union, Any from typing import Union, Any
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator, ConfigDict from pydantic import BaseModel, Field, computed_field, field_validator, model_validator
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from ..utils.TimeUtils import TimeDelta from ..utils.TimeUtils import TimeDelta

View File

@ -7,15 +7,17 @@ import sentry_sdk
from fastapi.responses import Response from fastapi.responses import Response
from loguru import logger from loguru import logger
import httpx import httpx
from fastapi import APIRouter, UploadFile, Header, HTTPException from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from starlette import status from starlette import status
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from BowongModalFunctions.config import WorkerConfig from BowongModalFunctions.config import WorkerConfig
from BowongModalFunctions.middleware.authorization import verify_token
from BowongModalFunctions.models.web_model import SentryTransactionInfo, GeminiResultResponse, GeminiRequest, \ from BowongModalFunctions.models.web_model import SentryTransactionInfo, GeminiResultResponse, GeminiRequest, \
ModalTaskResponse, MakeGridGeminiRequest ModalTaskResponse, MakeGridGeminiRequest
from BowongModalFunctions.utils.ModalUtils import ModalUtils from BowongModalFunctions.utils.ModalUtils import ModalUtils
from BowongModalFunctions.utils.HTTPUtils import GoogleAuthUtils
config = WorkerConfig() config = WorkerConfig()
@ -25,6 +27,7 @@ router = APIRouter(prefix="/google", tags=["Google"])
class GoogleAPIKeyHeaders(BaseModel): class GoogleAPIKeyHeaders(BaseModel):
x_google_api_key: Optional[str] = Field(description="Google API Key") x_google_api_key: Optional[str] = Field(description="Google API Key")
class BundleHeaders(BaseModel): class BundleHeaders(BaseModel):
x_google_api_key: Optional[str] = Field(description="Google API Key") x_google_api_key: Optional[str] = Field(description="Google API Key")
x_trace_id: str = Field(description="Sentry Transaction ID", default=None) x_trace_id: str = Field(description="Sentry Transaction ID", default=None)
@ -101,6 +104,7 @@ async def delete_file(filename: str, headers: Annotated[GoogleAPIKeyHeaders, Hea
response.raise_for_status() response.raise_for_status()
return JSONResponse(content=response.json(), status_code=response.status_code) return JSONResponse(content=response.json(), status_code=response.status_code)
@router.delete('/delete_all', summary="删除所有已上传的文件/第一页") @router.delete('/delete_all', summary="删除所有已上传的文件/第一页")
async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]): async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
@ -120,6 +124,7 @@ async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
return JSONResponse(content={"msg": f"删除文件{filename}失败"}, status_code=resp.status_code) return JSONResponse(content={"msg": f"删除文件{filename}失败"}, status_code=resp.status_code)
return JSONResponse(content=response.json(), status_code=response.status_code) return JSONResponse(content=response.json(), status_code=response.status_code)
@router.get('/list', summary="列出已上传的文件") @router.get('/list', summary="列出已上传的文件")
async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]): async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
@ -132,16 +137,33 @@ async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
return JSONResponse(content={}, status_code=response.status_code) return JSONResponse(content={}, status_code=response.status_code)
return JSONResponse(content=response.json() if len(response.text) > 0 else "", status_code=response.status_code) return JSONResponse(content=response.json() if len(response.text) > 0 else "", status_code=response.status_code)
@router.get('/access-token', summary="获取一个新的Google Access Token", dependencies=[Depends(verify_token)])
async def get_access_token() -> GoogleAuthUtils.GoogleAuthResponse:
service_account_json = os.environ.get("GOOGLE_AUTH_JSON", None)
if not service_account_json:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google IAM Service Account JSON")
service_account_info = json.loads(service_account_json)
access_token = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=service_account_info,
scopes=['https://www.googleapis.com/auth/cloud-platform'])
return access_token
@router.post("/make_grid_gemini", summary="将输入图拼为网格上传到Gemini网盘") @router.post("/make_grid_gemini", summary="将输入图拼为网格上传到Gemini网盘")
async def make_grid_gemini_upload(data: MakeGridGeminiRequest, headers: Annotated[BundleHeaders, Header()]): async def make_grid_gemini_upload(data: MakeGridGeminiRequest, headers: Annotated[BundleHeaders, Header()]):
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key: if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
fn = modal.Function.from_name(config.modal_app_name,"make_image_grid_upload", environment_name=config.modal_environment) fn = modal.Function.from_name(config.modal_app_name, "make_image_grid_upload",
image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size, data.padding, data.separator, google_api_key, environment_name=config.modal_environment)
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage()) image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size,
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage)) data.padding, data.separator, google_api_key,
return JSONResponse(content={"uri":image_grid_uri}, status_code=200) SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
x_baggage=sentry_sdk.get_baggage())
if headers.x_trace_id is None else SentryTransactionInfo(
x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage))
return JSONResponse(content={"uri": image_grid_uri}, status_code=200)
@router.post('/inference_gemini', summary="使用Gemini推理hls视频流指定时间段的打点情况") @router.post('/inference_gemini', summary="使用Gemini推理hls视频流指定时间段的打点情况")
async def inference_gemini( async def inference_gemini(
@ -151,10 +173,15 @@ async def inference_gemini(
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key: if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
fn = modal.Function.from_name(config.modal_app_name,"video_hls_slice_inference", environment_name=config.modal_environment) fn = modal.Function.from_name(config.modal_app_name, "video_hls_slice_inference",
fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list, data.start_time, data.end_time, data.options, environment_name=config.modal_environment)
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage()) fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list,
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage), data.webhook, 2, data.scale) data.start_time, data.end_time, data.options,
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
x_baggage=sentry_sdk.get_baggage())
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id,
x_baggage=headers.x_baggage),
data.webhook, 2, data.scale)
return ModalTaskResponse(success=True, taskId=fn_call.object_id) return ModalTaskResponse(success=True, taskId=fn_call.object_id)
@ -166,10 +193,11 @@ async def gemini_status(task_id: str, response: Response) -> GeminiResultRespons
response.headers["x-baggage"] = task_info.transaction.x_baggage response.headers["x-baggage"] = task_info.transaction.x_baggage
try: try:
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value), return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
error=task_info.error_reason, result=json.dumps(task_info.results, ensure_ascii=False) error=task_info.error_reason,
result=json.dumps(task_info.results, ensure_ascii=False)
if task_info.results is not None else "") if task_info.results is not None else "")
except Exception as e: except Exception as e:
logger.exception(f"获取Gemini状态发生错误 {e}") logger.exception(f"获取Gemini状态发生错误 {e}")
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value), return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
error=task_info.error_reason + f"获取Gemini状态发生错误 {e}" error=task_info.error_reason + f"获取Gemini状态发生错误 {e}"
if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="") if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="")

View File

@ -1,5 +1,5 @@
from typing import Optional, Any, Union from typing import Union, Dict, Any
from functools import wraps
import backoff import backoff
import httpx import httpx
import asyncio import asyncio
@ -7,6 +7,8 @@ from loguru import logger
import aiofiles import aiofiles
from pathlib import Path from pathlib import Path
from pydantic import BaseModel
class HTTPDownloadUtils: class HTTPDownloadUtils:
# 创建一个类级别的 Semaphore 来控制并发 # 创建一个类级别的 Semaphore 来控制并发
@ -98,3 +100,29 @@ class HTTPDownloadUtils:
tasks.append(task) tasks.append(task)
return await asyncio.gather(*tasks, return_exceptions=True) return await asyncio.gather(*tasks, return_exceptions=True)
from google.oauth2 import service_account
class GoogleAuthUtils:
class GoogleAuthResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
@staticmethod
async def get_google_auth_jwt(service_account_info: Dict[str, str], scopes: list[str]) -> GoogleAuthResponse:
credentials: service_account.Credentials = service_account.Credentials.from_service_account_info(
service_account_info, scopes=scopes)
assertion = credentials._make_authorization_grant_assertion().decode('utf-8')
logger.info(f"jwt_assertion: {assertion}")
params = {
'grant_type': "urn:ietf:params:oauth:grant-type:jwt-bearer",
'assertion': assertion
}
response = httpx.post(url="https://oauth2.googleapis.com/token",
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data=params)
response.raise_for_status()
return GoogleAuthUtils.GoogleAuthResponse.model_validate_json(response.text)