新增Google Vertex AI的access token获取接口
This commit is contained in:
parent
79c462d0e4
commit
ad41b80f4f
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Union, Any
|
from typing import Union, Any
|
||||||
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator, ConfigDict
|
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from ..utils.TimeUtils import TimeDelta
|
from ..utils.TimeUtils import TimeDelta
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,15 +7,17 @@ import sentry_sdk
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, UploadFile, Header, HTTPException
|
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from BowongModalFunctions.config import WorkerConfig
|
from BowongModalFunctions.config import WorkerConfig
|
||||||
|
from BowongModalFunctions.middleware.authorization import verify_token
|
||||||
from BowongModalFunctions.models.web_model import SentryTransactionInfo, GeminiResultResponse, GeminiRequest, \
|
from BowongModalFunctions.models.web_model import SentryTransactionInfo, GeminiResultResponse, GeminiRequest, \
|
||||||
ModalTaskResponse, MakeGridGeminiRequest
|
ModalTaskResponse, MakeGridGeminiRequest
|
||||||
from BowongModalFunctions.utils.ModalUtils import ModalUtils
|
from BowongModalFunctions.utils.ModalUtils import ModalUtils
|
||||||
|
from BowongModalFunctions.utils.HTTPUtils import GoogleAuthUtils
|
||||||
|
|
||||||
config = WorkerConfig()
|
config = WorkerConfig()
|
||||||
|
|
||||||
|
|
@ -25,6 +27,7 @@ router = APIRouter(prefix="/google", tags=["Google"])
|
||||||
class GoogleAPIKeyHeaders(BaseModel):
|
class GoogleAPIKeyHeaders(BaseModel):
|
||||||
x_google_api_key: Optional[str] = Field(description="Google API Key")
|
x_google_api_key: Optional[str] = Field(description="Google API Key")
|
||||||
|
|
||||||
|
|
||||||
class BundleHeaders(BaseModel):
|
class BundleHeaders(BaseModel):
|
||||||
x_google_api_key: Optional[str] = Field(description="Google API Key")
|
x_google_api_key: Optional[str] = Field(description="Google API Key")
|
||||||
x_trace_id: str = Field(description="Sentry Transaction ID", default=None)
|
x_trace_id: str = Field(description="Sentry Transaction ID", default=None)
|
||||||
|
|
@ -101,6 +104,7 @@ async def delete_file(filename: str, headers: Annotated[GoogleAPIKeyHeaders, Hea
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return JSONResponse(content=response.json(), status_code=response.status_code)
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
||||||
|
|
||||||
|
|
||||||
@router.delete('/delete_all', summary="删除所有已上传的文件/第一页")
|
@router.delete('/delete_all', summary="删除所有已上传的文件/第一页")
|
||||||
async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
||||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||||
|
|
@ -120,6 +124,7 @@ async def delete_all(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
||||||
return JSONResponse(content={"msg": f"删除文件{filename}失败"}, status_code=resp.status_code)
|
return JSONResponse(content={"msg": f"删除文件{filename}失败"}, status_code=resp.status_code)
|
||||||
return JSONResponse(content=response.json(), status_code=response.status_code)
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
||||||
|
|
||||||
|
|
||||||
@router.get('/list', summary="列出已上传的文件")
|
@router.get('/list', summary="列出已上传的文件")
|
||||||
async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
||||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||||
|
|
@ -132,16 +137,33 @@ async def list_files(headers: Annotated[GoogleAPIKeyHeaders, Header()]):
|
||||||
return JSONResponse(content={}, status_code=response.status_code)
|
return JSONResponse(content={}, status_code=response.status_code)
|
||||||
return JSONResponse(content=response.json() if len(response.text) > 0 else "", status_code=response.status_code)
|
return JSONResponse(content=response.json() if len(response.text) > 0 else "", status_code=response.status_code)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/access-token', summary="获取一个新的Google Access Token", dependencies=[Depends(verify_token)])
|
||||||
|
async def get_access_token() -> GoogleAuthUtils.GoogleAuthResponse:
|
||||||
|
service_account_json = os.environ.get("GOOGLE_AUTH_JSON", None)
|
||||||
|
if not service_account_json:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google IAM Service Account JSON")
|
||||||
|
service_account_info = json.loads(service_account_json)
|
||||||
|
access_token = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=service_account_info,
|
||||||
|
scopes=['https://www.googleapis.com/auth/cloud-platform'])
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
|
||||||
@router.post("/make_grid_gemini", summary="将输入图拼为网格上传到Gemini网盘")
|
@router.post("/make_grid_gemini", summary="将输入图拼为网格上传到Gemini网盘")
|
||||||
async def make_grid_gemini_upload(data: MakeGridGeminiRequest, headers: Annotated[BundleHeaders, Header()]):
|
async def make_grid_gemini_upload(data: MakeGridGeminiRequest, headers: Annotated[BundleHeaders, Header()]):
|
||||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||||
if not google_api_key:
|
if not google_api_key:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
||||||
fn = modal.Function.from_name(config.modal_app_name,"make_image_grid_upload", environment_name=config.modal_environment)
|
fn = modal.Function.from_name(config.modal_app_name, "make_image_grid_upload",
|
||||||
image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size, data.padding, data.separator, google_api_key,
|
environment_name=config.modal_environment)
|
||||||
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage())
|
image_grid_uri = await fn.remote.aio(data.pic_info_list, data.image_size, data.text_height, data.font_size,
|
||||||
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage))
|
data.padding, data.separator, google_api_key,
|
||||||
return JSONResponse(content={"uri":image_grid_uri}, status_code=200)
|
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
|
||||||
|
x_baggage=sentry_sdk.get_baggage())
|
||||||
|
if headers.x_trace_id is None else SentryTransactionInfo(
|
||||||
|
x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage))
|
||||||
|
return JSONResponse(content={"uri": image_grid_uri}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/inference_gemini', summary="使用Gemini推理hls视频流指定时间段的打点情况")
|
@router.post('/inference_gemini', summary="使用Gemini推理hls视频流指定时间段的打点情况")
|
||||||
async def inference_gemini(
|
async def inference_gemini(
|
||||||
|
|
@ -151,10 +173,15 @@ async def inference_gemini(
|
||||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||||
if not google_api_key:
|
if not google_api_key:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
||||||
fn = modal.Function.from_name(config.modal_app_name,"video_hls_slice_inference", environment_name=config.modal_environment)
|
fn = modal.Function.from_name(config.modal_app_name, "video_hls_slice_inference",
|
||||||
fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list, data.start_time, data.end_time, data.options,
|
environment_name=config.modal_environment)
|
||||||
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(), x_baggage=sentry_sdk.get_baggage())
|
fn_call = fn.spawn(data.media_hls_url, google_api_key, data.product_cover_grid_uri_list, data.product_list,
|
||||||
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id, x_baggage=headers.x_baggage), data.webhook, 2, data.scale)
|
data.start_time, data.end_time, data.options,
|
||||||
|
SentryTransactionInfo(x_trace_id=sentry_sdk.get_traceparent(),
|
||||||
|
x_baggage=sentry_sdk.get_baggage())
|
||||||
|
if headers.x_trace_id is None else SentryTransactionInfo(x_trace_id=headers.x_trace_id,
|
||||||
|
x_baggage=headers.x_baggage),
|
||||||
|
data.webhook, 2, data.scale)
|
||||||
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
|
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -166,7 +193,8 @@ async def gemini_status(task_id: str, response: Response) -> GeminiResultRespons
|
||||||
response.headers["x-baggage"] = task_info.transaction.x_baggage
|
response.headers["x-baggage"] = task_info.transaction.x_baggage
|
||||||
try:
|
try:
|
||||||
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
|
return GeminiResultResponse(taskId=task_id, status=task_info.status, code=cast(int, task_info.error_code.value),
|
||||||
error=task_info.error_reason, result=json.dumps(task_info.results, ensure_ascii=False)
|
error=task_info.error_reason,
|
||||||
|
result=json.dumps(task_info.results, ensure_ascii=False)
|
||||||
if task_info.results is not None else "")
|
if task_info.results is not None else "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"获取Gemini状态发生错误 {e}")
|
logger.exception(f"获取Gemini状态发生错误 {e}")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional, Any, Union
|
from typing import Union, Dict, Any
|
||||||
from functools import wraps
|
|
||||||
import backoff
|
import backoff
|
||||||
import httpx
|
import httpx
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -7,6 +7,8 @@ from loguru import logger
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class HTTPDownloadUtils:
|
class HTTPDownloadUtils:
|
||||||
# 创建一个类级别的 Semaphore 来控制并发
|
# 创建一个类级别的 Semaphore 来控制并发
|
||||||
|
|
@ -98,3 +100,29 @@ class HTTPDownloadUtils:
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleAuthUtils:
|
||||||
|
class GoogleAuthResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
expires_in: int
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_google_auth_jwt(service_account_info: Dict[str, str], scopes: list[str]) -> GoogleAuthResponse:
|
||||||
|
credentials: service_account.Credentials = service_account.Credentials.from_service_account_info(
|
||||||
|
service_account_info, scopes=scopes)
|
||||||
|
assertion = credentials._make_authorization_grant_assertion().decode('utf-8')
|
||||||
|
logger.info(f"jwt_assertion: {assertion}")
|
||||||
|
params = {
|
||||||
|
'grant_type': "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||||
|
'assertion': assertion
|
||||||
|
}
|
||||||
|
response = httpx.post(url="https://oauth2.googleapis.com/token",
|
||||||
|
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||||
|
data=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return GoogleAuthUtils.GoogleAuthResponse.model_validate_json(response.text)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue