diff --git a/src/BowongModalFunctions/router/google.py b/src/BowongModalFunctions/router/google.py index a85ce86..4465d1f 100644 --- a/src/BowongModalFunctions/router/google.py +++ b/src/BowongModalFunctions/router/google.py @@ -7,6 +7,7 @@ from typing import Annotated, Optional, cast import httpx import modal import sentry_sdk +from aiohttp.web_fileresponse import content_type from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form from fastapi.responses import Response from loguru import logger @@ -303,42 +304,45 @@ async def clothes_mark( return ModalTaskResponse(success=True, taskId=fn_call.object_id) -@router.post("/image/edit_custom", summary="使用Gemini进行图像编辑") +@router.post("/image/edit_custom", summary="使用Gemini进行图像生成/编辑") async def clothes_mark( headers: Annotated[BundleHeaders, Header()], - origin_image: Annotated[UploadFile, File(description="待处理图片")], - prompt: Annotated[str, Form(description="图像编辑提示词")], + prompt: Annotated[str, Form(description="图像生成/编辑提示词")], + origin_image: Annotated[UploadFile, File(description="待处理图片")] = None, temperature: Annotated[float, Form()] = 0.01, topP: Annotated[float, Form()] = 0.7, ) -> ModalTaskResponse: google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY") if not google_api_key: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key") - content_type = origin_image.content_type - if content_type not in SUPPORTED_IMAGE_TYPES: - raise HTTPException( - status_code=400, - detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}" - ) + if origin_image is not None: + content_type = origin_image.content_type + if content_type not in SUPPORTED_IMAGE_TYPES: + raise HTTPException( + status_code=400, + detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}" + ) - try: - # 获取对应文件类型的扩展名 - file_extension = SUPPORTED_IMAGE_TYPES[content_type] - # 创建临时文件,使用正确的扩展名 - with tempfile.NamedTemporaryFile( - delete=False, - suffix=file_extension, - mode='wb' - ) as temp_file: - # 写入文件内容 - contents = await origin_image.read() - temp_file.write(contents) - temp_file_path = temp_file.name - logger.success(f"上传图片已保存到{temp_file_path}") - except Exception as e: - logger.exception("保存图片失败") - raise e - uri = await upload(temp_file_path, content_type, google_api_key) + try: + # 获取对应文件类型的扩展名 + file_extension = SUPPORTED_IMAGE_TYPES[content_type] + # 创建临时文件,使用正确的扩展名 + with tempfile.NamedTemporaryFile( + delete=False, + suffix=file_extension, + mode='wb' + ) as temp_file: + # 写入文件内容 + contents = await origin_image.read() + temp_file.write(contents) + temp_file_path = temp_file.name + logger.success(f"上传图片已保存到{temp_file_path}") + except Exception as e: + logger.exception("保存图片失败") + raise e + uri = await upload(temp_file_path, content_type, google_api_key) + else: + uri, content_type = None, None fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom", environment_name=config.modal_environment) fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP) diff --git a/src/cluster/image_apps/image_edit.py b/src/cluster/image_apps/image_edit.py index 752b709..d78f3b4 100644 --- a/src/cluster/image_apps/image_edit.py +++ b/src/cluster/image_apps/image_edit.py @@ -132,22 +132,24 @@ with (downloader_image.imports()): ) try: logger.info("🐬开始处理图片") + parts = [ + types.Part.from_text( + text="" + "" + "{0}" + "" + "".format( + prompt + ) + ) + ] + if origin_image_uri is not None: + parts.insert(0, types.Part.from_uri( + file_uri=origin_image_uri, + mime_type=content_type)) resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation", contents=[types.Content(role='user', - parts=[ - types.Part.from_uri( - file_uri=origin_image_uri, - mime_type=content_type), - types.Part.from_text( - text="" - "" - "{0}" - "" - "".format( - prompt - ) - ) - ] + parts=parts ) ], config=types.GenerateContentConfig.model_validate({