import asyncio import base64 import json import logging import os import random import uuid import websockets from collections import defaultdict from datetime import datetime from typing import Dict, Any, Optional, List, Set, Union import aiohttp from aiohttp import ClientTimeout from workflow_service.config import Settings, ComfyUIServer from workflow_service.database import ( create_workflow_run, update_workflow_run_status, create_workflow_run_nodes, update_workflow_run_node_status, get_workflow_run, get_pending_workflow_runs, get_running_workflow_runs, get_workflow_run_nodes, ) settings = Settings() API_INPUT_PREFIX = "INPUT_" API_OUTPUT_PREFIX = "OUTPUT_" # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 全局任务队列管理器 class WorkflowQueueManager: def __init__(self): self.running_tasks = {} # server_url -> task_info self.pending_tasks = [] # 等待队列 self.lock = asyncio.Lock() async def add_task( self, workflow_run_id: str, workflow_name: str, workflow_data: dict, api_spec: dict, request_data: dict, ): """添加新任务到队列""" async with self.lock: # 创建任务记录 await create_workflow_run( workflow_run_id=workflow_run_id, workflow_name=workflow_name, workflow_json=json.dumps(workflow_data), api_spec=json.dumps(api_spec), request_data=json.dumps(request_data), ) # 创建工作流节点记录 nodes_data = [] for node in workflow_data.get("nodes", []): nodes_data.append( {"id": str(node["id"]), "type": node.get("type", "unknown")} ) await create_workflow_run_nodes(workflow_run_id, nodes_data) # 添加到待处理队列 self.pending_tasks.append(workflow_run_id) logger.info( f"任务 {workflow_run_id} 已添加到队列,当前队列长度: {len(self.pending_tasks)}" ) # 尝试处理队列 asyncio.create_task(self._process_queue()) return workflow_run_id async def _process_queue(self): """处理队列中的任务""" async with self.lock: if not self.pending_tasks: return # 检查是否有空闲的服务器 available_servers = await self._get_available_servers() if not available_servers: logger.info("没有可用的服务器,等待中...") return # 获取一个待处理任务 workflow_run_id = self.pending_tasks.pop(0) server = available_servers[0] # 标记任务为运行中 await update_workflow_run_status( workflow_run_id, "running", server.http_url ) self.running_tasks[server.http_url] = { "workflow_run_id": workflow_run_id, "started_at": datetime.now(), } # 启动任务执行 asyncio.create_task(self._execute_task(workflow_run_id, server)) async def _get_available_servers(self) -> list[ComfyUIServer]: """获取可用的服务器""" available_servers = [] for server in settings.SERVERS: if server.http_url not in self.running_tasks: # 检查服务器状态 try: async with aiohttp.ClientSession() as session: status = await get_server_status(server, session) if status["is_reachable"] and status["is_free"]: available_servers.append(server) except Exception as e: logger.warning(f"检查服务器 {server.http_url} 状态时出错: {e}") return available_servers async def _execute_task(self, workflow_run_id: str, server: ComfyUIServer): """执行任务""" cleanup_paths = [] try: # 获取工作流数据 workflow_run = await get_workflow_run(workflow_run_id) if not workflow_run: raise Exception(f"找不到工作流运行记录: {workflow_run_id}") workflow_data = json.loads(workflow_run["workflow_json"]) api_spec = json.loads(workflow_run["api_spec"]) request_data = json.loads(workflow_run["request_data"]) # 执行工作流 result = await execute_prompt_on_server( workflow_data, api_spec, request_data, server, workflow_run_id ) # 保存处理后的结果到数据库 await update_workflow_run_status( workflow_run_id, "completed", result=json.dumps(result, ensure_ascii=False), ) except Exception as e: logger.error(f"执行任务 {workflow_run_id} 时出错: {e}") await update_workflow_run_status( workflow_run_id, "failed", error_message=str(e) ) finally: # 清理临时文件 if cleanup_paths: logger.info(f"正在清理 {len(cleanup_paths)} 个临时文件...") for path in cleanup_paths: try: if os.path.exists(path): os.remove(path) logger.info(f" - 已删除: {path}") except Exception as e: logger.warning(f" - 删除 {path} 时出错: {e}") # 清理运行状态 async with self.lock: if server.http_url in self.running_tasks: del self.running_tasks[server.http_url] # 继续处理队列 asyncio.create_task(self._process_queue()) async def get_task_status(self, workflow_run_id: str) -> dict: """获取任务状态""" workflow_run = await get_workflow_run(workflow_run_id) if not workflow_run: return {"error": "任务不存在"} nodes = await get_workflow_run_nodes(workflow_run_id) result = { "id": workflow_run_id, "status": workflow_run["status"], "created_at": workflow_run["created_at"], "started_at": workflow_run["started_at"], "completed_at": workflow_run["completed_at"], "server_url": workflow_run["server_url"], "error_message": workflow_run["error_message"], "nodes": nodes, } # 如果任务完成,从数据库获取结果 if workflow_run["status"] == "completed" and workflow_run.get("result"): try: result["result"] = json.loads(workflow_run["result"]) except (json.JSONDecodeError, TypeError): result["result"] = None return result # 全局队列管理器实例 queue_manager = WorkflowQueueManager() # 定义一个自定义异常,用于封装来自ComfyUI的执行错误 class ComfyUIExecutionError(Exception): def __init__(self, error_data: dict): self.error_data = error_data # 创建一个对开发者友好的异常消息 message = ( f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, " f"节点类型: {error_data.get('node_type')}. " f"错误: {error_data.get('exception_message', 'N/A')}" ) super().__init__(message) async def get_server_status( server: ComfyUIServer, session: aiohttp.ClientSession ) -> Dict[str, Any]: """ 检查单个ComfyUI服务器的详细状态。 返回一个包含可达性、队列状态和详细队列内容的字典。 """ # [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型 status_info = { "is_reachable": False, "is_free": False, "queue_details": {"running_count": 0, "pending_count": 0}, } try: queue_url = f"{server.http_url}/queue" async with session.get(queue_url, timeout=60) as response: response.raise_for_status() queue_data = await response.json() status_info["is_reachable"] = True running_count = len(queue_data.get("queue_running", [])) pending_count = len(queue_data.get("queue_pending", [])) status_info["queue_details"] = { "running_count": running_count, "pending_count": pending_count, } status_info["is_free"] = running_count == 0 and pending_count == 0 except Exception as e: # 当请求失败时,将返回上面定义的、结构正确的初始 status_info logger.warning(f"无法检查服务器 {server.http_url} 的队列状态: {e}") return status_info async def select_server_for_execution() -> ComfyUIServer: """ 智能选择一个ComfyUI服务器。 优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。 """ servers = settings.SERVERS if not servers: raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。") if len(servers) == 1: return servers[0] async with aiohttp.ClientSession() as session: tasks = [get_server_status(server, session) for server in servers] results = await asyncio.gather(*tasks) free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]] if free_servers: selected_server = random.choice(free_servers) logger.info( f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}" ) return selected_server else: # 后备方案:选择一个可达的服务器,即使它很忙 reachable_servers = [ servers[i] for i, status in enumerate(results) if status["is_reachable"] ] if reachable_servers: selected_server = random.choice(reachable_servers) logger.info( f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}" ) return selected_server else: # 最坏情况:所有服务器都不可达,抛出异常 raise ConnectionError("所有配置的ComfyUI服务器都不可达。") async def execute_prompt_on_server( workflow_data: Dict, api_spec: Dict, request_data: Dict, server: ComfyUIServer, workflow_run_id: str, ) -> Dict: """ 在指定的服务器上执行一个准备好的prompt。 现在支持节点级别的状态跟踪。 """ client_id = str(uuid.uuid4()) # 应用请求数据到工作流 patched_workflow = await patch_workflow( workflow_data, api_spec, request_data, server ) # 转换为prompt格式 prompt = convert_workflow_to_prompt_api_format(patched_workflow) # 更新工作流运行状态,记录prompt_id和client_id await update_workflow_run_status( workflow_run_id, "running", server.http_url, None, # prompt_id将在_queue_prompt后更新 client_id, ) # 提交到ComfyUI prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url) # 更新prompt_id await update_workflow_run_status( workflow_run_id, "running", server.http_url, prompt_id, client_id ) logger.info( f"工作流 {workflow_run_id} 已在 {server.http_url} 上入队,Prompt ID: {prompt_id}" ) # 获取执行结果,现在支持节点级别的状态跟踪 results = await _get_execution_results( workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id ) return results async def submit_workflow_to_queue( workflow_name: str, workflow_data: Dict, api_spec: Dict, request_data: Dict ) -> str: """ 提交工作流到队列,立即返回任务ID。 这是新的异步接口,调用者可以通过任务ID查询状态。 """ workflow_run_id = str(uuid.uuid4()) # 添加到队列管理器 await queue_manager.add_task( workflow_run_id, workflow_name, workflow_data, api_spec, request_data ) return workflow_run_id def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]: """ 解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。 """ spec = {"inputs": {}, "outputs": {}} if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list): raise ValueError( "Invalid workflow format: 'nodes' key not found or is not a list." ) nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} input_name_counter = defaultdict(int) output_name_counter = defaultdict(int) for node_id, node in nodes_map.items(): title: str = node.get("title") if not title: continue if title.startswith(API_INPUT_PREFIX): base_name = title[len(API_INPUT_PREFIX) :].lower() if "inputs" in node: for a_input in node.get("inputs", []): if a_input.get("link") is None and "widget" in a_input: widget_name = a_input["widget"]["name"].lower() # [BUG修复] 构建有意义的基础参数名 param_name_candidate = f"{base_name}_{widget_name}" input_name_counter[param_name_candidate] += 1 count = input_name_counter[param_name_candidate] final_param_name = ( f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate ) input_type_str = a_input.get("type", "STRING").upper() param_type = "string" if "COMBO" in input_type_str: param_type = "UploadFile" elif "INT" in input_type_str: param_type = "int" elif "FLOAT" in input_type_str: param_type = "float" spec["inputs"][final_param_name] = { "node_id": node_id, "type": param_type, "widget_name": a_input["widget"]["name"], } elif title.startswith(API_OUTPUT_PREFIX): base_name = title[len(API_OUTPUT_PREFIX) :].lower() if "outputs" in node: for an_output in node.get("outputs", []): output_name = an_output["name"].lower() # [BUG修复] 构建有意义的基础参数名 param_name_candidate = f"{base_name}_{output_name}" output_name_counter[param_name_candidate] += 1 count = output_name_counter[param_name_candidate] final_param_name = ( f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate ) spec["outputs"][final_param_name] = { "node_id": node_id, "class_type": node.get("type"), "output_name": an_output["name"], "output_index": node["outputs"].index(an_output), } return spec async def patch_workflow( workflow_data: dict, api_spec: dict[str, dict[str, Any]], request_data: dict[str, Any], server: ComfyUIServer, ) -> dict: """ 将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。 """ if "nodes" not in workflow_data: raise ValueError("无效的工作流格式") nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} for param_name, value in request_data.items(): if param_name not in api_spec["inputs"]: continue spec = api_spec["inputs"][param_name] node_id = spec["node_id"] if node_id not in nodes_map: continue target_node = nodes_map[node_id] widget_name_to_patch = spec["widget_name"] # 找到所有属于widget的输入,并确定目标widget的索引 widget_inputs = [ inp for inp in target_node.get("inputs", []) if "widget" in inp ] target_widget_index = -1 for i, widget_input in enumerate(widget_inputs): # "widget" 字典可能不存在或没有 "name" 键 if widget_input.get("widget", {}).get("name") == widget_name_to_patch: target_widget_index = i break if target_widget_index == -1: logger.warning( f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。" ) continue # 确保 `widgets_values` 存在且为列表 if "widgets_values" not in target_node or not isinstance( target_node.get("widgets_values"), list ): # 如果不存在或格式错误,根据widget数量创建一个占位符列表 target_node["widgets_values"] = [None] * len(widget_inputs) # 确保 `widgets_values` 列表足够长 while len(target_node["widgets_values"]) <= target_widget_index: target_node["widgets_values"].append(None) # 根据API规范转换数据类型 target_type = str if spec["type"] == "int": target_type = int elif spec["type"] == "float": target_type = float # 在正确的位置上更新值 try: if target_node.get("type") == "LoadImage": value = await upload_image_to_comfy(value, server) target_node["widgets_values"][target_widget_index] = target_type(value) except (ValueError, TypeError) as e: logger.warning( f"无法将参数 '{param_name}' 的值 '{value}' 转换为类型 '{spec['type']}'。错误: {e}" ) continue workflow_data["nodes"] = list(nodes_map.values()) return workflow_data def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict: """ 将工作流(API格式)转换为提交到/prompt端点的格式。 此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。 """ if "nodes" not in workflow_data: raise ValueError("无效的工作流格式") prompt_api_format = {} # 建立从link_id到源节点的映射 link_map = {} for link_data in workflow_data.get("links", []): ( link_id, origin_node_id, origin_slot_index, target_node_id, target_slot_index, link_type, ) = link_data # 键是目标节点的输入link_id link_map[link_id] = [str(origin_node_id), origin_slot_index] for node in workflow_data["nodes"]: node_id = str(node["id"]) inputs_dict = {} # 1. 处理控件输入 (widgets) widgets_values = node.get("widgets_values", []) widget_cursor = 0 for input_config in node.get("inputs", []): # 如果是widget并且有对应的widgets_values if "widget" in input_config and widget_cursor < len(widgets_values): widget_name = input_config["widget"].get("name") if widget_name: # 使用widgets_values中的值,因为这里已经包含了API传入的修改 inputs_dict[widget_name] = widgets_values[widget_cursor] widget_cursor += 1 # 2. 处理连接输入 (links) for input_config in node.get("inputs", []): if "link" in input_config and input_config["link"] is not None: link_id = input_config["link"] if link_id in link_map: # 输入名称是input_config中的'name' inputs_dict[input_config["name"]] = link_map[link_id] prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict} return prompt_api_format async def _queue_prompt( workflow: dict, prompt: dict, client_id: str, http_url: str ) -> str: """通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。""" for node_id in prompt: prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random() payload = { "prompt": prompt, "client_id": client_id, "extra_data": { "api_key_comfy_org": "", "extra_pnginfo": {"workflow": workflow}, }, } logger.info(f"提交到 ComfyUI /prompt 端点的payload: {json.dumps(payload)}") async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session: prompt_url = f"{http_url}/prompt" try: async with session.post(prompt_url, json=payload) as response: logger.info(f"ComfyUI /prompt 端点返回的响应: {response}") response.raise_for_status() result = await response.json() logger.info(f"ComfyUI /prompt 端点返回的响应: {result}") if "prompt_id" not in result: raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}") return result["prompt_id"] except Exception as e: logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}") raise e async def _get_execution_results( workflow_data: Dict, prompt_id: str, client_id: str, ws_url: str, workflow_run_id: str, ) -> dict: """ 通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。 现在支持节点级别的状态跟踪。 """ full_ws_url = f"{ws_url}?clientId={client_id}" aggregated_outputs = {} try: async with websockets.connect(full_ws_url) as websocket: while True: try: out = await websocket.recv() if not isinstance(out, str): continue message = json.loads(out) msg_type = message.get("type") data = message.get("data") if not (data and data.get("prompt_id") == prompt_id): continue # 捕获并处理执行错误 if msg_type == "execution_error": error_data = data logger.error( f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}" ) # 更新节点状态为失败 node_id = error_data.get("node_id") if node_id: await update_workflow_run_node_status( workflow_run_id, node_id, "failed", error_message=error_data.get( "exception_message", "Unknown error" ), ) # 抛出自定义异常,将错误详情传递出去 raise ComfyUIExecutionError(error_data) # 处理节点开始执行 if msg_type == "executing" and data.get("node"): node_id = data.get("node") logger.info(f"节点 {node_id} 开始执行 (Prompt ID: {prompt_id})") # 更新节点状态为运行中 await update_workflow_run_node_status( workflow_run_id, node_id, "running" ) # 处理节点执行完成 if msg_type == "executed": node_id = data.get("node") output_data = data.get("output") if node_id and output_data: node = next( ( x for x in workflow_data["nodes"] if str(x["id"]) == node_id ), None, ) if ( node and node.get("title", "") and node["title"].startswith(API_OUTPUT_PREFIX) ): aggregated_outputs[node["title"]] = output_data logger.info( f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})" ) # 更新节点状态为完成 await update_workflow_run_node_status( workflow_run_id, node_id, "completed", output_data=json.dumps(output_data), ) # 处理整个工作流执行完成 elif msg_type == "executing" and data.get("node") is None: logger.info(f"Prompt ID: {prompt_id} 执行完成。") return aggregated_outputs except websockets.exceptions.ConnectionClosed as e: logger.warning( f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}" ) return aggregated_outputs except Exception as e: # 重新抛出我们自己的异常,或者处理其他意外错误 if not isinstance(e, ComfyUIExecutionError): logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}") raise e except websockets.exceptions.InvalidURI as e: logger.error( f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}" ) raise e return aggregated_outputs async def _get_execution_results_legacy( prompt_id: str, client_id: str, ws_url: str ) -> dict: """ 简化版的执行结果获取函数,不包含数据库状态跟踪。 """ full_ws_url = f"{ws_url}?clientId={client_id}" aggregated_outputs = {} try: async with websockets.connect(full_ws_url) as websocket: while True: try: out = await websocket.recv() if not isinstance(out, str): continue message = json.loads(out) msg_type = message.get("type") data = message.get("data") if not (data and data.get("prompt_id") == prompt_id): continue # 捕获并处理执行错误 if msg_type == "execution_error": error_data = data logger.error( f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}" ) # 抛出自定义异常,将错误详情传递出去 raise ComfyUIExecutionError(error_data) # 处理节点执行完成 if msg_type == "executed": node_id = data.get("node") output_data = data.get("output") if node_id and output_data: aggregated_outputs[node_id] = output_data logger.info( f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})" ) # 处理整个工作流执行完成 elif msg_type == "executing" and data.get("node") is None: logger.info(f"Prompt ID: {prompt_id} 执行完成。") return aggregated_outputs except websockets.exceptions.ConnectionClosed as e: logger.warning( f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}" ) return aggregated_outputs except Exception as e: # 重新抛出我们自己的异常,或者处理其他意外错误 if not isinstance(e, ComfyUIExecutionError): logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}") raise e except websockets.exceptions.InvalidURI as e: logger.error( f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}" ) raise e return aggregated_outputs async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str: """ 上传文件到服务器。 file 是一个http链接 返回一个名称 """ file_name = file_path.split("/")[-1] file_data = await download_file(file_path) form_data = aiohttp.FormData() form_data.add_field("image", file_data, filename=file_name, content_type="image/*") form_data.add_field("type", "input") async with aiohttp.ClientSession() as session: async with session.post( f"{server.http_url}/api/upload/image", data=form_data ) as response: response.raise_for_status() result = await response.json() return result["name"] async def download_file(src: str) -> bytes: """ 下载文件到本地。 """ async with aiohttp.ClientSession() as session: async with session.get(src) as response: response.raise_for_status() return await response.read()