新增Google Vertex AI的access token获取接口
This commit is contained in:
parent
79c462d0e4
commit
ad41b80f4f
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Union, Any
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator, ConfigDict
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from ..utils.TimeUtils import TimeDelta
|
||||
|
||||
|
|
|
|||
|
|
@ -7,15 +7,17 @@ import sentry_sdk
|
|||
from fastapi.responses import Response
|
||||
from loguru import logger
|
||||
import httpx
|
||||
from fastapi import APIRouter, UploadFile, Header, HTTPException
|
||||
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette import status
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from BowongModalFunctions.config import WorkerConfig
|
||||
from BowongModalFunctions.middleware.authorization import verify_token
|
||||
from BowongModalFunctions.models.web_model import SentryTransactionInfo, GeminiResultResponse, GeminiRequest, \
|
||||
ModalTaskResponse, MakeGridGeminiRequest
|
||||
from BowongModalFunctions.utils.ModalUtils import ModalUtils
|
||||
from BowongModalFunctions.utils.HTTPUtils import GoogleAuthUtils
|
||||
|
||||
config = WorkerConfig()
|
||||
|
||||
|
|
@ -25,6 +27,7 @@ router = APIRouter(prefix="/google", tags=["Google"])
|
|||
class GoogleAPIKeyHeaders(BaseModel):
|
||||
x_google_api_key: Optional[str] = Field(description="Google API Key")
|
||||
|
||||
|
||||
class BundleHeaders(BaseModel):
|
||||
x_google_api_key: Optional[str] = Field(description="Google API Key")
|
||||
x_trace_id: str = Field(description="Sentry Transaction ID", default=None)
|
||||
|
|
@ -101,6 +104,7 @@ async def delete_file(filename: str, headers: Annotated[GoogleAPIKeyHeaders, Hea
|
|||
response.raise_for_status()
|
||||
return JSONResponse(content=response.json(), status_code=response.status_code)
|
||||
|
||||
|
||||
@router.delete('/delete_all', summary="删除所有已上传的文件/第一页")
|
||||
async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||
|
|
@ -120,6 +124,7 @@ async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
|||
return JSONResponse(content={"msg": f"删除文件{filename}失败"}, status_code=resp.status_code)
|
||||
return JSONResponse(content=response.json(), status_code=response.status_code)
|
||||
|
||||
|
||||
@router.get('/list', summary="列出已上传的文件")
|
||||
async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||
|
|
@ -132,16 +137,33 @@ async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
|||
return JSONResponse(content={}, status_code=response.status_code)
|
||||
return JSONResponse(content=response.json() if len(response.text) > 0 else "", status_code=response.status_code)
|
||||
|
||||
|
||||
@router.get('/access-token', summary="获取一个新的Google Access Token", dependencies=[Depends(verify_token)])
|
||||
async def get_access_token() -> GoogleAuthUtils.GoogleAuthResponse:
|
||||
service_account_json = os.environ.get("GOOGLE_AUTH_JSON", None)
|
||||
if not service_account_json:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google IAM Service Account JSON")
|
||||
service_account_info = json.loads(service_account_json)
|
||||
access_token = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=service_account_info,
|
||||
scopes=['https://www.googleapis.com/auth/cloud-platform'])
|
||||
return access_token
|
||||
|
||||
|
||||
@router.post("/make_grid_gemini", summary="将输入图拼为网格上传到Gemini网盘")
|
||||
async def make_grid_gemini_upload(data: MakeGridGeminiRequest, headers: Annotated[BundleHeaders, Header()]):
|
||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
||||
fn = modal.Function.from_name(config.modal_app_name,"make_image_grid_upload", environment_name=config.modal_environment)
|
||||
image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size, data.padding, data.separator, google_api_key,
|
||||
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage())
|
||||
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage))
|
||||
return JSONResponse(content={"uri":image_grid_uri}, status_code=200)
|
||||
fn = modal.Function.from_name(config.modal_app_name, "make_image_grid_upload",
|
||||
environment_name=config.modal_environment)
|
||||
image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size,
|
||||
data.padding, data.separator, google_api_key,
|
||||
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
|
||||
x_baggage=sentry_sdk.get_baggage())
|
||||
if headers.x_trace_id is None else SentryTransactionInfo(
|
||||
x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage))
|
||||
return JSONResponse(content={"uri": image_grid_uri}, status_code=200)
|
||||
|
||||
|
||||
@router.post('/inference_gemini', summary="使用Gemini推理hls视频流指定时间段的打点情况")
|
||||
async def inference_gemini(
|
||||
|
|
@ -151,10 +173,15 @@ async def inference_gemini(
|
|||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
||||
fn = modal.Function.from_name(config.modal_app_name,"video_hls_slice_inference", environment_name=config.modal_environment)
|
||||
fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list, data.start_time, data.end_time, data.options,
|
||||
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage())
|
||||
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage), data.webhook, 2, data.scale)
|
||||
fn = modal.Function.from_name(config.modal_app_name, "video_hls_slice_inference",
|
||||
environment_name=config.modal_environment)
|
||||
fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list,
|
||||
data.start_time, data.end_time, data.options,
|
||||
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
|
||||
x_baggage=sentry_sdk.get_baggage())
|
||||
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id,
|
||||
x_baggage=headers.x_baggage),
|
||||
data.webhook, 2, data.scale)
|
||||
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
|
||||
|
||||
|
||||
|
|
@ -166,10 +193,11 @@ async def gemini_status(task_id: str, response: Response) -> GeminiResultRespons
|
|||
response.headers["x-baggage"] = task_info.transaction.x_baggage
|
||||
try:
|
||||
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
|
||||
error=task_info.error_reason, result=json.dumps(task_info.results, ensure_ascii=False)
|
||||
error=task_info.error_reason,
|
||||
result=json.dumps(task_info.results, ensure_ascii=False)
|
||||
if task_info.results is not None else "")
|
||||
except Exception as e:
|
||||
logger.exception(f"获取Gemini状态发生错误 {e}")
|
||||
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
|
||||
error=task_info.error_reason + f"获取Gemini状态发生错误 {e}"
|
||||
if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="")
|
||||
if task_info.error_reason else f"获取Gemini状态发生错误 {e}", result="")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Optional, Any, Union
|
||||
from functools import wraps
|
||||
from typing import Union, Dict, Any
|
||||
|
||||
import backoff
|
||||
import httpx
|
||||
import asyncio
|
||||
|
|
@ -7,6 +7,8 @@ from loguru import logger
|
|||
import aiofiles
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HTTPDownloadUtils:
|
||||
# 创建一个类级别的 Semaphore 来控制并发
|
||||
|
|
@ -98,3 +100,29 @@ class HTTPDownloadUtils:
|
|||
tasks.append(task)
|
||||
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
from google.oauth2 import service_account
|
||||
|
||||
|
||||
class GoogleAuthUtils:
|
||||
class GoogleAuthResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
|
||||
@staticmethod
|
||||
async def get_google_auth_jwt(service_account_info: Dict[str, str], scopes: list[str]) -> GoogleAuthResponse:
|
||||
credentials: service_account.Credentials = service_account.Credentials.from_service_account_info(
|
||||
service_account_info, scopes=scopes)
|
||||
assertion = credentials._make_authorization_grant_assertion().decode('utf-8')
|
||||
logger.info(f"jwt_assertion: {assertion}")
|
||||
params = {
|
||||
'grant_type': "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
'assertion': assertion
|
||||
}
|
||||
response = httpx.post(url="https://oauth2.googleapis.com/token",
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
data=params)
|
||||
response.raise_for_status()
|
||||
return GoogleAuthUtils.GoogleAuthResponse.model_validate_json(response.text)
|
||||
|
|
|
|||
Loading…
Reference in New Issue