新增Google Vertex AI的access token获取接口

This commit is contained in:
shuohigh@gmail.com 2025-06-17 16:41:48 +08:00
parent 79c462d0e4
commit ad41b80f4f
3 changed files with 71 additions and 15 deletions

View File

@ -1,5 +1,5 @@
from typing import Union, Any
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator, ConfigDict
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator
from pydantic.json_schema import JsonSchemaValue
from ..utils.TimeUtils import TimeDelta

View File

@ -7,15 +7,17 @@ import sentry_sdk
from fastapi.responses import Response
from loguru import logger
import httpx
from fastapi import APIRouter, UploadFile, Header, HTTPException
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends
from pydantic import BaseModel, Field
from starlette import status
from starlette.responses import JSONResponse
from BowongModalFunctions.config import WorkerConfig
from BowongModalFunctions.middleware.authorization import verify_token
from BowongModalFunctions.models.web_model import SentryTransactionInfo, GeminiResultResponse, GeminiRequest, \
ModalTaskResponse, MakeGridGeminiRequest
from BowongModalFunctions.utils.ModalUtils import ModalUtils
from BowongModalFunctions.utils.HTTPUtils import GoogleAuthUtils
config = WorkerConfig()
@ -25,6 +27,7 @@ router = APIRouter(prefix="/google", tags=["Google"])
class GoogleAPIKeyHeaders(BaseModel):
x_google_api_key: Optional[str] = Field(description="Google API Key")
class BundleHeaders(BaseModel):
x_google_api_key: Optional[str] = Field(description="Google API Key")
x_trace_id: str = Field(description="Sentry Transaction ID", default=None)
@ -101,6 +104,7 @@ async def delete_file(filename: str, headers: Annotated[GoogleAPIKeyHeaders, Hea
response.raise_for_status()
return JSONResponse(content=response.json(), status_code=response.status_code)
@router.delete('/delete_all', summary="删除所有已上传的文件/第一页")
async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
@ -120,6 +124,7 @@ async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
return JSONResponse(content={"msg": f"删除文件{filename}失败"}, status_code=resp.status_code)
return JSONResponse(content=response.json(), status_code=response.status_code)
@router.get('/list', summary="列出已上传的文件")
async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
@ -132,16 +137,33 @@ async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
return JSONResponse(content={}, status_code=response.status_code)
return JSONResponse(content=response.json() if len(response.text) > 0 else "", status_code=response.status_code)
@router.get('/access-token', summary="获取一个新的Google Access Token", dependencies=[Depends(verify_token)])
async def get_access_token() -> GoogleAuthUtils.GoogleAuthResponse:
service_account_json = os.environ.get("GOOGLE_AUTH_JSON", None)
if not service_account_json:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google IAM Service Account JSON")
service_account_info = json.loads(service_account_json)
access_token = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=service_account_info,
scopes=['https://www.googleapis.com/auth/cloud-platform'])
return access_token
@router.post("/make_grid_gemini", summary="将输入图拼为网格上传到Gemini网盘")
async def make_grid_gemini_upload(data: MakeGridGeminiRequest, headers: Annotated[BundleHeaders, Header()]):
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
fn = modal.Function.from_name(config.modal_app_name,"make_image_grid_upload", environment_name=config.modal_environment)
image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size, data.padding, data.separator, google_api_key,
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage())
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage))
return JSONResponse(content={"uri":image_grid_uri}, status_code=200)
fn = modal.Function.from_name(config.modal_app_name, "make_image_grid_upload",
environment_name=config.modal_environment)
image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size,
data.padding, data.separator, google_api_key,
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
x_baggage=sentry_sdk.get_baggage())
if headers.x_trace_id is None else SentryTransactionInfo(
x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage))
return JSONResponse(content={"uri": image_grid_uri}, status_code=200)
@router.post('/inference_gemini', summary="使用Gemini推理hls视频流指定时间段的打点情况")
async def inference_gemini(
@ -151,10 +173,15 @@ async def inference_gemini(
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
fn = modal.Function.from_name(config.modal_app_name,"video_hls_slice_inference", environment_name=config.modal_environment)
fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list, data.start_time, data.end_time, data.options,
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage())
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage), data.webhook, 2, data.scale)
fn = modal.Function.from_name(config.modal_app_name, "video_hls_slice_inference",
environment_name=config.modal_environment)
fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list,
data.start_time, data.end_time, data.options,
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
x_baggage=sentry_sdk.get_baggage())
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id,
x_baggage=headers.x_baggage),
data.webhook, 2, data.scale)
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
@ -166,10 +193,11 @@ async def gemini_status(task_id: str, response: Response) -> GeminiResultRespons
response.headers["x-baggage"] = task_info.transaction.x_baggage
try:
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
error=task_info.error_reason, result=json.dumps(task_info.results, ensure_ascii=False)
error=task_info.error_reason,
result=json.dumps(task_info.results, ensure_ascii=False)
if task_info.results is not None else "")
except Exception as e:
logger.exception(f"获取Gemini状态发生错误 {e}")
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
error=task_info.error_reason + f"获取Gemini状态发生错误 {e}"
if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="")
if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="")

View File

@ -1,5 +1,5 @@
from typing import Optional, Any, Union
from functools import wraps
from typing import Union, Dict, Any
import backoff
import httpx
import asyncio
@ -7,6 +7,8 @@ from loguru import logger
import aiofiles
from pathlib import Path
from pydantic import BaseModel
class HTTPDownloadUtils:
# 创建一个类级别的 Semaphore 来控制并发
@ -98,3 +100,29 @@ class HTTPDownloadUtils:
tasks.append(task)
return await asyncio.gather(*tasks, return_exceptions=True)
from google.oauth2 import service_account
class GoogleAuthUtils:
class GoogleAuthResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
@staticmethod
async def get_google_auth_jwt(service_account_info: Dict[str, str], scopes: list[str]) -> GoogleAuthResponse:
credentials: service_account.Credentials = service_account.Credentials.from_service_account_info(
service_account_info, scopes=scopes)
assertion = credentials._make_authorization_grant_assertion().decode('utf-8')
logger.info(f"jwt_assertion: {assertion}")
params = {
'grant_type': "urn:ietf:params:oauth:grant-type:jwt-bearer",
'assertion': assertion
}
response = httpx.post(url="https://oauth2.googleapis.com/token",
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data=params)
response.raise_for_status()
return GoogleAuthUtils.GoogleAuthResponse.model_validate_json(response.text)