ADD Gemini接口改为可不传参考图
This commit is contained in:
parent
aaae9060df
commit
1a134cd779
|
|
@ -7,6 +7,7 @@ from typing import Annotated, Optional, cast
|
||||||
import httpx
|
import httpx
|
||||||
import modal
|
import modal
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
from aiohttp.web_fileresponse import content_type
|
||||||
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form
|
from fastapi import APIRouter, UploadFile, Header, HTTPException, Depends, File, Form
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
@ -303,42 +304,45 @@ async def clothes_mark(
|
||||||
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
|
return ModalTaskResponse(success=True, taskId=fn_call.object_id)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/image/edit_custom", summary="使用Gemini进行图像编辑")
|
@router.post("/image/edit_custom", summary="使用Gemini进行图像生成/编辑")
|
||||||
async def clothes_mark(
|
async def clothes_mark(
|
||||||
headers: Annotated[BundleHeaders, Header()],
|
headers: Annotated[BundleHeaders, Header()],
|
||||||
origin_image: Annotated[UploadFile, File(description="待处理图片")],
|
prompt: Annotated[str, Form(description="图像生成/编辑提示词")],
|
||||||
prompt: Annotated[str, Form(description="图像编辑提示词")],
|
origin_image: Annotated[UploadFile, File(description="待处理图片")] = None,
|
||||||
temperature: Annotated[float, Form()] = 0.01,
|
temperature: Annotated[float, Form()] = 0.01,
|
||||||
topP: Annotated[float, Form()] = 0.7,
|
topP: Annotated[float, Form()] = 0.7,
|
||||||
) -> ModalTaskResponse:
|
) -> ModalTaskResponse:
|
||||||
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
google_api_key = headers.x_google_api_key or os.environ.get("GOOGLE_API_KEY")
|
||||||
if not google_api_key:
|
if not google_api_key:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Google API Key")
|
||||||
content_type = origin_image.content_type
|
if origin_image is not None:
|
||||||
if content_type not in SUPPORTED_IMAGE_TYPES:
|
content_type = origin_image.content_type
|
||||||
raise HTTPException(
|
if content_type not in SUPPORTED_IMAGE_TYPES:
|
||||||
status_code=400,
|
raise HTTPException(
|
||||||
detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}"
|
status_code=400,
|
||||||
)
|
detail=f"不支持的文件类型: {content_type}。支持的类型: {list(SUPPORTED_IMAGE_TYPES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取对应文件类型的扩展名
|
# 获取对应文件类型的扩展名
|
||||||
file_extension = SUPPORTED_IMAGE_TYPES[content_type]
|
file_extension = SUPPORTED_IMAGE_TYPES[content_type]
|
||||||
# 创建临时文件,使用正确的扩展名
|
# 创建临时文件,使用正确的扩展名
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(
|
||||||
delete=False,
|
delete=False,
|
||||||
suffix=file_extension,
|
suffix=file_extension,
|
||||||
mode='wb'
|
mode='wb'
|
||||||
) as temp_file:
|
) as temp_file:
|
||||||
# 写入文件内容
|
# 写入文件内容
|
||||||
contents = await origin_image.read()
|
contents = await origin_image.read()
|
||||||
temp_file.write(contents)
|
temp_file.write(contents)
|
||||||
temp_file_path = temp_file.name
|
temp_file_path = temp_file.name
|
||||||
logger.success(f"上传图片已保存到{temp_file_path}")
|
logger.success(f"上传图片已保存到{temp_file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("保存图片失败")
|
logger.exception("保存图片失败")
|
||||||
raise e
|
raise e
|
||||||
uri = await upload(temp_file_path, content_type, google_api_key)
|
uri = await upload(temp_file_path, content_type, google_api_key)
|
||||||
|
else:
|
||||||
|
uri, content_type = None, None
|
||||||
fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom",
|
fn = modal.Function.from_name(config.modal_app_name, "image_edit_custom",
|
||||||
environment_name=config.modal_environment)
|
environment_name=config.modal_environment)
|
||||||
fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP)
|
fn_call = fn.spawn(google_api_key, uri, content_type, prompt, temperature, topP)
|
||||||
|
|
|
||||||
|
|
@ -132,22 +132,24 @@ with (downloader_image.imports()):
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
logger.info("🐬开始处理图片")
|
logger.info("🐬开始处理图片")
|
||||||
|
parts = [
|
||||||
|
types.Part.from_text(
|
||||||
|
text="<prompt>"
|
||||||
|
"<instruction>"
|
||||||
|
"{0}"
|
||||||
|
"</instruction>"
|
||||||
|
"</prompt>".format(
|
||||||
|
prompt
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if origin_image_uri is not None:
|
||||||
|
parts.insert(0, types.Part.from_uri(
|
||||||
|
file_uri=origin_image_uri,
|
||||||
|
mime_type=content_type))
|
||||||
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation",
|
resp, resp_code = client.generate_content(model_id="gemini-2.0-flash-preview-image-generation",
|
||||||
contents=[types.Content(role='user',
|
contents=[types.Content(role='user',
|
||||||
parts=[
|
parts=parts
|
||||||
types.Part.from_uri(
|
|
||||||
file_uri=origin_image_uri,
|
|
||||||
mime_type=content_type),
|
|
||||||
types.Part.from_text(
|
|
||||||
text="<prompt>"
|
|
||||||
"<instruction>"
|
|
||||||
"{0}"
|
|
||||||
"</instruction>"
|
|
||||||
"</prompt>".format(
|
|
||||||
prompt
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
config=types.GenerateContentConfig.model_validate({
|
config=types.GenerateContentConfig.model_validate({
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue