ComfyUI-WorkflowPublisher/workflow_service/comfyui_client.py

142 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import random
import uuid
from typing import Dict, Any
import aiohttp
import websockets
from aiohttp import ClientTimeout
from workflow_service.config import Settings, ComfyUIServer
settings = Settings()
async def get_server_status(server: ComfyUIServer, session: aiohttp.ClientSession) -> Dict[str, Any]:
"""
检查单个ComfyUI服务器的详细状态。
返回一个包含可达性、队列状态和详细队列内容的字典。
"""
# [BUG修复] 确保初始字典结构与成功时的结构一致以满足Pydantic模型
status_info = {
"is_reachable": False,
"is_free": False,
"queue_details": {
"running_count": 0,
"pending_count": 0
}
}
try:
queue_url = f"{server.http_url}/queue"
async with session.get(queue_url, timeout=60) as response:
response.raise_for_status()
queue_data = await response.json()
status_info["is_reachable"] = True
running_count = len(queue_data.get("queue_running", []))
pending_count = len(queue_data.get("queue_pending", []))
status_info["queue_details"] = {
"running_count": running_count,
"pending_count": pending_count
}
status_info["is_free"] = (running_count == 0 and pending_count == 0)
except Exception as e:
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
print(f"警告: 无法检查服务器 {server.http_url} 的队列状态: {e}")
return status_info
async def select_server_for_execution() -> ComfyUIServer:
"""
智能选择一个ComfyUI服务器。
优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。
"""
servers = settings.SERVERS
if not servers:
raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。")
if len(servers) == 1:
return servers[0]
async with aiohttp.ClientSession() as session:
tasks = [get_server_status(server, session) for server in servers]
results = await asyncio.gather(*tasks)
free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]]
if free_servers:
selected_server = random.choice(free_servers)
print(f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}")
return selected_server
else:
# 后备方案:选择一个可达的服务器,即使它很忙
reachable_servers = [servers[i] for i, status in enumerate(results) if status["is_reachable"]]
if reachable_servers:
selected_server = random.choice(reachable_servers)
print(f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}")
return selected_server
else:
# 最坏情况:所有服务器都不可达,抛出异常
raise ConnectionError("所有配置的ComfyUI服务器都不可达。")
async def queue_prompt(prompt: dict, client_id: str, http_url: str) -> str:
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
for node_id in prompt:
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
payload = {"prompt": prompt, "client_id": client_id}
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
prompt_url = f"{http_url}/prompt"
async with session.post(prompt_url, json=payload) as response:
response.raise_for_status()
result = await response.json()
if "prompt_id" not in result:
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
return result["prompt_id"]
async def get_execution_results(prompt_id: str, client_id: str, ws_url: str) -> dict:
"""通过WebSocket连接到指定的ComfyUI服务器聚合执行结果。"""
full_ws_url = f"{ws_url}?clientId={client_id}"
aggregated_outputs = {}
async with websockets.connect(full_ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get('type')
data = message.get('data')
if data and data.get('prompt_id') == prompt_id:
if msg_type == 'executed':
node_id = data.get('node')
output_data = data.get('output')
if node_id and output_data:
aggregated_outputs[node_id] = output_data
print(f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})")
elif msg_type == 'executing' and data.get('node') is None:
print(f"Prompt ID: {prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
print(f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}")
return aggregated_outputs
except Exception as e:
print(f"处理 prompt {prompt_id} 时发生错误: {e}")
break
return aggregated_outputs
async def execute_prompt_on_server(prompt: Dict, server: ComfyUIServer) -> Dict:
"""
在指定的服务器上执行一个准备好的prompt。
"""
client_id = str(uuid.uuid4())
prompt_id = await queue_prompt(prompt, client_id, server.http_url)
print(f"工作流已在 {server.http_url} 上入队Prompt ID: {prompt_id}")
results = await get_execution_results(prompt_id, client_id, server.ws_url)
return results