188 lines
7.3 KiB
Python
188 lines
7.3 KiB
Python
import asyncio
|
||
import json
|
||
import random
|
||
import uuid
|
||
from typing import Dict, Any
|
||
|
||
import aiohttp
|
||
import websockets
|
||
from aiohttp import ClientTimeout
|
||
|
||
from workflow_service.config import Settings, ComfyUIServer
|
||
|
||
settings = Settings()
|
||
|
||
|
||
# [新增] 定义一个自定义异常,用于封装来自ComfyUI的执行错误
|
||
class ComfyUIExecutionError(Exception):
|
||
def __init__(self, error_data: dict):
|
||
self.error_data = error_data
|
||
# 创建一个对开发者友好的异常消息
|
||
message = (
|
||
f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, "
|
||
f"节点类型: {error_data.get('node_type')}. "
|
||
f"错误: {error_data.get('exception_message', 'N/A')}"
|
||
)
|
||
super().__init__(message)
|
||
|
||
|
||
async def get_server_status(
|
||
server: ComfyUIServer, session: aiohttp.ClientSession
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
检查单个ComfyUI服务器的详细状态。
|
||
返回一个包含可达性、队列状态和详细队列内容的字典。
|
||
"""
|
||
# [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型
|
||
status_info = {
|
||
"is_reachable": False,
|
||
"is_free": False,
|
||
"queue_details": {"running_count": 0, "pending_count": 0},
|
||
}
|
||
try:
|
||
queue_url = f"{server.http_url}/queue"
|
||
async with session.get(queue_url, timeout=60) as response:
|
||
response.raise_for_status()
|
||
queue_data = await response.json()
|
||
|
||
status_info["is_reachable"] = True
|
||
|
||
running_count = len(queue_data.get("queue_running", []))
|
||
pending_count = len(queue_data.get("queue_pending", []))
|
||
|
||
status_info["queue_details"] = {
|
||
"running_count": running_count,
|
||
"pending_count": pending_count,
|
||
}
|
||
|
||
status_info["is_free"] = running_count == 0 and pending_count == 0
|
||
|
||
except Exception as e:
|
||
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
|
||
print(f"警告: 无法检查服务器 {server.http_url} 的队列状态: {e}")
|
||
|
||
return status_info
|
||
|
||
|
||
async def select_server_for_execution() -> ComfyUIServer:
|
||
"""
|
||
智能选择一个ComfyUI服务器。
|
||
优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。
|
||
"""
|
||
servers = settings.SERVERS
|
||
if not servers:
|
||
raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。")
|
||
if len(servers) == 1:
|
||
return servers[0]
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
tasks = [get_server_status(server, session) for server in servers]
|
||
results = await asyncio.gather(*tasks)
|
||
|
||
free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]]
|
||
|
||
if free_servers:
|
||
selected_server = random.choice(free_servers)
|
||
print(
|
||
f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}"
|
||
)
|
||
return selected_server
|
||
else:
|
||
# 后备方案:选择一个可达的服务器,即使它很忙
|
||
reachable_servers = [
|
||
servers[i] for i, status in enumerate(results) if status["is_reachable"]
|
||
]
|
||
if reachable_servers:
|
||
selected_server = random.choice(reachable_servers)
|
||
print(
|
||
f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}"
|
||
)
|
||
return selected_server
|
||
else:
|
||
# 最坏情况:所有服务器都不可达,抛出异常
|
||
raise ConnectionError("所有配置的ComfyUI服务器都不可达。")
|
||
|
||
|
||
async def execute_prompt_on_server(prompt: Dict, server: ComfyUIServer) -> Dict:
|
||
"""
|
||
在指定的服务器上执行一个准备好的prompt。
|
||
"""
|
||
client_id = str(uuid.uuid4())
|
||
prompt_id = await _queue_prompt(prompt, client_id, server.http_url)
|
||
print(f"工作流已在 {server.http_url} 上入队,Prompt ID: {prompt_id}")
|
||
results = await _get_execution_results(prompt_id, client_id, server.ws_url)
|
||
return results
|
||
|
||
|
||
async def _queue_prompt(prompt: dict, client_id: str, http_url: str) -> str:
|
||
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
|
||
for node_id in prompt:
|
||
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
|
||
payload = {"prompt": prompt, "client_id": client_id}
|
||
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
|
||
prompt_url = f"{http_url}/prompt"
|
||
async with session.post(prompt_url, json=payload) as response:
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
if "prompt_id" not in result:
|
||
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
|
||
return result["prompt_id"]
|
||
|
||
|
||
async def _get_execution_results(prompt_id: str, client_id: str, ws_url: str) -> dict:
|
||
"""
|
||
通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。
|
||
[核心改动] 新增对 'execution_error' 消息的处理。
|
||
"""
|
||
full_ws_url = f"{ws_url}?clientId={client_id}"
|
||
aggregated_outputs = {}
|
||
|
||
try:
|
||
async with websockets.connect(full_ws_url) as websocket:
|
||
while True:
|
||
try:
|
||
out = await websocket.recv()
|
||
if not isinstance(out, str):
|
||
continue
|
||
|
||
message = json.loads(out)
|
||
msg_type = message.get("type")
|
||
data = message.get("data")
|
||
|
||
if not (data and data.get("prompt_id") == prompt_id):
|
||
continue
|
||
|
||
# [核心改动] 捕获并处理执行错误
|
||
if msg_type == "execution_error":
|
||
print(f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {data}")
|
||
# 抛出自定义异常,将错误详情传递出去
|
||
raise ComfyUIExecutionError(data)
|
||
|
||
if msg_type == "executed":
|
||
node_id = data.get("node")
|
||
output_data = data.get("output")
|
||
if node_id and output_data:
|
||
aggregated_outputs[node_id] = output_data
|
||
print(f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})")
|
||
|
||
elif msg_type == "executing" and data.get("node") is None:
|
||
print(f"Prompt ID: {prompt_id} 执行完成。")
|
||
return aggregated_outputs
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
print(f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}")
|
||
return aggregated_outputs
|
||
except Exception as e:
|
||
# 重新抛出我们自己的异常,或者处理其他意外错误
|
||
if not isinstance(e, ComfyUIExecutionError):
|
||
print(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
|
||
raise e
|
||
|
||
except websockets.exceptions.InvalidURI as e:
|
||
print(
|
||
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
|
||
)
|
||
raise e
|
||
|
||
return aggregated_outputs
|