ComfyUI-WorkflowPublisher/app/comfy/comfy_run.py

346 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import random
import uuid
from typing import Optional
import aiohttp
import websockets
from aiohttp import ClientTimeout
from app.comfy.comfy_workflow import ComfyWorkflow, API_OUTPUT_PREFIX
from app.comfy.comfy_server import ComfyUIServerInfo, server_manager
from app.database.api import (
create_workflow_run,
create_workflow_run_nodes,
get_workflow_run,
update_workflow_run_status,
update_workflow_run_node_status,
)
class ComfyUIExecutionError(Exception):
"""ComfyUI执行错误"""
def __init__(self, error_data):
self.error_data = error_data
super().__init__(f"ComfyUI execution error: {error_data}")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ComfyRun:
"""
ComfyUI工作流运行实例封装单个工作流运行的所有操作
"""
def __init__(self, workflow: ComfyWorkflow, run_id: str, request_data: dict):
"""
初始化工作流运行实例
Args:
workflow: ComfyWorkflow实例
run_id: 运行ID
request_data: 请求数据
"""
self.workflow = workflow
self.run_id = run_id
self.request_data = request_data
self.client_id = str(uuid.uuid4())
self.prompt_id: Optional[str] = None
self.server: Optional[ComfyUIServerInfo] = None
@classmethod
async def create(cls, workflow: ComfyWorkflow, request_data: dict) -> "ComfyRun":
"""
创建新的ComfyRun实例并保存到数据库
Args:
workflow: ComfyWorkflow实例
request_data: 请求数据
Returns:
ComfyRun实例
"""
run_id = str(uuid.uuid4())
# 创建任务记录
await create_workflow_run(
workflow_run_id=run_id,
workflow_name=workflow.workflow_name,
workflow_json=json.dumps(workflow.workflow_data.model_dump()),
api_spec=json.dumps(workflow.get_api_spec().model_dump()),
request_data=json.dumps(request_data),
)
# 创建工作流节点记录
nodes_data = []
for node in workflow.workflow_data.nodes:
nodes_data.append({"id": str(node.id), "type": node.type})
await create_workflow_run_nodes(run_id, nodes_data)
return cls(workflow, run_id, request_data)
@classmethod
async def from_run_id(cls, run_id: str) -> Optional["ComfyRun"]:
"""
从run_id创建ComfyRun实例
Args:
run_id: 运行ID
Returns:
ComfyRun实例如果找不到则返回None
"""
workflow_run = await get_workflow_run(run_id)
if not workflow_run:
return None
workflow_data = json.loads(workflow_run.workflow_json)
workflow = ComfyWorkflow(workflow_run.workflow_name, workflow_data)
request_data = json.loads(workflow_run.request_data)
return cls(workflow, run_id, request_data)
async def execute(self, server: ComfyUIServerInfo) -> dict:
"""
在指定的服务器上执行工作流
Args:
server: ComfyUI服务器信息
Returns:
执行结果
"""
self.server = server
try:
# 分配服务器资源
await server_manager.allocate_server(server.name)
# 构建prompt
prompt = await self.workflow.build_prompt(server, self.request_data)
# 更新运行状态为running
await self._update_status(
"running", server.http_url, client_id=self.client_id
)
# 提交到ComfyUI
self.prompt_id = await self._queue_prompt(prompt, server.http_url)
# 更新prompt_id
await self._update_status(
"running", server.http_url, self.prompt_id, self.client_id
)
logger.info(
f"工作流 {self.run_id} 已在 {server.http_url} 上入队Prompt ID: {self.prompt_id}"
)
# 获取执行结果
results = await self._get_execution_results(server.ws_url)
# 标记完成
await self._update_status(
"completed", result=json.dumps(results, ensure_ascii=False)
)
return results
except Exception as e:
# 标记失败
await self._update_status("failed", error_message=str(e))
raise
finally:
# 释放服务器资源
try:
await server_manager.release_server(server.name)
except Exception as e:
logger.error(f"释放服务器资源时出错: {e}")
async def _update_status(
self,
status: str,
server_url: Optional[str] = None,
prompt_id: Optional[str] = None,
client_id: Optional[str] = None,
result: Optional[str] = None,
error_message: Optional[str] = None,
):
"""更新工作流运行状态"""
await update_workflow_run_status(
self.run_id, status, server_url, prompt_id, client_id, error_message, result
)
async def _queue_prompt(self, prompt: dict, http_url: str) -> str:
"""提交工作流到ComfyUI服务器"""
# 添加随机缓存破坏器避免缓存
for node_id in prompt:
prompt[node_id]["inputs"][
f"cache_buster_{uuid.uuid4().hex}"
] = random.random()
payload = {
"prompt": prompt,
"client_id": self.client_id,
"extra_data": {
"api_key_comfy_org": "",
"extra_pnginfo": {"workflow": self.workflow.workflow_data.model_dump()},
},
}
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
prompt_url = f"{http_url}/prompt"
try:
async with session.post(prompt_url, json=payload) as response:
logger.info(f"ComfyUI /prompt 端点返回的响应: {response}")
response.raise_for_status()
result = await response.json()
logger.info(f"ComfyUI /prompt 端点返回的响应: {result}")
if "prompt_id" not in result:
raise Exception(
f"从 ComfyUI /prompt 端点返回的响应无效: {result}"
)
return result["prompt_id"]
except Exception as e:
logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}")
raise
async def _get_execution_results(self, ws_url: str) -> dict:
"""
通过WebSocket连接获取执行结果支持节点级别的状态跟踪
"""
aggregated_outputs = {}
full_ws_url = f"{ws_url}?clientId={self.client_id}"
try:
async with websockets.connect(full_ws_url) as websocket:
logger.info(f"已连接到WebSocket: {full_ws_url}")
while True:
try:
out = await websocket.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data")
if not (data and data.get("prompt_id") == self.prompt_id):
continue
# 捕获并处理执行错误
if msg_type == "execution_error":
error_data = data
logger.error(
f"ComfyUI执行错误 (Prompt ID: {self.prompt_id}): {error_data}"
)
# 更新节点状态为失败
node_id = error_data.get("node_id")
if node_id:
await update_workflow_run_node_status(
self.run_id,
node_id,
"failed",
error_message=error_data.get(
"exception_message", "Unknown error"
),
)
# 抛出自定义异常,将错误详情传递出去
raise ComfyUIExecutionError(error_data)
# 处理节点开始执行
if msg_type == "executing" and data.get("node"):
node_id = data.get("node")
logger.info(
f"节点 {node_id} 开始执行 (Prompt ID: {self.prompt_id})"
)
# 更新节点状态为运行中
await update_workflow_run_node_status(
self.run_id, node_id, "running"
)
# 处理节点执行完成
elif msg_type == "executed":
await self._handle_node_executed(data, aggregated_outputs)
# 处理整个工作流执行完成
elif msg_type == "executing" and data.get("node") is None:
logger.info(f"Prompt ID: {self.prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"WebSocket 连接已关闭 (Prompt ID: {self.prompt_id})。错误: {e}"
)
return aggregated_outputs
except Exception as e:
# 重新抛出我们自己的异常,或者处理其他意外错误
if not isinstance(e, ComfyUIExecutionError):
logger.error(
f"处理 prompt {self.prompt_id} 时发生意外错误: {e}"
)
raise e
except websockets.exceptions.InvalidURI as e:
logger.error(
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
)
raise e
except Exception as e:
logger.error(f"WebSocket连接出错: {e}")
raise
async def _handle_node_executed(self, data: dict, aggregated_outputs: dict):
"""处理节点执行完成事件"""
node_id = data.get("node")
output_data = data.get("output")
if not node_id or not output_data:
return
# 查找对应的节点
node = next(
(x for x in self.workflow.workflow_data.nodes if str(x.id) == node_id),
None,
)
if not node:
return
# 如果是输出节点,收集结果
if node.title and node.title.startswith(API_OUTPUT_PREFIX):
title = node.title.replace(API_OUTPUT_PREFIX, "")
aggregated_outputs[title] = output_data
logger.info(f"收到节点 {node_id} 的输出 (Prompt ID: {self.prompt_id})")
# 更新节点状态为完成
await update_workflow_run_node_status(
self.run_id,
node_id,
"completed",
output_data=json.dumps(output_data),
)
def get_status_info(self) -> dict:
"""获取运行状态信息"""
return {
"run_id": self.run_id,
"workflow_name": self.workflow.workflow_name,
"client_id": self.client_id,
"prompt_id": self.prompt_id,
"server_url": self.server.http_url if self.server else None,
}