diff --git a/workflow_service/comfyui_client.py b/workflow_service/comfyui_client.py index 95e27e6..818870a 100644 --- a/workflow_service/comfyui_client.py +++ b/workflow_service/comfyui_client.py @@ -1,4 +1,5 @@ import asyncio +import base64 import json import logging import os @@ -7,7 +8,7 @@ import uuid import websockets from collections import defaultdict from datetime import datetime -from typing import Dict, Any, Optional, List, Set +from typing import Dict, Any, Optional, List, Set, Union import aiohttp from aiohttp import ClientTimeout @@ -309,7 +310,9 @@ async def execute_prompt_on_server( client_id = str(uuid.uuid4()) # 应用请求数据到工作流 - patched_workflow = patch_workflow(workflow_data, api_spec, request_data) + patched_workflow = await patch_workflow( + workflow_data, api_spec, request_data, server + ) # 转换为prompt格式 prompt = convert_workflow_to_prompt_api_format(patched_workflow) @@ -438,10 +441,11 @@ def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]: return spec -def patch_workflow( +async def patch_workflow( workflow_data: dict, api_spec: dict[str, dict[str, Any]], request_data: dict[str, Any], + server: ComfyUIServer, ) -> dict: """ 将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。 @@ -501,6 +505,8 @@ def patch_workflow( # 在正确的位置上更新值 try: + if target_node.get("type") == "LoadImage": + value = await upload_image_to_comfy(value, server) target_node["widgets_values"][target_widget_index] = target_type(value) except (ValueError, TypeError) as e: logger.warning( @@ -779,3 +785,35 @@ async def _get_execution_results_legacy( raise e return aggregated_outputs + + +async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str: + """ + 上传文件到服务器。 + file 是一个http链接 + 返回一个名称 + """ + file_name = file_path.split("/")[-1] + file_data = await download_file(file_path) + + form_data = aiohttp.FormData() + form_data.add_field("image", file_data, filename=file_name, content_type="image/*") + form_data.add_field("type", "input") + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{server.http_url}/api/upload/image", data=form_data + ) as response: + response.raise_for_status() + result = await response.json() + return result["name"] + + +async def download_file(src: str) -> bytes: + """ + 下载文件到本地。 + """ + async with aiohttp.ClientSession() as session: + async with session.get(src) as response: + response.raise_for_status() + return await response.read() diff --git a/workflow_service/main.py b/workflow_service/main.py index fe76cb1..0863b99 100644 --- a/workflow_service/main.py +++ b/workflow_service/main.py @@ -15,11 +15,6 @@ from pydantic import BaseModel from workflow_service import comfyui_client from workflow_service import database from workflow_service.utils.s3_client import upload_file_to_s3 -from workflow_service.comfyui_client import ( - ComfyUIExecutionError, - queue_manager, - submit_workflow_to_queue, -) from workflow_service.config import Settings settings = Settings() @@ -197,8 +192,11 @@ async def run_workflow( api_spec = comfyui_client.parse_api_spec(workflow) # 提交到队列 - workflow_run_id = await submit_workflow_to_queue( - workflow_name=workflow_name, workflow_data=workflow, api_spec=api_spec, request_data=data + workflow_run_id = await comfyui_client.submit_workflow_to_queue( + workflow_name=workflow_name, + workflow_data=workflow, + api_spec=api_spec, + request_data=data, ) return JSONResponse( @@ -220,7 +218,7 @@ async def get_run_status(workflow_run_id: str): 获取工作流执行状态。 """ try: - status = await queue_manager.get_task_status(workflow_run_id) + status = await comfyui_client.queue_manager.get_task_status(workflow_run_id) return status except Exception as e: raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") @@ -232,13 +230,13 @@ async def get_metrics(): 获取队列状态概览。 """ try: - pending_count = len(queue_manager.pending_tasks) - running_count = len(queue_manager.running_tasks) + pending_count = len(comfyui_client.queue_manager.pending_tasks) + running_count = len(comfyui_client.queue_manager.running_tasks) return { "pending_tasks": pending_count, "running_tasks": running_count, - "total_servers": len(queue_manager.running_tasks), + "total_servers": len(comfyui_client.queue_manager.running_tasks), "queue_manager_status": "active", } except Exception as e: