import json import logging import random import uuid from typing import Optional import aiohttp import websockets from aiohttp import ClientTimeout from app.comfy.comfy_workflow import ComfyWorkflow, API_OUTPUT_PREFIX from app.comfy.comfy_server import ComfyUIServerInfo, server_manager from app.database.api import ( create_workflow_run, create_workflow_run_nodes, get_workflow_run, update_workflow_run_status, update_workflow_run_node_status, ) class ComfyUIExecutionError(Exception): """ComfyUI执行错误""" def __init__(self, error_data): self.error_data = error_data super().__init__(f"ComfyUI execution error: {error_data}") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ComfyRun: """ ComfyUI工作流运行实例,封装单个工作流运行的所有操作 """ def __init__(self, workflow: ComfyWorkflow, run_id: str, request_data: dict): """ 初始化工作流运行实例 Args: workflow: ComfyWorkflow实例 run_id: 运行ID request_data: 请求数据 """ self.workflow = workflow self.run_id = run_id self.request_data = request_data self.client_id = str(uuid.uuid4()) self.prompt_id: Optional[str] = None self.server: Optional[ComfyUIServerInfo] = None @classmethod async def create(cls, workflow: ComfyWorkflow, request_data: dict) -> "ComfyRun": """ 创建新的ComfyRun实例并保存到数据库 Args: workflow: ComfyWorkflow实例 request_data: 请求数据 Returns: ComfyRun实例 """ run_id = str(uuid.uuid4()) # 创建任务记录 await create_workflow_run( workflow_run_id=run_id, workflow_name=workflow.workflow_name, workflow_json=json.dumps(workflow.workflow_data.model_dump()), api_spec=json.dumps(workflow.get_api_spec().model_dump()), request_data=json.dumps(request_data), ) # 创建工作流节点记录 nodes_data = [] for node in workflow.workflow_data.nodes: nodes_data.append({"id": str(node.id), "type": node.type}) await create_workflow_run_nodes(run_id, nodes_data) return cls(workflow, run_id, request_data) @classmethod async def from_run_id(cls, run_id: str) -> Optional["ComfyRun"]: """ 从run_id创建ComfyRun实例 Args: run_id: 运行ID Returns: ComfyRun实例,如果找不到则返回None """ workflow_run = await get_workflow_run(run_id) if not workflow_run: return None workflow_data = json.loads(workflow_run.workflow_json) workflow = ComfyWorkflow(workflow_run.workflow_name, workflow_data) request_data = json.loads(workflow_run.request_data) return cls(workflow, run_id, request_data) async def execute(self, server: ComfyUIServerInfo) -> dict: """ 在指定的服务器上执行工作流 Args: server: ComfyUI服务器信息 Returns: 执行结果 """ self.server = server try: # 分配服务器资源 await server_manager.allocate_server(server.name) # 构建prompt prompt = await self.workflow.build_prompt(server, self.request_data) # 更新运行状态为running await self._update_status( "running", server.http_url, client_id=self.client_id ) # 提交到ComfyUI self.prompt_id = await self._queue_prompt(prompt, server.http_url) # 更新prompt_id await self._update_status( "running", server.http_url, self.prompt_id, self.client_id ) logger.info( f"工作流 {self.run_id} 已在 {server.http_url} 上入队,Prompt ID: {self.prompt_id}" ) # 获取执行结果 results = await self._get_execution_results(server.ws_url) # 标记完成 await self._update_status( "completed", result=json.dumps(results, ensure_ascii=False) ) return results except Exception as e: # 标记失败 await self._update_status("failed", error_message=str(e)) raise finally: # 释放服务器资源 try: await server_manager.release_server(server.name) except Exception as e: logger.error(f"释放服务器资源时出错: {e}") async def _update_status( self, status: str, server_url: Optional[str] = None, prompt_id: Optional[str] = None, client_id: Optional[str] = None, result: Optional[str] = None, error_message: Optional[str] = None, ): """更新工作流运行状态""" await update_workflow_run_status( self.run_id, status, server_url, prompt_id, client_id, error_message, result ) async def _queue_prompt(self, prompt: dict, http_url: str) -> str: """提交工作流到ComfyUI服务器""" # 添加随机缓存破坏器避免缓存 for node_id in prompt: prompt[node_id]["inputs"][ f"cache_buster_{uuid.uuid4().hex}" ] = random.random() payload = { "prompt": prompt, "client_id": self.client_id, "extra_data": { "api_key_comfy_org": "", "extra_pnginfo": {"workflow": self.workflow.workflow_data.model_dump()}, }, } async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session: prompt_url = f"{http_url}/prompt" try: async with session.post(prompt_url, json=payload) as response: logger.info(f"ComfyUI /prompt 端点返回的响应: {response}") response.raise_for_status() result = await response.json() logger.info(f"ComfyUI /prompt 端点返回的响应: {result}") if "prompt_id" not in result: raise Exception( f"从 ComfyUI /prompt 端点返回的响应无效: {result}" ) return result["prompt_id"] except Exception as e: logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}") raise async def _get_execution_results(self, ws_url: str) -> dict: """ 通过WebSocket连接获取执行结果,支持节点级别的状态跟踪 """ aggregated_outputs = {} full_ws_url = f"{ws_url}?clientId={self.client_id}" try: async with websockets.connect(full_ws_url) as websocket: logger.info(f"已连接到WebSocket: {full_ws_url}") while True: try: out = await websocket.recv() if not isinstance(out, str): continue message = json.loads(out) msg_type = message.get("type") data = message.get("data") if not (data and data.get("prompt_id") == self.prompt_id): continue # 捕获并处理执行错误 if msg_type == "execution_error": error_data = data logger.error( f"ComfyUI执行错误 (Prompt ID: {self.prompt_id}): {error_data}" ) # 更新节点状态为失败 node_id = error_data.get("node_id") if node_id: await update_workflow_run_node_status( self.run_id, node_id, "failed", error_message=error_data.get( "exception_message", "Unknown error" ), ) # 抛出自定义异常,将错误详情传递出去 raise ComfyUIExecutionError(error_data) # 处理节点开始执行 if msg_type == "executing" and data.get("node"): node_id = data.get("node") logger.info( f"节点 {node_id} 开始执行 (Prompt ID: {self.prompt_id})" ) # 更新节点状态为运行中 await update_workflow_run_node_status( self.run_id, node_id, "running" ) # 处理节点执行完成 elif msg_type == "executed": await self._handle_node_executed(data, aggregated_outputs) # 处理整个工作流执行完成 elif msg_type == "executing" and data.get("node") is None: logger.info(f"Prompt ID: {self.prompt_id} 执行完成。") return aggregated_outputs except websockets.exceptions.ConnectionClosed as e: logger.warning( f"WebSocket 连接已关闭 (Prompt ID: {self.prompt_id})。错误: {e}" ) return aggregated_outputs except Exception as e: # 重新抛出我们自己的异常,或者处理其他意外错误 if not isinstance(e, ComfyUIExecutionError): logger.error( f"处理 prompt {self.prompt_id} 时发生意外错误: {e}" ) raise e except websockets.exceptions.InvalidURI as e: logger.error( f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}" ) raise e except Exception as e: logger.error(f"WebSocket连接出错: {e}") raise async def _handle_node_executed(self, data: dict, aggregated_outputs: dict): """处理节点执行完成事件""" node_id = data.get("node") output_data = data.get("output") if not node_id or not output_data: return # 查找对应的节点 node = next( (x for x in self.workflow.workflow_data.nodes if str(x.id) == node_id), None, ) if not node: return # 如果是输出节点,收集结果 if node.title and node.title.startswith(API_OUTPUT_PREFIX): title = node.title.replace(API_OUTPUT_PREFIX, "") aggregated_outputs[title] = output_data logger.info(f"收到节点 {node_id} 的输出 (Prompt ID: {self.prompt_id})") # 更新节点状态为完成 await update_workflow_run_node_status( self.run_id, node_id, "completed", output_data=json.dumps(output_data), ) def get_status_info(self) -> dict: """获取运行状态信息""" return { "run_id": self.run_id, "workflow_name": self.workflow.workflow_name, "client_id": self.client_id, "prompt_id": self.prompt_id, "server_url": self.server.http_url if self.server else None, }