This commit is contained in:
iHeyTang 2025-08-13 16:21:39 +08:00
parent 1af6b42573
commit 7830d2e02d
2 changed files with 50 additions and 14 deletions

View File

@ -1,4 +1,5 @@
import asyncio
import base64
import json
import logging
import os
@ -7,7 +8,7 @@ import uuid
import websockets
from collections import defaultdict
from datetime import datetime
from typing import Dict, Any, Optional, List, Set
from typing import Dict, Any, Optional, List, Set, Union
import aiohttp
from aiohttp import ClientTimeout
@ -309,7 +310,9 @@ async def execute_prompt_on_server(
client_id = str(uuid.uuid4())
# 应用请求数据到工作流
patched_workflow = patch_workflow(workflow_data, api_spec, request_data)
patched_workflow = await patch_workflow(
workflow_data, api_spec, request_data, server
)
# 转换为prompt格式
prompt = convert_workflow_to_prompt_api_format(patched_workflow)
@ -438,10 +441,11 @@ def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
return spec
def patch_workflow(
async def patch_workflow(
workflow_data: dict,
api_spec: dict[str, dict[str, Any]],
request_data: dict[str, Any],
server: ComfyUIServer,
) -> dict:
"""
将request_data中的参数值patch到workflow_data中并返回修改后的workflow_data
@ -501,6 +505,8 @@ def patch_workflow(
# 在正确的位置上更新值
try:
if target_node.get("type") == "LoadImage":
value = await upload_image_to_comfy(value, server)
target_node["widgets_values"][target_widget_index] = target_type(value)
except (ValueError, TypeError) as e:
logger.warning(
@ -779,3 +785,35 @@ async def _get_execution_results_legacy(
raise e
return aggregated_outputs
async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str:
"""
上传文件到服务器
file 是一个http链接
返回一个名称
"""
file_name = file_path.split("/")[-1]
file_data = await download_file(file_path)
form_data = aiohttp.FormData()
form_data.add_field("image", file_data, filename=file_name, content_type="image/*")
form_data.add_field("type", "input")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{server.http_url}/api/upload/image", data=form_data
) as response:
response.raise_for_status()
result = await response.json()
return result["name"]
async def download_file(src: str) -> bytes:
"""
下载文件到本地
"""
async with aiohttp.ClientSession() as session:
async with session.get(src) as response:
response.raise_for_status()
return await response.read()

View File

@ -15,11 +15,6 @@ from pydantic import BaseModel
from workflow_service import comfyui_client
from workflow_service import database
from workflow_service.utils.s3_client import upload_file_to_s3
from workflow_service.comfyui_client import (
ComfyUIExecutionError,
queue_manager,
submit_workflow_to_queue,
)
from workflow_service.config import Settings
settings = Settings()
@ -197,8 +192,11 @@ async def run_workflow(
api_spec = comfyui_client.parse_api_spec(workflow)
# 提交到队列
workflow_run_id = await submit_workflow_to_queue(
workflow_name=workflow_name, workflow_data=workflow, api_spec=api_spec, request_data=data
workflow_run_id = await comfyui_client.submit_workflow_to_queue(
workflow_name=workflow_name,
workflow_data=workflow,
api_spec=api_spec,
request_data=data,
)
return JSONResponse(
@ -220,7 +218,7 @@ async def get_run_status(workflow_run_id: str):
获取工作流执行状态
"""
try:
status = await queue_manager.get_task_status(workflow_run_id)
status = await comfyui_client.queue_manager.get_task_status(workflow_run_id)
return status
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
@ -232,13 +230,13 @@ async def get_metrics():
获取队列状态概览
"""
try:
pending_count = len(queue_manager.pending_tasks)
running_count = len(queue_manager.running_tasks)
pending_count = len(comfyui_client.queue_manager.pending_tasks)
running_count = len(comfyui_client.queue_manager.running_tasks)
return {
"pending_tasks": pending_count,
"running_tasks": running_count,
"total_servers": len(queue_manager.running_tasks),
"total_servers": len(comfyui_client.queue_manager.running_tasks),
"queue_manager_status": "active",
}
except Exception as e: