564 lines
22 KiB
Python
564 lines
22 KiB
Python
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import random
|
||
import uuid
|
||
import websockets
|
||
from datetime import datetime
|
||
from typing import Dict, Any, Optional
|
||
|
||
import aiohttp
|
||
from aiohttp import ClientTimeout
|
||
|
||
from workflow_service.comfy.comfy_workflow import build_prompt
|
||
from workflow_service.config import Settings
|
||
from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo
|
||
from workflow_service.database.api import (
|
||
create_workflow_run,
|
||
update_workflow_run_status,
|
||
create_workflow_run_nodes,
|
||
update_workflow_run_node_status,
|
||
get_workflow_run,
|
||
get_workflow_run_nodes,
|
||
)
|
||
|
||
settings = Settings()
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
API_INPUT_PREFIX = "INPUT_"
|
||
API_OUTPUT_PREFIX = "OUTPUT_"
|
||
|
||
|
||
class ComfyUIExecutionError(Exception):
|
||
def __init__(self, error_data: dict):
|
||
self.error_data = error_data
|
||
# 创建一个对开发者友好的异常消息
|
||
message = (
|
||
f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, "
|
||
f"节点类型: {error_data.get('node_type')}. "
|
||
f"错误: {error_data.get('exception_message', 'N/A')}"
|
||
)
|
||
super().__init__(message)
|
||
|
||
|
||
# 全局任务队列管理器
|
||
class WorkflowQueueManager:
|
||
def __init__(self, monitor_interval: int = 5):
|
||
self.running_tasks = {} # server_url -> task_info
|
||
self.pending_tasks = [] # 等待队列
|
||
self.lock = asyncio.Lock()
|
||
self._queue_monitor_task: Optional[asyncio.Task] = None
|
||
self._monitor_interval = monitor_interval # 监控间隔(秒)
|
||
self._last_processing_time = None # 上次处理队列的时间
|
||
|
||
async def start_queue_monitor(self):
|
||
"""启动队列监控任务"""
|
||
if self._queue_monitor_task is None or self._queue_monitor_task.done():
|
||
self._queue_monitor_task = asyncio.create_task(self._monitor_queue())
|
||
logger.info(f"队列监控任务已启动,监控间隔: {self._monitor_interval}秒")
|
||
|
||
async def stop_queue_monitor(self):
|
||
"""停止队列监控任务"""
|
||
if self._queue_monitor_task and not self._queue_monitor_task.done():
|
||
self._queue_monitor_task.cancel()
|
||
try:
|
||
await self._queue_monitor_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
logger.info("队列监控任务已停止")
|
||
|
||
async def set_monitor_interval(self, interval: int):
|
||
"""设置监控间隔"""
|
||
if interval < 1:
|
||
raise ValueError("监控间隔不能小于1秒")
|
||
self._monitor_interval = interval
|
||
logger.info(f"队列监控间隔已设置为 {interval} 秒")
|
||
|
||
# 如果监控任务正在运行,重启它以应用新间隔
|
||
if self._queue_monitor_task and not self._queue_monitor_task.done():
|
||
await self.stop_queue_monitor()
|
||
await self.start_queue_monitor()
|
||
|
||
async def trigger_queue_processing(self):
|
||
"""手动触发队列处理(用于测试或紧急情况)"""
|
||
logger.info("手动触发队列处理")
|
||
asyncio.create_task(self._process_queue())
|
||
|
||
async def get_queue_status(self) -> dict:
|
||
"""获取队列状态信息"""
|
||
async with self.lock:
|
||
return {
|
||
"pending_tasks_count": len(self.pending_tasks),
|
||
"running_tasks_count": len(self.running_tasks),
|
||
"monitor_active": self._queue_monitor_task is not None
|
||
and not self._queue_monitor_task.done(),
|
||
"monitor_interval": self._monitor_interval,
|
||
}
|
||
|
||
async def _monitor_queue(self):
|
||
"""定期监控队列,当有可用服务器时自动处理"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(self._monitor_interval)
|
||
|
||
# 每次定时检测触发时,打印当前状态
|
||
async with self.lock:
|
||
pending_count = len(self.pending_tasks)
|
||
running_count = len(self.running_tasks)
|
||
current_time = datetime.now()
|
||
timestamp = current_time.strftime("%Y-%m-%d %H:%M:%S")
|
||
time_only = current_time.strftime("%H:%M:%S")
|
||
|
||
# 计算距离上次处理队列的时间
|
||
time_since_last_processing = ""
|
||
if self._last_processing_time:
|
||
time_diff = current_time - self._last_processing_time
|
||
minutes = int(time_diff.total_seconds() // 60)
|
||
seconds = int(time_diff.total_seconds() % 60)
|
||
time_since_last_processing = (
|
||
f"上次处理: {minutes}分{seconds}秒前"
|
||
)
|
||
else:
|
||
time_since_last_processing = "上次处理: 从未"
|
||
|
||
# 格式化状态信息,提高可读性
|
||
status_info = (
|
||
f"⏰ 定时检测触发 [{timestamp}]\n"
|
||
f" 📋 待处理任务: {pending_count} 个\n"
|
||
f" 🚀 运行中任务: {running_count} 个\n"
|
||
f" ⏱️ 监控间隔: {self._monitor_interval} 秒\n"
|
||
f" 🕐 {time_since_last_processing}"
|
||
)
|
||
logger.info(status_info)
|
||
|
||
if self.pending_tasks:
|
||
available_servers = await self._get_available_servers()
|
||
if available_servers:
|
||
logger.info(
|
||
f"监控发现 {len(available_servers)} 个可用服务器,开始处理队列"
|
||
)
|
||
# 使用create_task避免阻塞监控循环
|
||
asyncio.create_task(self._process_queue())
|
||
else:
|
||
# 只在有任务等待时记录日志,避免日志噪音
|
||
if len(self.pending_tasks) > 0:
|
||
logger.debug(
|
||
f"队列中有 {len(self.pending_tasks)} 个待处理任务,但无可用服务器"
|
||
)
|
||
else:
|
||
# 队列为空时,减少日志输出
|
||
logger.debug("队列为空,继续监控中...")
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("队列监控任务被取消")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"队列监控任务出错: {e}")
|
||
# 出错时稍微延长等待时间,避免频繁重试
|
||
await asyncio.sleep(self._monitor_interval * 2)
|
||
|
||
async def add_task(
|
||
self,
|
||
workflow_name: str,
|
||
workflow_data: dict,
|
||
api_spec: dict,
|
||
request_data: dict,
|
||
):
|
||
"""添加新任务到队列"""
|
||
workflow_run_id = str(uuid.uuid4())
|
||
async with self.lock:
|
||
# 创建任务记录
|
||
await create_workflow_run(
|
||
workflow_run_id=workflow_run_id,
|
||
workflow_name=workflow_name,
|
||
workflow_json=json.dumps(workflow_data),
|
||
api_spec=json.dumps(api_spec),
|
||
request_data=json.dumps(request_data),
|
||
)
|
||
|
||
# 创建工作流节点记录
|
||
nodes_data = []
|
||
for node in workflow_data.get("nodes", []):
|
||
nodes_data.append(
|
||
{"id": str(node["id"]), "type": node.get("type", "unknown")}
|
||
)
|
||
await create_workflow_run_nodes(workflow_run_id, nodes_data)
|
||
|
||
# 添加到待处理队列
|
||
self.pending_tasks.append(workflow_run_id)
|
||
logger.info(
|
||
f"任务 {workflow_run_id} 已添加到队列,当前队列长度: {len(self.pending_tasks)}"
|
||
)
|
||
|
||
# 注意:不再手动触发队列处理,由定时监控自动处理
|
||
# 这样可以避免在没有可用服务器时的不必要尝试
|
||
|
||
return workflow_run_id
|
||
|
||
async def _process_queue(self):
|
||
"""处理队列中的任务"""
|
||
async with self.lock:
|
||
# 更新上次处理队列的时间
|
||
self._last_processing_time = datetime.now()
|
||
|
||
if not self.pending_tasks:
|
||
logger.debug("队列为空,无需处理")
|
||
return
|
||
|
||
# 检查是否有空闲的服务器
|
||
available_servers = await self._get_available_servers()
|
||
if not available_servers:
|
||
logger.info(
|
||
f"队列中有 {len(self.pending_tasks)} 个待处理任务,但没有可用的服务器"
|
||
)
|
||
return
|
||
|
||
# 获取一个待处理任务
|
||
workflow_run_id = self.pending_tasks.pop(0)
|
||
server = available_servers[0]
|
||
|
||
logger.info(
|
||
f"开始处理任务 {workflow_run_id},使用服务器 {server.name} ({server.http_url})"
|
||
)
|
||
|
||
# 分配服务器资源
|
||
await server_manager.allocate_server(server.name)
|
||
|
||
# 标记任务为运行中
|
||
await update_workflow_run_status(
|
||
workflow_run_id, "running", server.http_url
|
||
)
|
||
self.running_tasks[server.http_url] = {
|
||
"workflow_run_id": workflow_run_id,
|
||
"started_at": datetime.now(),
|
||
"server_name": server.name,
|
||
}
|
||
|
||
# 启动任务执行
|
||
asyncio.create_task(self._execute_task(workflow_run_id, server))
|
||
|
||
logger.info(
|
||
f"任务 {workflow_run_id} 已分配到服务器 {server.name},当前运行中任务数: {len(self.running_tasks)}"
|
||
)
|
||
|
||
async def _get_available_servers(self) -> list[ComfyUIServerInfo]:
|
||
"""获取可用的服务器"""
|
||
# 使用新的服务器管理器获取可用服务器
|
||
available_server = await server_manager.get_available_server()
|
||
if available_server:
|
||
return [available_server]
|
||
return []
|
||
|
||
async def _execute_task(self, workflow_run_id: str, server: ComfyUIServerInfo):
|
||
"""执行任务"""
|
||
cleanup_paths = []
|
||
try:
|
||
# 获取工作流数据
|
||
workflow_run = await get_workflow_run(workflow_run_id)
|
||
if not workflow_run:
|
||
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
|
||
|
||
workflow_data = json.loads(workflow_run["workflow_json"])
|
||
api_spec = json.loads(workflow_run["api_spec"])
|
||
request_data = json.loads(workflow_run["request_data"])
|
||
|
||
# 执行工作流
|
||
result = await _execute_prompt_on_server(
|
||
workflow_data, api_spec, request_data, server, workflow_run_id
|
||
)
|
||
|
||
# 保存处理后的结果到数据库
|
||
await update_workflow_run_status(
|
||
workflow_run_id,
|
||
"completed",
|
||
result=json.dumps(result, ensure_ascii=False),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"执行任务 {workflow_run_id} 时出错: {e}")
|
||
await update_workflow_run_status(
|
||
workflow_run_id, "failed", error_message=str(e)
|
||
)
|
||
finally:
|
||
# 清理临时文件
|
||
if cleanup_paths:
|
||
logger.info(f"正在清理 {len(cleanup_paths)} 个临时文件...")
|
||
for path in cleanup_paths:
|
||
try:
|
||
if os.path.exists(path):
|
||
os.remove(path)
|
||
logger.info(f" - 已删除: {path}")
|
||
except Exception as e:
|
||
logger.warning(f" - 删除 {path} 时出错: {e}")
|
||
|
||
# 清理运行状态
|
||
async with self.lock:
|
||
if server.http_url in self.running_tasks:
|
||
del self.running_tasks[server.http_url]
|
||
|
||
# 释放服务器资源
|
||
await server_manager.release_server(server.name)
|
||
|
||
async def get_task_status(self, workflow_run_id: str) -> dict:
|
||
"""获取任务状态"""
|
||
workflow_run = await get_workflow_run(workflow_run_id)
|
||
if not workflow_run:
|
||
return {"error": "任务不存在"}
|
||
nodes = await get_workflow_run_nodes(workflow_run_id)
|
||
|
||
result = {
|
||
"id": workflow_run_id,
|
||
"status": workflow_run["status"],
|
||
"created_at": workflow_run["created_at"],
|
||
"started_at": workflow_run["started_at"],
|
||
"completed_at": workflow_run["completed_at"],
|
||
"server_url": workflow_run["server_url"],
|
||
"error_message": workflow_run["error_message"],
|
||
"nodes": nodes,
|
||
}
|
||
|
||
# 如果任务完成,从数据库获取结果
|
||
if workflow_run["status"] == "completed" and workflow_run.get("result"):
|
||
try:
|
||
result["result"] = json.loads(workflow_run["result"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
result["result"] = None
|
||
|
||
return result
|
||
|
||
async def get_server_status(
|
||
self, server: ComfyUIServerInfo, session: aiohttp.ClientSession
|
||
) -> dict[str, Any]:
|
||
"""
|
||
检查单个ComfyUI服务器的详细状态。
|
||
返回一个包含可达性、队列状态和详细队列内容的字典。
|
||
"""
|
||
# [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型
|
||
status_info = {
|
||
"is_reachable": False,
|
||
"is_free": False,
|
||
"queue_details": {"running_count": 0, "pending_count": 0},
|
||
}
|
||
try:
|
||
queue_url = f"{server.http_url}/queue"
|
||
async with session.get(queue_url, timeout=60) as response:
|
||
response.raise_for_status()
|
||
queue_data = await response.json()
|
||
|
||
status_info["is_reachable"] = True
|
||
|
||
running_count = len(queue_data.get("queue_running", []))
|
||
pending_count = len(queue_data.get("queue_pending", []))
|
||
|
||
status_info["queue_details"] = {
|
||
"running_count": running_count,
|
||
"pending_count": pending_count,
|
||
}
|
||
|
||
status_info["is_free"] = running_count == 0 and pending_count == 0
|
||
|
||
except Exception as e:
|
||
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
|
||
logger.warning(f"无法检查服务器 {server.http_url} 的队列状态: {e}")
|
||
|
||
return status_info
|
||
|
||
|
||
# 全局队列管理器实例
|
||
# 从配置中读取监控间隔,默认为5秒
|
||
queue_manager = WorkflowQueueManager(monitor_interval=5)
|
||
|
||
|
||
async def _execute_prompt_on_server(
|
||
workflow_data: dict,
|
||
api_spec: dict,
|
||
request_data: dict,
|
||
server: ComfyUIServerInfo,
|
||
workflow_run_id: str,
|
||
) -> dict:
|
||
"""
|
||
在指定的服务器上执行一个准备好的prompt。
|
||
现在支持节点级别的状态跟踪。
|
||
"""
|
||
client_id = str(uuid.uuid4())
|
||
|
||
# 构建prompt
|
||
prompt = await build_prompt(workflow_data, api_spec, request_data, server)
|
||
|
||
# 更新工作流运行状态,记录prompt_id和client_id
|
||
await update_workflow_run_status(
|
||
workflow_run_id,
|
||
"running",
|
||
server.http_url,
|
||
None, # prompt_id将在_queue_prompt后更新
|
||
client_id,
|
||
)
|
||
|
||
# 提交到ComfyUI
|
||
prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url)
|
||
|
||
# 更新prompt_id
|
||
await update_workflow_run_status(
|
||
workflow_run_id, "running", server.http_url, prompt_id, client_id
|
||
)
|
||
|
||
logger.info(
|
||
f"工作流 {workflow_run_id} 已在 {server.http_url} 上入队,Prompt ID: {prompt_id}"
|
||
)
|
||
|
||
# 获取执行结果,现在支持节点级别的状态跟踪
|
||
results = await _get_execution_results(
|
||
workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id
|
||
)
|
||
return results
|
||
|
||
|
||
async def _queue_prompt(
|
||
workflow: dict, prompt: dict, client_id: str, http_url: str
|
||
) -> str:
|
||
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
|
||
for node_id in prompt:
|
||
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
|
||
payload = {
|
||
"prompt": prompt,
|
||
"client_id": client_id,
|
||
"extra_data": {
|
||
"api_key_comfy_org": "",
|
||
"extra_pnginfo": {"workflow": workflow},
|
||
},
|
||
}
|
||
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
|
||
prompt_url = f"{http_url}/prompt"
|
||
try:
|
||
async with session.post(prompt_url, json=payload) as response:
|
||
logger.info(f"ComfyUI /prompt 端点返回的响应: {response}")
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
logger.info(f"ComfyUI /prompt 端点返回的响应: {result}")
|
||
if "prompt_id" not in result:
|
||
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
|
||
return result["prompt_id"]
|
||
except Exception as e:
|
||
logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}")
|
||
raise e
|
||
|
||
|
||
async def _get_execution_results(
|
||
workflow_data: Dict,
|
||
prompt_id: str,
|
||
client_id: str,
|
||
ws_url: str,
|
||
workflow_run_id: str,
|
||
) -> dict:
|
||
"""
|
||
通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。
|
||
现在支持节点级别的状态跟踪。
|
||
"""
|
||
full_ws_url = f"{ws_url}?clientId={client_id}"
|
||
aggregated_outputs = {}
|
||
|
||
try:
|
||
async with websockets.connect(full_ws_url) as websocket:
|
||
while True:
|
||
try:
|
||
out = await websocket.recv()
|
||
if not isinstance(out, str):
|
||
continue
|
||
|
||
message = json.loads(out)
|
||
msg_type = message.get("type")
|
||
data = message.get("data")
|
||
|
||
if not (data and data.get("prompt_id") == prompt_id):
|
||
continue
|
||
|
||
# 捕获并处理执行错误
|
||
if msg_type == "execution_error":
|
||
error_data = data
|
||
logger.error(
|
||
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
|
||
)
|
||
|
||
# 更新节点状态为失败
|
||
node_id = error_data.get("node_id")
|
||
if node_id:
|
||
await update_workflow_run_node_status(
|
||
workflow_run_id,
|
||
node_id,
|
||
"failed",
|
||
error_message=error_data.get(
|
||
"exception_message", "Unknown error"
|
||
),
|
||
)
|
||
|
||
# 抛出自定义异常,将错误详情传递出去
|
||
raise ComfyUIExecutionError(error_data)
|
||
|
||
# 处理节点开始执行
|
||
if msg_type == "executing" and data.get("node"):
|
||
node_id = data.get("node")
|
||
logger.info(f"节点 {node_id} 开始执行 (Prompt ID: {prompt_id})")
|
||
|
||
# 更新节点状态为运行中
|
||
await update_workflow_run_node_status(
|
||
workflow_run_id, node_id, "running"
|
||
)
|
||
|
||
# 处理节点执行完成
|
||
if msg_type == "executed":
|
||
node_id = data.get("node")
|
||
output_data = data.get("output")
|
||
if node_id and output_data:
|
||
node = next(
|
||
(
|
||
x
|
||
for x in workflow_data["nodes"]
|
||
if str(x["id"]) == node_id
|
||
),
|
||
None,
|
||
)
|
||
if (
|
||
node
|
||
and node.get("title", "")
|
||
and node["title"].startswith(API_OUTPUT_PREFIX)
|
||
):
|
||
title = node["title"].replace(API_OUTPUT_PREFIX, "")
|
||
aggregated_outputs[title] = output_data
|
||
logger.info(
|
||
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
|
||
)
|
||
|
||
# 更新节点状态为完成
|
||
await update_workflow_run_node_status(
|
||
workflow_run_id,
|
||
node_id,
|
||
"completed",
|
||
output_data=json.dumps(output_data),
|
||
)
|
||
|
||
# 处理整个工作流执行完成
|
||
elif msg_type == "executing" and data.get("node") is None:
|
||
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
|
||
return aggregated_outputs
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
logger.warning(
|
||
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
|
||
)
|
||
return aggregated_outputs
|
||
except Exception as e:
|
||
# 重新抛出我们自己的异常,或者处理其他意外错误
|
||
if not isinstance(e, ComfyUIExecutionError):
|
||
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
|
||
raise e
|
||
|
||
except websockets.exceptions.InvalidURI as e:
|
||
logger.error(
|
||
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
|
||
)
|
||
raise e
|