diff --git a/src/BowongModalFunctions/router/google.py b/src/BowongModalFunctions/router/google.py
index a85ce86..4465d1f 100644
--- a/src/BowongModalFunctions/router/google.py
+++ b/src/BowongModalFunctions/router/google.py
@@ -7,6 +7,7 @@ from typing import Annotated, Optional, cast
import httpx
import modal
import sentry_sdk
+from aiohttp.web_fileresponse import content_type
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form
from fastapi.responses import Response
from loguru import logger
@@ -303,42 +304,45 @@ async def clothes_mark(
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
-@router.post("/image/edit_custom", summary="使用Gemini进行图像编辑")
+@router.post("/image/edit_custom", summary="使用Gemini进行图像生成/编辑")
async def clothes_mark(
headers: Annotated[BundleHeaders, Header()],
- origin_image: Annotated[UploadFile, File(description="待处理图片")],
- prompt: Annotated[str, Form(description="图像编辑提示词")],
+ prompt: Annotated[str, Form(description="图像生成/编辑提示词")],
+ origin_image: Annotated[UploadFile, File(description="待处理图片")] = None,
temperature: Annotated[float, Form()] = 0.01,
topP: Annotated[float, Form()] = 0.7,
) -> ModalTaskResponse:
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
if not google_api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
- content_type = origin_image.content_type
- if content_type not in SUPPORTED_IMAGE_TYPES:
- raise HTTPException(
- status_code=400,
- detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}"
- )
+ if origin_image is not None:
+ content_type = origin_image.content_type
+ if content_type not in SUPPORTED_IMAGE_TYPES:
+ raise HTTPException(
+ status_code=400,
+ detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}"
+ )
- try:
- # 获取对应文件类型的扩展名
- file_extension = SUPPORTED_IMAGE_TYPES[content_type]
- # 创建临时文件,使用正确的扩展名
- with tempfile.NamedTemporaryFile(
- delete=False,
- suffix=file_extension,
- mode='wb'
- ) as temp_file:
- # 写入文件内容
- contents = await origin_image.read()
- temp_file.write(contents)
- temp_file_path = temp_file.name
- logger.success(f"上传图片已保存到{temp_file_path}")
- except Exception as e:
- logger.exception("保存图片失败")
- raise e
- uri = await upload(temp_file_path, content_type, google_api_key)
+ try:
+ # 获取对应文件类型的扩展名
+ file_extension = SUPPORTED_IMAGE_TYPES[content_type]
+ # 创建临时文件,使用正确的扩展名
+ with tempfile.NamedTemporaryFile(
+ delete=False,
+ suffix=file_extension,
+ mode='wb'
+ ) as temp_file:
+ # 写入文件内容
+ contents = await origin_image.read()
+ temp_file.write(contents)
+ temp_file_path = temp_file.name
+ logger.success(f"上传图片已保存到{temp_file_path}")
+ except Exception as e:
+ logger.exception("保存图片失败")
+ raise e
+ uri = await upload(temp_file_path, content_type, google_api_key)
+ else:
+ uri, content_type = None, None
fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom",
environment_name=config.modal_environment)
fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP)
diff --git a/src/cluster/image_apps/image_edit.py b/src/cluster/image_apps/image_edit.py
index 752b709..d78f3b4 100644
--- a/src/cluster/image_apps/image_edit.py
+++ b/src/cluster/image_apps/image_edit.py
@@ -132,22 +132,24 @@ with (downloader_image.imports()):
)
try:
logger.info("🐬开始处理图片")
+ parts = [
+ types.Part.from_text(
+ text=""
+ ""
+ "{0}"
+ ""
+ "".format(
+ prompt
+ )
+ )
+ ]
+ if origin_image_uri is not None:
+ parts.insert(0, types.Part.from_uri(
+ file_uri=origin_image_uri,
+ mime_type=content_type))
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation",
contents=[types.Content(role='user',
- parts=[
- types.Part.from_uri(
- file_uri=origin_image_uri,
- mime_type=content_type),
- types.Part.from_text(
- text=""
- ""
- "{0}"
- ""
- "".format(
- prompt
- )
- )
- ]
+ parts=parts
)
],
config=types.GenerateContentConfig.model_validate({