128 lines
3.4 KiB
Python
128 lines
3.4 KiB
Python
import json
|
|
from typing import Optional
|
|
|
|
from fastapi import Body, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from app.comfy.comfy_queue import queue_manager
|
|
from app.comfy.comfy_workflow import ComfyWorkflow
|
|
from app.database.api import get_workflow
|
|
from ._base import runx_router
|
|
|
|
RUNX_NAME = "model_with_multi_dress"
|
|
RUNX_WORKFLOW_NAME = "测试"
|
|
|
|
|
|
@runx_router.post(f"/{RUNX_NAME}")
|
|
async def model_with_multi_dress(
|
|
data: dict = Body(...),
|
|
workflow_version: Optional[str] = None,
|
|
):
|
|
"""
|
|
一个模特,穿多件不同的衣服
|
|
"""
|
|
workflow_name = RUNX_WORKFLOW_NAME
|
|
|
|
# 获取工作流定义
|
|
workflow_data = await get_workflow(workflow_name, workflow_version)
|
|
if not workflow_data:
|
|
detail = (
|
|
f"工作流 '{workflow_name}'"
|
|
+ (f" 带版本 '{workflow_version}'" if workflow_version else " (最新版)")
|
|
+ " 未找到。"
|
|
)
|
|
raise HTTPException(status_code=404, detail=detail)
|
|
|
|
workflow = json.loads(workflow_data["workflow_json"])
|
|
flow = ComfyWorkflow(workflow_name, workflow)
|
|
|
|
# 将请求拆分为多个请求
|
|
batch_data = _convert(data)
|
|
|
|
# 提交到队列
|
|
workflow_run_ids: list[str] = []
|
|
for item in batch_data:
|
|
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=item)
|
|
workflow_run_ids.append(workflow_run_id)
|
|
|
|
return JSONResponse(
|
|
content={
|
|
"workflow_run_ids": workflow_run_ids,
|
|
"status": "queued",
|
|
"message": "工作流已提交到队列,正在等待执行",
|
|
},
|
|
status_code=202,
|
|
)
|
|
|
|
|
|
@runx_router.get(f"/{RUNX_NAME}/json_schema")
|
|
async def model_with_multi_dress_json_schema():
|
|
"""
|
|
获取工作流定义
|
|
"""
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"模特图": {
|
|
"type": "object",
|
|
"properties": {
|
|
"image": {"type": "string"},
|
|
},
|
|
},
|
|
"模特描述": {
|
|
"type": "object",
|
|
"properties": {
|
|
"value": {"type": "string"},
|
|
},
|
|
},
|
|
"穿搭图": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"image": {"type": "string"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
"required": ["模特图", "模特描述", "穿搭图"],
|
|
}
|
|
|
|
|
|
def _convert(data: dict) -> list[dict]:
|
|
"""
|
|
数据格式
|
|
{
|
|
"模特图": {"image" : "xxx" },
|
|
"模特描述" : { "value" : "xxx" },
|
|
"穿搭图": [
|
|
{"image" : "xxx" },
|
|
{"image" : "xxx" }
|
|
]
|
|
}
|
|
转换为
|
|
[
|
|
{
|
|
"模特图": {"image" : "xxx" },
|
|
"模特描述" : { "value" : "xxx" },
|
|
"穿搭图": { "image" : "xxx" }
|
|
}
|
|
]
|
|
"""
|
|
result: list[dict] = []
|
|
|
|
# 获取基础信息
|
|
model_image = data.get("模特图", {})
|
|
model_description = data.get("模特描述", {})
|
|
dress_images = data.get("穿搭图", [])
|
|
|
|
# 为每个穿搭图创建一个记录
|
|
for dress_image in dress_images:
|
|
record = {
|
|
"模特图": model_image,
|
|
"模特描述": model_description,
|
|
"穿搭图": dress_image,
|
|
}
|
|
result.append(record)
|
|
|
|
return result
|