import asyncio import json import logging import os import random import uuid import websockets from datetime import datetime from typing import Dict, Any, Optional import aiohttp from aiohttp import ClientTimeout from workflow_service.comfy.comfy_workflow import build_prompt from workflow_service.config import Settings from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo from workflow_service.database.api import ( create_workflow_run, update_workflow_run_status, create_workflow_run_nodes, update_workflow_run_node_status, get_workflow_run, get_workflow_run_nodes, ) settings = Settings() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) API_INPUT_PREFIX = "INPUT_" API_OUTPUT_PREFIX = "OUTPUT_" class ComfyUIExecutionError(Exception): def __init__(self, error_data: dict): self.error_data = error_data # 创建一个对开发者友好的异常消息 message = ( f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, " f"节点类型: {error_data.get('node_type')}. " f"错误: {error_data.get('exception_message', 'N/A')}" ) super().__init__(message) # 全局任务队列管理器 class WorkflowQueueManager: def __init__(self, monitor_interval: int = 5): self.running_tasks = {} # server_url -> task_info self.pending_tasks = [] # 等待队列 self.lock = asyncio.Lock() self._queue_monitor_task: Optional[asyncio.Task] = None self._monitor_interval = monitor_interval # 监控间隔(秒) self._last_processing_time = None # 上次处理队列的时间 async def start_queue_monitor(self): """启动队列监控任务""" if self._queue_monitor_task is None or self._queue_monitor_task.done(): self._queue_monitor_task = asyncio.create_task(self._monitor_queue()) logger.info(f"队列监控任务已启动,监控间隔: {self._monitor_interval}秒") async def stop_queue_monitor(self): """停止队列监控任务""" if self._queue_monitor_task and not self._queue_monitor_task.done(): self._queue_monitor_task.cancel() try: await self._queue_monitor_task except asyncio.CancelledError: pass logger.info("队列监控任务已停止") async def set_monitor_interval(self, interval: int): """设置监控间隔""" if interval < 1: raise ValueError("监控间隔不能小于1秒") self._monitor_interval = interval logger.info(f"队列监控间隔已设置为 {interval} 秒") # 如果监控任务正在运行,重启它以应用新间隔 if self._queue_monitor_task and not self._queue_monitor_task.done(): await self.stop_queue_monitor() await self.start_queue_monitor() async def trigger_queue_processing(self): """手动触发队列处理(用于测试或紧急情况)""" logger.info("手动触发队列处理") asyncio.create_task(self._process_queue()) async def get_queue_status(self) -> dict: """获取队列状态信息""" async with self.lock: return { "pending_tasks_count": len(self.pending_tasks), "running_tasks_count": len(self.running_tasks), "monitor_active": self._queue_monitor_task is not None and not self._queue_monitor_task.done(), "monitor_interval": self._monitor_interval, } async def _monitor_queue(self): """定期监控队列,当有可用服务器时自动处理""" while True: try: await asyncio.sleep(self._monitor_interval) # 每次定时检测触发时,打印当前状态 async with self.lock: pending_count = len(self.pending_tasks) running_count = len(self.running_tasks) current_time = datetime.now() timestamp = current_time.strftime("%Y-%m-%d %H:%M:%S") time_only = current_time.strftime("%H:%M:%S") # 计算距离上次处理队列的时间 time_since_last_processing = "" if self._last_processing_time: time_diff = current_time - self._last_processing_time minutes = int(time_diff.total_seconds() // 60) seconds = int(time_diff.total_seconds() % 60) time_since_last_processing = ( f"上次处理: {minutes}分{seconds}秒前" ) else: time_since_last_processing = "上次处理: 从未" # 格式化状态信息,提高可读性 status_info = ( f"⏰ 定时检测触发 [{timestamp}]\n" f" 📋 待处理任务: {pending_count} 个\n" f" 🚀 运行中任务: {running_count} 个\n" f" ⏱️ 监控间隔: {self._monitor_interval} 秒\n" f" 🕐 {time_since_last_processing}" ) logger.info(status_info) if self.pending_tasks: available_servers = await self._get_available_servers() if available_servers: logger.info( f"监控发现 {len(available_servers)} 个可用服务器,开始处理队列" ) # 使用create_task避免阻塞监控循环 asyncio.create_task(self._process_queue()) else: # 只在有任务等待时记录日志,避免日志噪音 if len(self.pending_tasks) > 0: logger.debug( f"队列中有 {len(self.pending_tasks)} 个待处理任务,但无可用服务器" ) else: # 队列为空时,减少日志输出 logger.debug("队列为空,继续监控中...") except asyncio.CancelledError: logger.info("队列监控任务被取消") break except Exception as e: logger.error(f"队列监控任务出错: {e}") # 出错时稍微延长等待时间,避免频繁重试 await asyncio.sleep(self._monitor_interval * 2) async def add_task( self, workflow_name: str, workflow_data: dict, api_spec: dict, request_data: dict, ): """添加新任务到队列""" workflow_run_id = str(uuid.uuid4()) async with self.lock: # 创建任务记录 await create_workflow_run( workflow_run_id=workflow_run_id, workflow_name=workflow_name, workflow_json=json.dumps(workflow_data), api_spec=json.dumps(api_spec), request_data=json.dumps(request_data), ) # 创建工作流节点记录 nodes_data = [] for node in workflow_data.get("nodes", []): nodes_data.append( {"id": str(node["id"]), "type": node.get("type", "unknown")} ) await create_workflow_run_nodes(workflow_run_id, nodes_data) # 添加到待处理队列 self.pending_tasks.append(workflow_run_id) logger.info( f"任务 {workflow_run_id} 已添加到队列,当前队列长度: {len(self.pending_tasks)}" ) # 注意:不再手动触发队列处理,由定时监控自动处理 # 这样可以避免在没有可用服务器时的不必要尝试 return workflow_run_id async def _process_queue(self): """处理队列中的任务""" async with self.lock: # 更新上次处理队列的时间 self._last_processing_time = datetime.now() if not self.pending_tasks: logger.debug("队列为空,无需处理") return # 检查是否有空闲的服务器 available_servers = await self._get_available_servers() if not available_servers: logger.info( f"队列中有 {len(self.pending_tasks)} 个待处理任务,但没有可用的服务器" ) return # 获取一个待处理任务 workflow_run_id = self.pending_tasks.pop(0) server = available_servers[0] logger.info( f"开始处理任务 {workflow_run_id},使用服务器 {server.name} ({server.http_url})" ) # 分配服务器资源 await server_manager.allocate_server(server.name) # 标记任务为运行中 await update_workflow_run_status( workflow_run_id, "running", server.http_url ) self.running_tasks[server.http_url] = { "workflow_run_id": workflow_run_id, "started_at": datetime.now(), "server_name": server.name, } # 启动任务执行 asyncio.create_task(self._execute_task(workflow_run_id, server)) logger.info( f"任务 {workflow_run_id} 已分配到服务器 {server.name},当前运行中任务数: {len(self.running_tasks)}" ) async def _get_available_servers(self) -> list[ComfyUIServerInfo]: """获取可用的服务器""" # 使用新的服务器管理器获取可用服务器 available_server = await server_manager.get_available_server() if available_server: return [available_server] return [] async def _execute_task(self, workflow_run_id: str, server: ComfyUIServerInfo): """执行任务""" cleanup_paths = [] try: # 获取工作流数据 workflow_run = await get_workflow_run(workflow_run_id) if not workflow_run: raise Exception(f"找不到工作流运行记录: {workflow_run_id}") workflow_data = json.loads(workflow_run["workflow_json"]) api_spec = json.loads(workflow_run["api_spec"]) request_data = json.loads(workflow_run["request_data"]) # 执行工作流 result = await _execute_prompt_on_server( workflow_data, api_spec, request_data, server, workflow_run_id ) # 保存处理后的结果到数据库 await update_workflow_run_status( workflow_run_id, "completed", result=json.dumps(result, ensure_ascii=False), ) except Exception as e: logger.error(f"执行任务 {workflow_run_id} 时出错: {e}") await update_workflow_run_status( workflow_run_id, "failed", error_message=str(e) ) finally: # 清理临时文件 if cleanup_paths: logger.info(f"正在清理 {len(cleanup_paths)} 个临时文件...") for path in cleanup_paths: try: if os.path.exists(path): os.remove(path) logger.info(f" - 已删除: {path}") except Exception as e: logger.warning(f" - 删除 {path} 时出错: {e}") # 清理运行状态 async with self.lock: if server.http_url in self.running_tasks: del self.running_tasks[server.http_url] # 释放服务器资源 await server_manager.release_server(server.name) async def get_task_status(self, workflow_run_id: str) -> dict: """获取任务状态""" workflow_run = await get_workflow_run(workflow_run_id) if not workflow_run: return {"error": "任务不存在"} nodes = await get_workflow_run_nodes(workflow_run_id) result = { "id": workflow_run_id, "status": workflow_run["status"], "created_at": workflow_run["created_at"], "started_at": workflow_run["started_at"], "completed_at": workflow_run["completed_at"], "server_url": workflow_run["server_url"], "error_message": workflow_run["error_message"], "nodes": nodes, } # 如果任务完成,从数据库获取结果 if workflow_run["status"] == "completed" and workflow_run.get("result"): try: result["result"] = json.loads(workflow_run["result"]) except (json.JSONDecodeError, TypeError): result["result"] = None return result async def get_server_status( self, server: ComfyUIServerInfo, session: aiohttp.ClientSession ) -> dict[str, Any]: """ 检查单个ComfyUI服务器的详细状态。 返回一个包含可达性、队列状态和详细队列内容的字典。 """ # [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型 status_info = { "is_reachable": False, "is_free": False, "queue_details": {"running_count": 0, "pending_count": 0}, } try: queue_url = f"{server.http_url}/queue" async with session.get(queue_url, timeout=60) as response: response.raise_for_status() queue_data = await response.json() status_info["is_reachable"] = True running_count = len(queue_data.get("queue_running", [])) pending_count = len(queue_data.get("queue_pending", [])) status_info["queue_details"] = { "running_count": running_count, "pending_count": pending_count, } status_info["is_free"] = running_count == 0 and pending_count == 0 except Exception as e: # 当请求失败时,将返回上面定义的、结构正确的初始 status_info logger.warning(f"无法检查服务器 {server.http_url} 的队列状态: {e}") return status_info # 全局队列管理器实例 # 从配置中读取监控间隔,默认为5秒 queue_manager = WorkflowQueueManager(monitor_interval=5) async def _execute_prompt_on_server( workflow_data: dict, api_spec: dict, request_data: dict, server: ComfyUIServerInfo, workflow_run_id: str, ) -> dict: """ 在指定的服务器上执行一个准备好的prompt。 现在支持节点级别的状态跟踪。 """ client_id = str(uuid.uuid4()) # 构建prompt prompt = await build_prompt(workflow_data, api_spec, request_data, server) # 更新工作流运行状态,记录prompt_id和client_id await update_workflow_run_status( workflow_run_id, "running", server.http_url, None, # prompt_id将在_queue_prompt后更新 client_id, ) # 提交到ComfyUI prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url) # 更新prompt_id await update_workflow_run_status( workflow_run_id, "running", server.http_url, prompt_id, client_id ) logger.info( f"工作流 {workflow_run_id} 已在 {server.http_url} 上入队,Prompt ID: {prompt_id}" ) # 获取执行结果,现在支持节点级别的状态跟踪 results = await _get_execution_results( workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id ) return results async def _queue_prompt( workflow: dict, prompt: dict, client_id: str, http_url: str ) -> str: """通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。""" for node_id in prompt: prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random() payload = { "prompt": prompt, "client_id": client_id, "extra_data": { "api_key_comfy_org": "", "extra_pnginfo": {"workflow": workflow}, }, } async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session: prompt_url = f"{http_url}/prompt" try: async with session.post(prompt_url, json=payload) as response: logger.info(f"ComfyUI /prompt 端点返回的响应: {response}") response.raise_for_status() result = await response.json() logger.info(f"ComfyUI /prompt 端点返回的响应: {result}") if "prompt_id" not in result: raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}") return result["prompt_id"] except Exception as e: logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}") raise e async def _get_execution_results( workflow_data: Dict, prompt_id: str, client_id: str, ws_url: str, workflow_run_id: str, ) -> dict: """ 通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。 现在支持节点级别的状态跟踪。 """ full_ws_url = f"{ws_url}?clientId={client_id}" aggregated_outputs = {} try: async with websockets.connect(full_ws_url) as websocket: while True: try: out = await websocket.recv() if not isinstance(out, str): continue message = json.loads(out) msg_type = message.get("type") data = message.get("data") if not (data and data.get("prompt_id") == prompt_id): continue # 捕获并处理执行错误 if msg_type == "execution_error": error_data = data logger.error( f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}" ) # 更新节点状态为失败 node_id = error_data.get("node_id") if node_id: await update_workflow_run_node_status( workflow_run_id, node_id, "failed", error_message=error_data.get( "exception_message", "Unknown error" ), ) # 抛出自定义异常,将错误详情传递出去 raise ComfyUIExecutionError(error_data) # 处理节点开始执行 if msg_type == "executing" and data.get("node"): node_id = data.get("node") logger.info(f"节点 {node_id} 开始执行 (Prompt ID: {prompt_id})") # 更新节点状态为运行中 await update_workflow_run_node_status( workflow_run_id, node_id, "running" ) # 处理节点执行完成 if msg_type == "executed": node_id = data.get("node") output_data = data.get("output") if node_id and output_data: node = next( ( x for x in workflow_data["nodes"] if str(x["id"]) == node_id ), None, ) if ( node and node.get("title", "") and node["title"].startswith(API_OUTPUT_PREFIX) ): title = node["title"].replace(API_OUTPUT_PREFIX, "") aggregated_outputs[title] = output_data logger.info( f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})" ) # 更新节点状态为完成 await update_workflow_run_node_status( workflow_run_id, node_id, "completed", output_data=json.dumps(output_data), ) # 处理整个工作流执行完成 elif msg_type == "executing" and data.get("node") is None: logger.info(f"Prompt ID: {prompt_id} 执行完成。") return aggregated_outputs except websockets.exceptions.ConnectionClosed as e: logger.warning( f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}" ) return aggregated_outputs except Exception as e: # 重新抛出我们自己的异常,或者处理其他意外错误 if not isinstance(e, ComfyUIExecutionError): logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}") raise e except websockets.exceptions.InvalidURI as e: logger.error( f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}" ) raise e