import websockets import json import uuid import aiohttp import random from workflow_service.config import Settings settings = Settings() async def queue_prompt(prompt: dict, client_id: str) -> str: """通过HTTP POST将工作流任务提交到ComfyUI队列。""" # [缓存破解] 在每个节点中添加一个随机数输入,强制重新执行 for node_id in prompt: prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random() payload = {"prompt": prompt, "client_id": client_id} async with aiohttp.ClientSession() as session: http_url = f"{settings.COMFYUI_HTTP_URL}/prompt" async with session.post(http_url, json=payload) as response: response.raise_for_status() result = await response.json() if "prompt_id" not in result: raise Exception(f"Invalid response from ComfyUI /prompt endpoint: {result}") return result["prompt_id"] async def get_execution_results(prompt_id: str, client_id: str) -> dict: """ 通过WebSocket连接,聚合所有'executed'事件的输出, 直到整个执行流程结束。 """ ws_url = f"{settings.COMFYUI_URL}?clientId={client_id}" aggregated_outputs = {} async with websockets.connect(ws_url) as websocket: while True: try: out = await websocket.recv() if isinstance(out, str): message = json.loads(out) msg_type = message.get('type') data = message.get('data') # 我们只关心与我们prompt_id相关的事件 if data and data.get('prompt_id') == prompt_id: if msg_type == 'executed': # 聚合每个节点的输出 node_id = data.get('node') output_data = data.get('output') if node_id and output_data: aggregated_outputs[node_id] = output_data print(f"Output received for node {node_id} (Prompt ID: {prompt_id})") # 判断执行是否结束 # 官方UI使用 "executing" 且 node is null 作为结束标志 elif msg_type == 'executing' and data.get('node') is None: print(f"Execution finished for Prompt ID: {prompt_id}") return aggregated_outputs except websockets.exceptions.ConnectionClosed as e: print(f"WebSocket connection closed for {prompt_id}. Returning aggregated results. Error: {e}") return aggregated_outputs # 连接关闭也视为结束 except Exception as e: print(f"An error occurred for {prompt_id}: {e}") break return aggregated_outputs async def run_workflow(prompt: dict) -> dict: """主协调函数:提交任务,然后等待结果。""" client_id = str(uuid.uuid4()) prompt_id = await queue_prompt(prompt, client_id) print(f"Workflow successfully queued with Prompt ID: {prompt_id}") results = await get_execution_results(prompt_id, client_id) return results