import asyncio import json import random import uuid from typing import Dict, Any import aiohttp import websockets from aiohttp import ClientTimeout from workflow_service.config import Settings, ComfyUIServer settings = Settings() async def get_server_status(server: ComfyUIServer, session: aiohttp.ClientSession) -> Dict[str, Any]: """ 检查单个ComfyUI服务器的详细状态。 返回一个包含可达性、队列状态和详细队列内容的字典。 """ # [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型 status_info = { "is_reachable": False, "is_free": False, "queue_details": { "running_count": 0, "pending_count": 0 } } try: queue_url = f"{server.http_url}/queue" async with session.get(queue_url, timeout=60) as response: response.raise_for_status() queue_data = await response.json() status_info["is_reachable"] = True running_count = len(queue_data.get("queue_running", [])) pending_count = len(queue_data.get("queue_pending", [])) status_info["queue_details"] = { "running_count": running_count, "pending_count": pending_count } status_info["is_free"] = (running_count == 0 and pending_count == 0) except Exception as e: # 当请求失败时,将返回上面定义的、结构正确的初始 status_info print(f"警告: 无法检查服务器 {server.http_url} 的队列状态: {e}") return status_info async def select_server_for_execution() -> ComfyUIServer: """ 智能选择一个ComfyUI服务器。 优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。 """ servers = settings.SERVERS if not servers: raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。") if len(servers) == 1: return servers[0] async with aiohttp.ClientSession() as session: tasks = [get_server_status(server, session) for server in servers] results = await asyncio.gather(*tasks) free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]] if free_servers: selected_server = random.choice(free_servers) print(f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}") return selected_server else: # 后备方案:选择一个可达的服务器,即使它很忙 reachable_servers = [servers[i] for i, status in enumerate(results) if status["is_reachable"]] if reachable_servers: selected_server = random.choice(reachable_servers) print(f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}") return selected_server else: # 最坏情况:所有服务器都不可达,抛出异常 raise ConnectionError("所有配置的ComfyUI服务器都不可达。") async def queue_prompt(prompt: dict, client_id: str, http_url: str) -> str: """通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。""" for node_id in prompt: prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random() payload = {"prompt": prompt, "client_id": client_id} async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session: prompt_url = f"{http_url}/prompt" async with session.post(prompt_url, json=payload) as response: response.raise_for_status() result = await response.json() if "prompt_id" not in result: raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}") return result["prompt_id"] async def get_execution_results(prompt_id: str, client_id: str, ws_url: str) -> dict: """通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。""" full_ws_url = f"{ws_url}?clientId={client_id}" aggregated_outputs = {} async with websockets.connect(full_ws_url) as websocket: while True: try: out = await websocket.recv() if isinstance(out, str): message = json.loads(out) msg_type = message.get('type') data = message.get('data') if data and data.get('prompt_id') == prompt_id: if msg_type == 'executed': node_id = data.get('node') output_data = data.get('output') if node_id and output_data: aggregated_outputs[node_id] = output_data print(f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})") elif msg_type == 'executing' and data.get('node') is None: print(f"Prompt ID: {prompt_id} 执行完成。") return aggregated_outputs except websockets.exceptions.ConnectionClosed as e: print(f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}") return aggregated_outputs except Exception as e: print(f"处理 prompt {prompt_id} 时发生错误: {e}") break return aggregated_outputs async def execute_prompt_on_server(prompt: Dict, server: ComfyUIServer) -> Dict: """ 在指定的服务器上执行一个准备好的prompt。 """ client_id = str(uuid.uuid4()) prompt_id = await queue_prompt(prompt, client_id, server.http_url) print(f"工作流已在 {server.http_url} 上入队,Prompt ID: {prompt_id}") results = await get_execution_results(prompt_id, client_id, server.ws_url) return results