ComfyUI-WorkflowPublisher/workflow_service/comfy/comfy_queue.py

444 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import logging
import os
import random
import uuid
import websockets
from datetime import datetime
from typing import Dict, Any
import aiohttp
from aiohttp import ClientTimeout
from workflow_service.comfy.comfy_workflow import build_prompt
from workflow_service.config import Settings
from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo
from workflow_service.database.api import (
create_workflow_run,
update_workflow_run_status,
create_workflow_run_nodes,
update_workflow_run_node_status,
get_workflow_run,
get_workflow_run_nodes,
)
settings = Settings()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
class ComfyUIExecutionError(Exception):
def __init__(self, error_data: dict):
self.error_data = error_data
# 创建一个对开发者友好的异常消息
message = (
f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, "
f"节点类型: {error_data.get('node_type')}. "
f"错误: {error_data.get('exception_message', 'N/A')}"
)
super().__init__(message)
# 全局任务队列管理器
class WorkflowQueueManager:
def __init__(self):
self.running_tasks = {} # server_url -> task_info
self.pending_tasks = [] # 等待队列
self.lock = asyncio.Lock()
async def add_task(
self,
workflow_name: str,
workflow_data: dict,
api_spec: dict,
request_data: dict,
):
"""添加新任务到队列"""
workflow_run_id = str(uuid.uuid4())
async with self.lock:
# 创建任务记录
await create_workflow_run(
workflow_run_id=workflow_run_id,
workflow_name=workflow_name,
workflow_json=json.dumps(workflow_data),
api_spec=json.dumps(api_spec),
request_data=json.dumps(request_data),
)
# 创建工作流节点记录
nodes_data = []
for node in workflow_data.get("nodes", []):
nodes_data.append(
{"id": str(node["id"]), "type": node.get("type", "unknown")}
)
await create_workflow_run_nodes(workflow_run_id, nodes_data)
# 添加到待处理队列
self.pending_tasks.append(workflow_run_id)
logger.info(
f"任务 {workflow_run_id} 已添加到队列,当前队列长度: {len(self.pending_tasks)}"
)
# 尝试处理队列
asyncio.create_task(self._process_queue())
return workflow_run_id
async def _process_queue(self):
"""处理队列中的任务"""
async with self.lock:
if not self.pending_tasks:
return
# 检查是否有空闲的服务器
available_servers = await self._get_available_servers()
if not available_servers:
logger.info("没有可用的服务器,等待中...")
return
# 获取一个待处理任务
workflow_run_id = self.pending_tasks.pop(0)
server = available_servers[0]
# 分配服务器资源
await server_manager.allocate_server(server.name)
# 标记任务为运行中
await update_workflow_run_status(
workflow_run_id, "running", server.http_url
)
self.running_tasks[server.http_url] = {
"workflow_run_id": workflow_run_id,
"started_at": datetime.now(),
"server_name": server.name,
}
# 启动任务执行
asyncio.create_task(self._execute_task(workflow_run_id, server))
async def _get_available_servers(self) -> list[ComfyUIServerInfo]:
"""获取可用的服务器"""
# 使用新的服务器管理器获取可用服务器
available_server = await server_manager.get_available_server()
if available_server:
return [available_server]
return []
async def _execute_task(self, workflow_run_id: str, server: ComfyUIServerInfo):
"""执行任务"""
cleanup_paths = []
try:
# 获取工作流数据
workflow_run = await get_workflow_run(workflow_run_id)
if not workflow_run:
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
workflow_data = json.loads(workflow_run["workflow_json"])
api_spec = json.loads(workflow_run["api_spec"])
request_data = json.loads(workflow_run["request_data"])
# 执行工作流
result = await _execute_prompt_on_server(
workflow_data, api_spec, request_data, server, workflow_run_id
)
# 保存处理后的结果到数据库
await update_workflow_run_status(
workflow_run_id,
"completed",
result=json.dumps(result, ensure_ascii=False),
)
except Exception as e:
logger.error(f"执行任务 {workflow_run_id} 时出错: {e}")
await update_workflow_run_status(
workflow_run_id, "failed", error_message=str(e)
)
finally:
# 清理临时文件
if cleanup_paths:
logger.info(f"正在清理 {len(cleanup_paths)} 个临时文件...")
for path in cleanup_paths:
try:
if os.path.exists(path):
os.remove(path)
logger.info(f" - 已删除: {path}")
except Exception as e:
logger.warning(f" - 删除 {path} 时出错: {e}")
# 清理运行状态
async with self.lock:
if server.http_url in self.running_tasks:
del self.running_tasks[server.http_url]
# 释放服务器资源
await server_manager.release_server(server.name)
# 继续处理队列
asyncio.create_task(self._process_queue())
async def get_task_status(self, workflow_run_id: str) -> dict:
"""获取任务状态"""
workflow_run = await get_workflow_run(workflow_run_id)
if not workflow_run:
return {"error": "任务不存在"}
nodes = await get_workflow_run_nodes(workflow_run_id)
result = {
"id": workflow_run_id,
"status": workflow_run["status"],
"created_at": workflow_run["created_at"],
"started_at": workflow_run["started_at"],
"completed_at": workflow_run["completed_at"],
"server_url": workflow_run["server_url"],
"error_message": workflow_run["error_message"],
"nodes": nodes,
}
# 如果任务完成,从数据库获取结果
if workflow_run["status"] == "completed" and workflow_run.get("result"):
try:
result["result"] = json.loads(workflow_run["result"])
except (json.JSONDecodeError, TypeError):
result["result"] = None
return result
async def get_server_status(
self, server: ComfyUIServerInfo, session: aiohttp.ClientSession
) -> dict[str, Any]:
"""
检查单个ComfyUI服务器的详细状态。
返回一个包含可达性、队列状态和详细队列内容的字典。
"""
# [BUG修复] 确保初始字典结构与成功时的结构一致以满足Pydantic模型
status_info = {
"is_reachable": False,
"is_free": False,
"queue_details": {"running_count": 0, "pending_count": 0},
}
try:
queue_url = f"{server.http_url}/queue"
async with session.get(queue_url, timeout=60) as response:
response.raise_for_status()
queue_data = await response.json()
status_info["is_reachable"] = True
running_count = len(queue_data.get("queue_running", []))
pending_count = len(queue_data.get("queue_pending", []))
status_info["queue_details"] = {
"running_count": running_count,
"pending_count": pending_count,
}
status_info["is_free"] = running_count == 0 and pending_count == 0
except Exception as e:
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
logger.warning(f"无法检查服务器 {server.http_url} 的队列状态: {e}")
return status_info
# 全局队列管理器实例
queue_manager = WorkflowQueueManager()
async def _execute_prompt_on_server(
workflow_data: dict,
api_spec: dict,
request_data: dict,
server: ComfyUIServerInfo,
workflow_run_id: str,
) -> dict:
"""
在指定的服务器上执行一个准备好的prompt。
现在支持节点级别的状态跟踪。
"""
client_id = str(uuid.uuid4())
# 构建prompt
prompt = await build_prompt(workflow_data, api_spec, request_data, server)
# 更新工作流运行状态记录prompt_id和client_id
await update_workflow_run_status(
workflow_run_id,
"running",
server.http_url,
None, # prompt_id将在_queue_prompt后更新
client_id,
)
# 提交到ComfyUI
prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url)
# 更新prompt_id
await update_workflow_run_status(
workflow_run_id, "running", server.http_url, prompt_id, client_id
)
logger.info(
f"工作流 {workflow_run_id} 已在 {server.http_url} 上入队Prompt ID: {prompt_id}"
)
# 获取执行结果,现在支持节点级别的状态跟踪
results = await _get_execution_results(
workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id
)
return results
async def _queue_prompt(
workflow: dict, prompt: dict, client_id: str, http_url: str
) -> str:
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
for node_id in prompt:
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
payload = {
"prompt": prompt,
"client_id": client_id,
"extra_data": {
"api_key_comfy_org": "",
"extra_pnginfo": {"workflow": workflow},
},
}
logger.info(f"提交到 ComfyUI /prompt 端点的payload: {json.dumps(payload)}")
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
prompt_url = f"{http_url}/prompt"
try:
async with session.post(prompt_url, json=payload) as response:
logger.info(f"ComfyUI /prompt 端点返回的响应: {response}")
response.raise_for_status()
result = await response.json()
logger.info(f"ComfyUI /prompt 端点返回的响应: {result}")
if "prompt_id" not in result:
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
return result["prompt_id"]
except Exception as e:
logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}")
raise e
async def _get_execution_results(
workflow_data: Dict,
prompt_id: str,
client_id: str,
ws_url: str,
workflow_run_id: str,
) -> dict:
"""
通过WebSocket连接到指定的ComfyUI服务器聚合执行结果。
现在支持节点级别的状态跟踪。
"""
full_ws_url = f"{ws_url}?clientId={client_id}"
aggregated_outputs = {}
try:
async with websockets.connect(full_ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data")
if not (data and data.get("prompt_id") == prompt_id):
continue
# 捕获并处理执行错误
if msg_type == "execution_error":
error_data = data
logger.error(
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
)
# 更新节点状态为失败
node_id = error_data.get("node_id")
if node_id:
await update_workflow_run_node_status(
workflow_run_id,
node_id,
"failed",
error_message=error_data.get(
"exception_message", "Unknown error"
),
)
# 抛出自定义异常,将错误详情传递出去
raise ComfyUIExecutionError(error_data)
# 处理节点开始执行
if msg_type == "executing" and data.get("node"):
node_id = data.get("node")
logger.info(f"节点 {node_id} 开始执行 (Prompt ID: {prompt_id})")
# 更新节点状态为运行中
await update_workflow_run_node_status(
workflow_run_id, node_id, "running"
)
# 处理节点执行完成
if msg_type == "executed":
node_id = data.get("node")
output_data = data.get("output")
if node_id and output_data:
node = next(
(
x
for x in workflow_data["nodes"]
if str(x["id"]) == node_id
),
None,
)
if (
node
and node.get("title", "")
and node["title"].startswith(API_OUTPUT_PREFIX)
):
title = node["title"].replace(API_OUTPUT_PREFIX, "")
aggregated_outputs[title] = output_data
logger.info(
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
)
# 更新节点状态为完成
await update_workflow_run_node_status(
workflow_run_id,
node_id,
"completed",
output_data=json.dumps(output_data),
)
# 处理整个工作流执行完成
elif msg_type == "executing" and data.get("node") is None:
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
)
return aggregated_outputs
except Exception as e:
# 重新抛出我们自己的异常,或者处理其他意外错误
if not isinstance(e, ComfyUIExecutionError):
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
raise e
except websockets.exceptions.InvalidURI as e:
logger.error(
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
)
raise e