346 lines
12 KiB
Python
346 lines
12 KiB
Python
import json
|
||
import logging
|
||
import random
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
import aiohttp
|
||
import websockets
|
||
from aiohttp import ClientTimeout
|
||
|
||
from app.comfy.comfy_workflow import ComfyWorkflow, API_OUTPUT_PREFIX
|
||
from app.comfy.comfy_server import ComfyUIServerInfo, server_manager
|
||
from app.database.api import (
|
||
create_workflow_run,
|
||
create_workflow_run_nodes,
|
||
get_workflow_run,
|
||
update_workflow_run_status,
|
||
update_workflow_run_node_status,
|
||
)
|
||
|
||
|
||
class ComfyUIExecutionError(Exception):
|
||
"""ComfyUI执行错误"""
|
||
|
||
def __init__(self, error_data):
|
||
self.error_data = error_data
|
||
super().__init__(f"ComfyUI execution error: {error_data}")
|
||
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ComfyRun:
|
||
"""
|
||
ComfyUI工作流运行实例,封装单个工作流运行的所有操作
|
||
"""
|
||
|
||
def __init__(self, workflow: ComfyWorkflow, run_id: str, request_data: dict):
|
||
"""
|
||
初始化工作流运行实例
|
||
|
||
Args:
|
||
workflow: ComfyWorkflow实例
|
||
run_id: 运行ID
|
||
request_data: 请求数据
|
||
"""
|
||
self.workflow = workflow
|
||
self.run_id = run_id
|
||
self.request_data = request_data
|
||
self.client_id = str(uuid.uuid4())
|
||
self.prompt_id: Optional[str] = None
|
||
self.server: Optional[ComfyUIServerInfo] = None
|
||
|
||
@classmethod
|
||
async def create(cls, workflow: ComfyWorkflow, request_data: dict) -> "ComfyRun":
|
||
"""
|
||
创建新的ComfyRun实例并保存到数据库
|
||
|
||
Args:
|
||
workflow: ComfyWorkflow实例
|
||
request_data: 请求数据
|
||
|
||
Returns:
|
||
ComfyRun实例
|
||
"""
|
||
run_id = str(uuid.uuid4())
|
||
|
||
# 创建任务记录
|
||
await create_workflow_run(
|
||
workflow_run_id=run_id,
|
||
workflow_name=workflow.workflow_name,
|
||
workflow_json=json.dumps(workflow.workflow_data.model_dump()),
|
||
api_spec=json.dumps(workflow.get_api_spec().model_dump()),
|
||
request_data=json.dumps(request_data),
|
||
)
|
||
|
||
# 创建工作流节点记录
|
||
nodes_data = []
|
||
for node in workflow.workflow_data.nodes:
|
||
nodes_data.append({"id": str(node.id), "type": node.type})
|
||
await create_workflow_run_nodes(run_id, nodes_data)
|
||
|
||
return cls(workflow, run_id, request_data)
|
||
|
||
@classmethod
|
||
async def from_run_id(cls, run_id: str) -> Optional["ComfyRun"]:
|
||
"""
|
||
从run_id创建ComfyRun实例
|
||
|
||
Args:
|
||
run_id: 运行ID
|
||
|
||
Returns:
|
||
ComfyRun实例,如果找不到则返回None
|
||
"""
|
||
workflow_run = await get_workflow_run(run_id)
|
||
if not workflow_run:
|
||
return None
|
||
|
||
workflow_data = json.loads(workflow_run.workflow_json)
|
||
workflow = ComfyWorkflow(workflow_run.workflow_name, workflow_data)
|
||
request_data = json.loads(workflow_run.request_data)
|
||
|
||
return cls(workflow, run_id, request_data)
|
||
|
||
async def execute(self, server: ComfyUIServerInfo) -> dict:
|
||
"""
|
||
在指定的服务器上执行工作流
|
||
|
||
Args:
|
||
server: ComfyUI服务器信息
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
self.server = server
|
||
|
||
try:
|
||
# 分配服务器资源
|
||
await server_manager.allocate_server(server.name)
|
||
|
||
# 构建prompt
|
||
prompt = await self.workflow.build_prompt(server, self.request_data)
|
||
|
||
# 更新运行状态为running
|
||
await self._update_status(
|
||
"running", server.http_url, client_id=self.client_id
|
||
)
|
||
|
||
# 提交到ComfyUI
|
||
self.prompt_id = await self._queue_prompt(prompt, server.http_url)
|
||
|
||
# 更新prompt_id
|
||
await self._update_status(
|
||
"running", server.http_url, self.prompt_id, self.client_id
|
||
)
|
||
|
||
logger.info(
|
||
f"工作流 {self.run_id} 已在 {server.http_url} 上入队,Prompt ID: {self.prompt_id}"
|
||
)
|
||
|
||
# 获取执行结果
|
||
results = await self._get_execution_results(server.ws_url)
|
||
|
||
# 标记完成
|
||
await self._update_status(
|
||
"completed", result=json.dumps(results, ensure_ascii=False)
|
||
)
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
# 标记失败
|
||
await self._update_status("failed", error_message=str(e))
|
||
raise
|
||
|
||
finally:
|
||
# 释放服务器资源
|
||
try:
|
||
await server_manager.release_server(server.name)
|
||
except Exception as e:
|
||
logger.error(f"释放服务器资源时出错: {e}")
|
||
|
||
async def _update_status(
|
||
self,
|
||
status: str,
|
||
server_url: Optional[str] = None,
|
||
prompt_id: Optional[str] = None,
|
||
client_id: Optional[str] = None,
|
||
result: Optional[str] = None,
|
||
error_message: Optional[str] = None,
|
||
):
|
||
"""更新工作流运行状态"""
|
||
await update_workflow_run_status(
|
||
self.run_id, status, server_url, prompt_id, client_id, error_message, result
|
||
)
|
||
|
||
async def _queue_prompt(self, prompt: dict, http_url: str) -> str:
|
||
"""提交工作流到ComfyUI服务器"""
|
||
# 添加随机缓存破坏器避免缓存
|
||
for node_id in prompt:
|
||
prompt[node_id]["inputs"][
|
||
f"cache_buster_{uuid.uuid4().hex}"
|
||
] = random.random()
|
||
|
||
payload = {
|
||
"prompt": prompt,
|
||
"client_id": self.client_id,
|
||
"extra_data": {
|
||
"api_key_comfy_org": "",
|
||
"extra_pnginfo": {"workflow": self.workflow.workflow_data.model_dump()},
|
||
},
|
||
}
|
||
|
||
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
|
||
prompt_url = f"{http_url}/prompt"
|
||
try:
|
||
async with session.post(prompt_url, json=payload) as response:
|
||
logger.info(f"ComfyUI /prompt 端点返回的响应: {response}")
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
logger.info(f"ComfyUI /prompt 端点返回的响应: {result}")
|
||
|
||
if "prompt_id" not in result:
|
||
raise Exception(
|
||
f"从 ComfyUI /prompt 端点返回的响应无效: {result}"
|
||
)
|
||
|
||
return result["prompt_id"]
|
||
|
||
except Exception as e:
|
||
logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}")
|
||
raise
|
||
|
||
async def _get_execution_results(self, ws_url: str) -> dict:
|
||
"""
|
||
通过WebSocket连接获取执行结果,支持节点级别的状态跟踪
|
||
"""
|
||
aggregated_outputs = {}
|
||
full_ws_url = f"{ws_url}?clientId={self.client_id}"
|
||
|
||
try:
|
||
async with websockets.connect(full_ws_url) as websocket:
|
||
logger.info(f"已连接到WebSocket: {full_ws_url}")
|
||
|
||
while True:
|
||
try:
|
||
out = await websocket.recv()
|
||
if not isinstance(out, str):
|
||
continue
|
||
|
||
message = json.loads(out)
|
||
msg_type = message.get("type")
|
||
data = message.get("data")
|
||
|
||
if not (data and data.get("prompt_id") == self.prompt_id):
|
||
continue
|
||
|
||
# 捕获并处理执行错误
|
||
if msg_type == "execution_error":
|
||
error_data = data
|
||
logger.error(
|
||
f"ComfyUI执行错误 (Prompt ID: {self.prompt_id}): {error_data}"
|
||
)
|
||
|
||
# 更新节点状态为失败
|
||
node_id = error_data.get("node_id")
|
||
if node_id:
|
||
await update_workflow_run_node_status(
|
||
self.run_id,
|
||
node_id,
|
||
"failed",
|
||
error_message=error_data.get(
|
||
"exception_message", "Unknown error"
|
||
),
|
||
)
|
||
|
||
# 抛出自定义异常,将错误详情传递出去
|
||
raise ComfyUIExecutionError(error_data)
|
||
|
||
# 处理节点开始执行
|
||
if msg_type == "executing" and data.get("node"):
|
||
node_id = data.get("node")
|
||
logger.info(
|
||
f"节点 {node_id} 开始执行 (Prompt ID: {self.prompt_id})"
|
||
)
|
||
|
||
# 更新节点状态为运行中
|
||
await update_workflow_run_node_status(
|
||
self.run_id, node_id, "running"
|
||
)
|
||
|
||
# 处理节点执行完成
|
||
elif msg_type == "executed":
|
||
await self._handle_node_executed(data, aggregated_outputs)
|
||
|
||
# 处理整个工作流执行完成
|
||
elif msg_type == "executing" and data.get("node") is None:
|
||
logger.info(f"Prompt ID: {self.prompt_id} 执行完成。")
|
||
return aggregated_outputs
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
logger.warning(
|
||
f"WebSocket 连接已关闭 (Prompt ID: {self.prompt_id})。错误: {e}"
|
||
)
|
||
return aggregated_outputs
|
||
except Exception as e:
|
||
# 重新抛出我们自己的异常,或者处理其他意外错误
|
||
if not isinstance(e, ComfyUIExecutionError):
|
||
logger.error(
|
||
f"处理 prompt {self.prompt_id} 时发生意外错误: {e}"
|
||
)
|
||
raise e
|
||
|
||
except websockets.exceptions.InvalidURI as e:
|
||
logger.error(
|
||
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
|
||
)
|
||
raise e
|
||
except Exception as e:
|
||
logger.error(f"WebSocket连接出错: {e}")
|
||
raise
|
||
|
||
async def _handle_node_executed(self, data: dict, aggregated_outputs: dict):
|
||
"""处理节点执行完成事件"""
|
||
node_id = data.get("node")
|
||
output_data = data.get("output")
|
||
|
||
if not node_id or not output_data:
|
||
return
|
||
|
||
# 查找对应的节点
|
||
node = next(
|
||
(x for x in self.workflow.workflow_data.nodes if str(x.id) == node_id),
|
||
None,
|
||
)
|
||
|
||
if not node:
|
||
return
|
||
|
||
# 如果是输出节点,收集结果
|
||
if node.title and node.title.startswith(API_OUTPUT_PREFIX):
|
||
title = node.title.replace(API_OUTPUT_PREFIX, "")
|
||
aggregated_outputs[title] = output_data
|
||
|
||
logger.info(f"收到节点 {node_id} 的输出 (Prompt ID: {self.prompt_id})")
|
||
|
||
# 更新节点状态为完成
|
||
await update_workflow_run_node_status(
|
||
self.run_id,
|
||
node_id,
|
||
"completed",
|
||
output_data=json.dumps(output_data),
|
||
)
|
||
|
||
def get_status_info(self) -> dict:
|
||
"""获取运行状态信息"""
|
||
return {
|
||
"run_id": self.run_id,
|
||
"workflow_name": self.workflow.workflow_name,
|
||
"client_id": self.client_id,
|
||
"prompt_id": self.prompt_id,
|
||
"server_url": self.server.http_url if self.server else None,
|
||
}
|