import asyncio import json import random import uuid from typing import Dict, Any import aiohttp import websockets from aiohttp import ClientTimeout from workflow_service.config import Settings, ComfyUIServer settings = Settings() # [新增] 定义一个自定义异常,用于封装来自ComfyUI的执行错误 class ComfyUIExecutionError(Exception): def __init__(self, error_data: dict): self.error_data = error_data # 创建一个对开发者友好的异常消息 message = ( f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, " f"节点类型: {error_data.get('node_type')}. " f"错误: {error_data.get('exception_message', 'N/A')}" ) super().__init__(message) async def get_server_status( server: ComfyUIServer, session: aiohttp.ClientSession ) -> Dict[str, Any]: """ 检查单个ComfyUI服务器的详细状态。 返回一个包含可达性、队列状态和详细队列内容的字典。 """ # [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型 status_info = { "is_reachable": False, "is_free": False, "queue_details": {"running_count": 0, "pending_count": 0}, } try: queue_url = f"{server.http_url}/queue" async with session.get(queue_url, timeout=60) as response: response.raise_for_status() queue_data = await response.json() status_info["is_reachable"] = True running_count = len(queue_data.get("queue_running", [])) pending_count = len(queue_data.get("queue_pending", [])) status_info["queue_details"] = { "running_count": running_count, "pending_count": pending_count, } status_info["is_free"] = running_count == 0 and pending_count == 0 except Exception as e: # 当请求失败时,将返回上面定义的、结构正确的初始 status_info print(f"警告: 无法检查服务器 {server.http_url} 的队列状态: {e}") return status_info async def select_server_for_execution() -> ComfyUIServer: """ 智能选择一个ComfyUI服务器。 优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。 """ servers = settings.SERVERS if not servers: raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。") if len(servers) == 1: return servers[0] async with aiohttp.ClientSession() as session: tasks = [get_server_status(server, session) for server in servers] results = await asyncio.gather(*tasks) free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]] if free_servers: selected_server = random.choice(free_servers) print( f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}" ) return selected_server else: # 后备方案:选择一个可达的服务器,即使它很忙 reachable_servers = [ servers[i] for i, status in enumerate(results) if status["is_reachable"] ] if reachable_servers: selected_server = random.choice(reachable_servers) print( f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}" ) return selected_server else: # 最坏情况:所有服务器都不可达,抛出异常 raise ConnectionError("所有配置的ComfyUI服务器都不可达。") async def execute_prompt_on_server(prompt: Dict, server: ComfyUIServer) -> Dict: """ 在指定的服务器上执行一个准备好的prompt。 """ client_id = str(uuid.uuid4()) prompt_id = await _queue_prompt(prompt, client_id, server.http_url) print(f"工作流已在 {server.http_url} 上入队,Prompt ID: {prompt_id}") results = await _get_execution_results(prompt_id, client_id, server.ws_url) return results async def _queue_prompt(prompt: dict, client_id: str, http_url: str) -> str: """通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。""" for node_id in prompt: prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random() payload = {"prompt": prompt, "client_id": client_id} async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session: prompt_url = f"{http_url}/prompt" async with session.post(prompt_url, json=payload) as response: response.raise_for_status() result = await response.json() if "prompt_id" not in result: raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}") return result["prompt_id"] async def _get_execution_results(prompt_id: str, client_id: str, ws_url: str) -> dict: """ 通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。 [核心改动] 新增对 'execution_error' 消息的处理。 """ full_ws_url = f"{ws_url}?clientId={client_id}" aggregated_outputs = {} try: async with websockets.connect(full_ws_url) as websocket: while True: try: out = await websocket.recv() if not isinstance(out, str): continue message = json.loads(out) msg_type = message.get("type") data = message.get("data") if not (data and data.get("prompt_id") == prompt_id): continue # [核心改动] 捕获并处理执行错误 if msg_type == "execution_error": print(f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {data}") # 抛出自定义异常,将错误详情传递出去 raise ComfyUIExecutionError(data) if msg_type == "executed": node_id = data.get("node") output_data = data.get("output") if node_id and output_data: aggregated_outputs[node_id] = output_data print(f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})") elif msg_type == "executing" and data.get("node") is None: print(f"Prompt ID: {prompt_id} 执行完成。") return aggregated_outputs except websockets.exceptions.ConnectionClosed as e: print(f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}") return aggregated_outputs except Exception as e: # 重新抛出我们自己的异常,或者处理其他意外错误 if not isinstance(e, ComfyUIExecutionError): print(f"处理 prompt {prompt_id} 时发生意外错误: {e}") raise e except websockets.exceptions.InvalidURI as e: print( f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}" ) raise e return aggregated_outputs