diff --git a/src/BowongModalFunctions/models/ffmpeg_worker_model.py b/src/BowongModalFunctions/models/ffmpeg_worker_model.py index 310173f..ed28886 100644 --- a/src/BowongModalFunctions/models/ffmpeg_worker_model.py +++ b/src/BowongModalFunctions/models/ffmpeg_worker_model.py @@ -1,5 +1,5 @@ from typing import Union, Any -from pydantic import BaseModel, Field, computed_field, field_validator, model_validator, ConfigDict +from pydantic import BaseModel, Field, computed_field, field_validator, model_validator from pydantic.json_schema import JsonSchemaValue from ..utils.TimeUtils import TimeDelta diff --git a/src/BowongModalFunctions/router/google.py b/src/BowongModalFunctions/router/google.py index 8516358..fb6775e 100644 --- a/src/BowongModalFunctions/router/google.py +++ b/src/BowongModalFunctions/router/google.py @@ -7,15 +7,17 @@ import sentry_sdk from fastapi.responses import Response from loguru import logger import httpx -from fastapi import APIRouter, UploadFile, Header, HTTPException +from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends from pydantic import BaseModel, Field from starlette import status from starlette.responses import JSONResponse from BowongModalFunctions.config import WorkerConfig +from BowongModalFunctions.middleware.authorization import verify_token from BowongModalFunctions.models.web_model import SentryTransactionInfo, GeminiResultResponse, GeminiRequest, \ ModalTaskResponse, MakeGridGeminiRequest from BowongModalFunctions.utils.ModalUtils import ModalUtils +from BowongModalFunctions.utils.HTTPUtils import GoogleAuthUtils config = WorkerConfig() @@ -25,6 +27,7 @@ router = APIRouter(prefix="/google", tags=["Google"]) class GoogleAPIKeyHeaders(BaseModel): x_google_api_key: Optional[str] = Field(description="Google API Key") + class BundleHeaders(BaseModel): x_google_api_key: Optional[str] = Field(description="Google API Key") x_trace_id: str = Field(description="Sentry Transaction ID", default=None) @@ -101,6 +104,7 @@ async def delete_file(filename: str, headers: Annotated[GoogleAPIKeyHeaders, Hea response.raise_for_status() return JSONResponse(content=response.json(), status_code=response.status_code) + @router.delete('/delete_all', summary="删除所有已上传的文件/第一页") async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]): google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") @@ -120,6 +124,7 @@ async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]): return JSONResponse(content={"msg": f"删除文件{filename}失败"}, status_code=resp.status_code) return JSONResponse(content=response.json(), status_code=response.status_code) + @router.get('/list', summary="列出已上传的文件") async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]): google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") @@ -132,16 +137,33 @@ async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]): return JSONResponse(content={}, status_code=response.status_code) return JSONResponse(content=response.json() if len(response.text) > 0 else "", status_code=response.status_code) + +@router.get('/access-token', summary="获取一个新的Google Access Token", dependencies=[Depends(verify_token)]) +async def get_access_token() -> GoogleAuthUtils.GoogleAuthResponse: + service_account_json = os.environ.get("GOOGLE_AUTH_JSON", None) + if not service_account_json: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google IAM Service Account JSON") + service_account_info = json.loads(service_account_json) + access_token = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=service_account_info, + scopes=['https://www.googleapis.com/auth/cloud-platform']) + return access_token + + @router.post("/make_grid_gemini", summary="将输入图拼为网格上传到Gemini网盘") async def make_grid_gemini_upload(data: MakeGridGeminiRequest, headers: Annotated[BundleHeaders, Header()]): google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") if not google_api_key: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key") - fn = modal.Function.from_name(config.modal_app_name,"make_image_grid_upload", environment_name=config.modal_environment) - image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size, data.padding, data.separator, google_api_key, - SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage()) - if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage)) - return JSONResponse(content={"uri":image_grid_uri}, status_code=200) + fn = modal.Function.from_name(config.modal_app_name, "make_image_grid_upload", + environment_name=config.modal_environment) + image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size, + data.padding, data.separator, google_api_key, + SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), + x_baggage=sentry_sdk.get_baggage()) + if headers.x_trace_id is None else SentryTransactionInfo( + x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage)) + return JSONResponse(content={"uri": image_grid_uri}, status_code=200) + @router.post('/inference_gemini', summary="使用Gemini推理hls视频流指定时间段的打点情况") async def inference_gemini( @@ -151,10 +173,15 @@ async def inference_gemini( google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") if not google_api_key: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key") - fn = modal.Function.from_name(config.modal_app_name,"video_hls_slice_inference", environment_name=config.modal_environment) - fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list, data.start_time, data.end_time, data.options, - SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage()) - if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage), data.webhook, 2, data.scale) + fn = modal.Function.from_name(config.modal_app_name, "video_hls_slice_inference", + environment_name=config.modal_environment) + fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list, + data.start_time, data.end_time, data.options, + SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), + x_baggage=sentry_sdk.get_baggage()) + if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, + x_baggage=headers.x_baggage), + data.webhook, 2, data.scale) return ModalTaskResponse(success=True, taskId=fn_call.object_id) @@ -166,10 +193,11 @@ async def gemini_status(task_id: str, response: Response) -> GeminiResultRespons response.headers["x-baggage"] = task_info.transaction.x_baggage try: return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value), - error=task_info.error_reason, result=json.dumps(task_info.results, ensure_ascii=False) + error=task_info.error_reason, + result=json.dumps(task_info.results, ensure_ascii=False) if task_info.results is not None else "") except Exception as e: logger.exception(f"获取Gemini状态发生错误 {e}") return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value), error=task_info.error_reason + f"获取Gemini状态发生错误 {e}" - if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="") \ No newline at end of file + if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="") diff --git a/src/BowongModalFunctions/utils/HTTPUtils.py b/src/BowongModalFunctions/utils/HTTPUtils.py index 4ea07a3..e885c56 100644 --- a/src/BowongModalFunctions/utils/HTTPUtils.py +++ b/src/BowongModalFunctions/utils/HTTPUtils.py @@ -1,5 +1,5 @@ -from typing import Optional, Any, Union -from functools import wraps +from typing import Union, Dict, Any + import backoff import httpx import asyncio @@ -7,6 +7,8 @@ from loguru import logger import aiofiles from pathlib import Path +from pydantic import BaseModel + class HTTPDownloadUtils: # 创建一个类级别的 Semaphore 来控制并发 @@ -98,3 +100,29 @@ class HTTPDownloadUtils: tasks.append(task) return await asyncio.gather(*tasks, return_exceptions=True) + + +from google.oauth2 import service_account + + +class GoogleAuthUtils: + class GoogleAuthResponse(BaseModel): + access_token: str + expires_in: int + token_type: str + + @staticmethod + async def get_google_auth_jwt(service_account_info: Dict[str, str], scopes: list[str]) -> GoogleAuthResponse: + credentials: service_account.Credentials = service_account.Credentials.from_service_account_info( + service_account_info, scopes=scopes) + assertion = credentials._make_authorization_grant_assertion().decode('utf-8') + logger.info(f"jwt_assertion: {assertion}") + params = { + 'grant_type': "urn:ietf:params:oauth:grant-type:jwt-bearer", + 'assertion': assertion + } + response = httpx.post(url="https://oauth2.googleapis.com/token", + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + data=params) + response.raise_for_status() + return GoogleAuthUtils.GoogleAuthResponse.model_validate_json(response.text)