ComfyUI-WorkflowPublisher/workflow_service/comfyui_client.py

188 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import random
import uuid
from typing import Dict, Any
import aiohttp
import websockets
from aiohttp import ClientTimeout
from workflow_service.config import Settings, ComfyUIServer
settings = Settings()
# [新增] 定义一个自定义异常用于封装来自ComfyUI的执行错误
class ComfyUIExecutionError(Exception):
def __init__(self, error_data: dict):
self.error_data = error_data
# 创建一个对开发者友好的异常消息
message = (
f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, "
f"节点类型: {error_data.get('node_type')}. "
f"错误: {error_data.get('exception_message', 'N/A')}"
)
super().__init__(message)
async def get_server_status(
server: ComfyUIServer, session: aiohttp.ClientSession
) -> Dict[str, Any]:
"""
检查单个ComfyUI服务器的详细状态。
返回一个包含可达性、队列状态和详细队列内容的字典。
"""
# [BUG修复] 确保初始字典结构与成功时的结构一致以满足Pydantic模型
status_info = {
"is_reachable": False,
"is_free": False,
"queue_details": {"running_count": 0, "pending_count": 0},
}
try:
queue_url = f"{server.http_url}/queue"
async with session.get(queue_url, timeout=60) as response:
response.raise_for_status()
queue_data = await response.json()
status_info["is_reachable"] = True
running_count = len(queue_data.get("queue_running", []))
pending_count = len(queue_data.get("queue_pending", []))
status_info["queue_details"] = {
"running_count": running_count,
"pending_count": pending_count,
}
status_info["is_free"] = running_count == 0 and pending_count == 0
except Exception as e:
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
print(f"警告: 无法检查服务器 {server.http_url} 的队列状态: {e}")
return status_info
async def select_server_for_execution() -> ComfyUIServer:
"""
智能选择一个ComfyUI服务器。
优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。
"""
servers = settings.SERVERS
if not servers:
raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。")
if len(servers) == 1:
return servers[0]
async with aiohttp.ClientSession() as session:
tasks = [get_server_status(server, session) for server in servers]
results = await asyncio.gather(*tasks)
free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]]
if free_servers:
selected_server = random.choice(free_servers)
print(
f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}"
)
return selected_server
else:
# 后备方案:选择一个可达的服务器,即使它很忙
reachable_servers = [
servers[i] for i, status in enumerate(results) if status["is_reachable"]
]
if reachable_servers:
selected_server = random.choice(reachable_servers)
print(
f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}"
)
return selected_server
else:
# 最坏情况:所有服务器都不可达,抛出异常
raise ConnectionError("所有配置的ComfyUI服务器都不可达。")
async def execute_prompt_on_server(prompt: Dict, server: ComfyUIServer) -> Dict:
"""
在指定的服务器上执行一个准备好的prompt。
"""
client_id = str(uuid.uuid4())
prompt_id = await _queue_prompt(prompt, client_id, server.http_url)
print(f"工作流已在 {server.http_url} 上入队Prompt ID: {prompt_id}")
results = await _get_execution_results(prompt_id, client_id, server.ws_url)
return results
async def _queue_prompt(prompt: dict, client_id: str, http_url: str) -> str:
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
for node_id in prompt:
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
payload = {"prompt": prompt, "client_id": client_id}
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
prompt_url = f"{http_url}/prompt"
async with session.post(prompt_url, json=payload) as response:
response.raise_for_status()
result = await response.json()
if "prompt_id" not in result:
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
return result["prompt_id"]
async def _get_execution_results(prompt_id: str, client_id: str, ws_url: str) -> dict:
"""
通过WebSocket连接到指定的ComfyUI服务器聚合执行结果。
[核心改动] 新增对 'execution_error' 消息的处理。
"""
full_ws_url = f"{ws_url}?clientId={client_id}"
aggregated_outputs = {}
try:
async with websockets.connect(full_ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data")
if not (data and data.get("prompt_id") == prompt_id):
continue
# [核心改动] 捕获并处理执行错误
if msg_type == "execution_error":
print(f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {data}")
# 抛出自定义异常,将错误详情传递出去
raise ComfyUIExecutionError(data)
if msg_type == "executed":
node_id = data.get("node")
output_data = data.get("output")
if node_id and output_data:
aggregated_outputs[node_id] = output_data
print(f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})")
elif msg_type == "executing" and data.get("node") is None:
print(f"Prompt ID: {prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
print(f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}")
return aggregated_outputs
except Exception as e:
# 重新抛出我们自己的异常,或者处理其他意外错误
if not isinstance(e, ComfyUIExecutionError):
print(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
raise e
except websockets.exceptions.InvalidURI as e:
print(
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
)
raise e
return aggregated_outputs