ADD Gemini接口改为可不传参考图

This commit is contained in:
kyj@bowong.ai 2025-07-21 18:35:57 +08:00
parent aaae9060df
commit 1a134cd779
2 changed files with 47 additions and 41 deletions

View File

@ -7,6 +7,7 @@ from typing import Annotated, Optional, cast
import httpx import httpx
import modal import modal
import sentry_sdk import sentry_sdk
from aiohttp.web_fileresponse import content_type
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form
from fastapi.responses import Response from fastapi.responses import Response
from loguru import logger from loguru import logger
@ -303,17 +304,18 @@ async def clothes_mark(
return ModalTaskResponse(success=True, taskId=fn_call.object_id) return ModalTaskResponse(success=True, taskId=fn_call.object_id)
@router.post("/image/edit_custom", summary="使用Gemini进行图像编辑") @router.post("/image/edit_custom", summary="使用Gemini进行图像生成/编辑")
async def clothes_mark( async def clothes_mark(
headers: Annotated[BundleHeaders, Header()], headers: Annotated[BundleHeaders, Header()],
origin_image: Annotated[UploadFile, File(description="待处理图片")], prompt: Annotated[str, Form(description="图像生成/编辑提示词")],
prompt: Annotated[str, Form(description="图像编辑提示词")], origin_image: Annotated[UploadFile, File(description="待处理图片")] = None,
temperature: Annotated[float, Form()] = 0.01, temperature: Annotated[float, Form()] = 0.01,
topP: Annotated[float, Form()] = 0.7, topP: Annotated[float, Form()] = 0.7,
) -> ModalTaskResponse: ) -> ModalTaskResponse:
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key: if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
if origin_image is not None:
content_type = origin_image.content_type content_type = origin_image.content_type
if content_type not in SUPPORTED_IMAGE_TYPES: if content_type not in SUPPORTED_IMAGE_TYPES:
raise HTTPException( raise HTTPException(
@ -339,6 +341,8 @@ async def clothes_mark(
logger.exception("保存图片失败") logger.exception("保存图片失败")
raise e raise e
uri = await upload(temp_file_path, content_type, google_api_key) uri = await upload(temp_file_path, content_type, google_api_key)
else:
uri, content_type = None, None
fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom", fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom",
environment_name=config.modal_environment) environment_name=config.modal_environment)
fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP) fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP)

View File

@ -132,12 +132,7 @@ with (downloader_image.imports()):
) )
try: try:
logger.info("🐬开始处理图片") logger.info("🐬开始处理图片")
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation", parts = [
contents=[types.Content(role='user',
parts=[
types.Part.from_uri(
file_uri=origin_image_uri,
mime_type=content_type),
types.Part.from_text( types.Part.from_text(
text="<prompt>" text="<prompt>"
"<instruction>" "<instruction>"
@ -148,6 +143,13 @@ with (downloader_image.imports()):
) )
) )
] ]
if origin_image_uri is not None:
parts.insert(0, types.Part.from_uri(
file_uri=origin_image_uri,
mime_type=content_type))
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation",
contents=[types.Content(role='user',
parts=parts
) )
], ],
config=types.GenerateContentConfig.model_validate({ config=types.GenerateContentConfig.model_validate({