ADD Gemini接口改为可不传参考图

This commit is contained in:
kyj@bowong.ai 2025-07-21 18:35:57 +08:00
parent aaae9060df
commit 1a134cd779
2 changed files with 47 additions and 41 deletions

View File

@ -7,6 +7,7 @@ from typing import Annotated, Optional, cast
import httpx import httpx
import modal import modal
import sentry_sdk import sentry_sdk
from aiohttp.web_fileresponse import content_type
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form
from fastapi.responses import Response from fastapi.responses import Response
from loguru import logger from loguru import logger
@ -303,42 +304,45 @@ async def clothes_mark(
return ModalTaskResponse(success=True, taskId=fn_call.object_id) return ModalTaskResponse(success=True, taskId=fn_call.object_id)
@router.post("/image/edit_custom", summary="使用Gemini进行图像编辑") @router.post("/image/edit_custom", summary="使用Gemini进行图像生成/编辑")
async def clothes_mark( async def clothes_mark(
headers: Annotated[BundleHeaders, Header()], headers: Annotated[BundleHeaders, Header()],
origin_image: Annotated[UploadFile, File(description="待处理图片")], prompt: Annotated[str, Form(description="图像生成/编辑提示词")],
prompt: Annotated[str, Form(description="图像编辑提示词")], origin_image: Annotated[UploadFile, File(description="待处理图片")] = None,
temperature: Annotated[float, Form()] = 0.01, temperature: Annotated[float, Form()] = 0.01,
topP: Annotated[float, Form()] = 0.7, topP: Annotated[float, Form()] = 0.7,
) -> ModalTaskResponse: ) -> ModalTaskResponse:
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key: if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
content_type = origin_image.content_type if origin_image is not None:
if content_type not in SUPPORTED_IMAGE_TYPES: content_type = origin_image.content_type
raise HTTPException( if content_type not in SUPPORTED_IMAGE_TYPES:
status_code=400, raise HTTPException(
detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}" status_code=400,
) detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}"
)
try: try:
# 获取对应文件类型的扩展名 # 获取对应文件类型的扩展名
file_extension = SUPPORTED_IMAGE_TYPES[content_type] file_extension = SUPPORTED_IMAGE_TYPES[content_type]
# 创建临时文件,使用正确的扩展名 # 创建临时文件,使用正确的扩展名
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
delete=False, delete=False,
suffix=file_extension, suffix=file_extension,
mode='wb' mode='wb'
) as temp_file: ) as temp_file:
# 写入文件内容 # 写入文件内容
contents = await origin_image.read() contents = await origin_image.read()
temp_file.write(contents) temp_file.write(contents)
temp_file_path = temp_file.name temp_file_path = temp_file.name
logger.success(f"上传图片已保存到{temp_file_path}") logger.success(f"上传图片已保存到{temp_file_path}")
except Exception as e: except Exception as e:
logger.exception("保存图片失败") logger.exception("保存图片失败")
raise e raise e
uri = await upload(temp_file_path, content_type, google_api_key) uri = await upload(temp_file_path, content_type, google_api_key)
else:
uri, content_type = None, None
fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom", fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom",
environment_name=config.modal_environment) environment_name=config.modal_environment)
fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP) fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP)

View File

@ -132,22 +132,24 @@ with (downloader_image.imports()):
) )
try: try:
logger.info("🐬开始处理图片") logger.info("🐬开始处理图片")
parts = [
types.Part.from_text(
text="<prompt>"
"<instruction>"
"{0}"
"</instruction>"
"</prompt>".format(
prompt
)
)
]
if origin_image_uri is not None:
parts.insert(0, types.Part.from_uri(
file_uri=origin_image_uri,
mime_type=content_type))
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation", resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation",
contents=[types.Content(role='user', contents=[types.Content(role='user',
parts=[ parts=parts
types.Part.from_uri(
file_uri=origin_image_uri,
mime_type=content_type),
types.Part.from_text(
text="<prompt>"
"<instruction>"
"{0}"
"</instruction>"
"</prompt>".format(
prompt
)
)
]
) )
], ],
config=types.GenerateContentConfig.model_validate({ config=types.GenerateContentConfig.model_validate({