ADD Gemini接口改为可不传参考图

This commit is contained in:
kyj@bowong.ai 2025-07-21 18:35:57 +08:00
parent aaae9060df
commit 1a134cd779
2 changed files with 47 additions and 41 deletions

View File

@ -7,6 +7,7 @@ from typing import Annotated, Optional, cast
import httpx
import modal
import sentry_sdk
from aiohttp.web_fileresponse import content_type
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form
from fastapi.responses import Response
from loguru import logger
@ -303,17 +304,18 @@ async def clothes_mark(
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
@router.post("/image/edit_custom", summary="使用Gemini进行图像编辑")
@router.post("/image/edit_custom", summary="使用Gemini进行图像生成/编辑")
async def clothes_mark(
headers: Annotated[BundleHeaders, Header()],
origin_image: Annotated[UploadFile, File(description="待处理图片")],
prompt: Annotated[str, Form(description="图像编辑提示词")],
prompt: Annotated[str, Form(description="图像生成/编辑提示词")],
origin_image: Annotated[UploadFile, File(description="待处理图片")] = None,
temperature: Annotated[float, Form()] = 0.01,
topP: Annotated[float, Form()] = 0.7,
) -> ModalTaskResponse:
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
if origin_image is not None:
content_type = origin_image.content_type
if content_type not in SUPPORTED_IMAGE_TYPES:
raise HTTPException(
@ -339,6 +341,8 @@ async def clothes_mark(
logger.exception("保存图片失败")
raise e
uri = await upload(temp_file_path, content_type, google_api_key)
else:
uri, content_type = None, None
fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom",
environment_name=config.modal_environment)
fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP)

View File

@ -132,12 +132,7 @@ with (downloader_image.imports()):
)
try:
logger.info("🐬开始处理图片")
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation",
contents=[types.Content(role='user',
parts = [
types.Part.from_uri(
file_uri=origin_image_uri,
mime_type=content_type),
types.Part.from_text(
text="<prompt>"
"<instruction>"
@ -148,6 +143,13 @@ with (downloader_image.imports()):
)
)
]
if origin_image_uri is not None:
parts.insert(0, types.Part.from_uri(
file_uri=origin_image_uri,
mime_type=content_type))
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation",
contents=[types.Content(role='user',
parts=parts
)
],
config=types.GenerateContentConfig.model_validate({