142 lines
5.8 KiB
Python
142 lines
5.8 KiB
Python
import asyncio
|
||
import json
|
||
import random
|
||
import uuid
|
||
from typing import Dict, Any
|
||
|
||
import aiohttp
|
||
import websockets
|
||
from aiohttp import ClientTimeout
|
||
|
||
from workflow_service.config import Settings, ComfyUIServer
|
||
|
||
settings = Settings()
|
||
|
||
|
||
async def get_server_status(server: ComfyUIServer, session: aiohttp.ClientSession) -> Dict[str, Any]:
|
||
"""
|
||
检查单个ComfyUI服务器的详细状态。
|
||
返回一个包含可达性、队列状态和详细队列内容的字典。
|
||
"""
|
||
# [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型
|
||
status_info = {
|
||
"is_reachable": False,
|
||
"is_free": False,
|
||
"queue_details": {
|
||
"running_count": 0,
|
||
"pending_count": 0
|
||
}
|
||
}
|
||
try:
|
||
queue_url = f"{server.http_url}/queue"
|
||
async with session.get(queue_url, timeout=60) as response:
|
||
response.raise_for_status()
|
||
queue_data = await response.json()
|
||
|
||
status_info["is_reachable"] = True
|
||
|
||
running_count = len(queue_data.get("queue_running", []))
|
||
pending_count = len(queue_data.get("queue_pending", []))
|
||
|
||
status_info["queue_details"] = {
|
||
"running_count": running_count,
|
||
"pending_count": pending_count
|
||
}
|
||
|
||
status_info["is_free"] = (running_count == 0 and pending_count == 0)
|
||
|
||
except Exception as e:
|
||
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
|
||
print(f"警告: 无法检查服务器 {server.http_url} 的队列状态: {e}")
|
||
|
||
return status_info
|
||
|
||
|
||
async def select_server_for_execution() -> ComfyUIServer:
|
||
"""
|
||
智能选择一个ComfyUI服务器。
|
||
优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。
|
||
"""
|
||
servers = settings.SERVERS
|
||
if not servers:
|
||
raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。")
|
||
if len(servers) == 1:
|
||
return servers[0]
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
tasks = [get_server_status(server, session) for server in servers]
|
||
results = await asyncio.gather(*tasks)
|
||
|
||
free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]]
|
||
|
||
if free_servers:
|
||
selected_server = random.choice(free_servers)
|
||
print(f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}")
|
||
return selected_server
|
||
else:
|
||
# 后备方案:选择一个可达的服务器,即使它很忙
|
||
reachable_servers = [servers[i] for i, status in enumerate(results) if status["is_reachable"]]
|
||
if reachable_servers:
|
||
selected_server = random.choice(reachable_servers)
|
||
print(f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}")
|
||
return selected_server
|
||
else:
|
||
# 最坏情况:所有服务器都不可达,抛出异常
|
||
raise ConnectionError("所有配置的ComfyUI服务器都不可达。")
|
||
|
||
|
||
async def queue_prompt(prompt: dict, client_id: str, http_url: str) -> str:
|
||
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
|
||
for node_id in prompt:
|
||
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
|
||
payload = {"prompt": prompt, "client_id": client_id}
|
||
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
|
||
prompt_url = f"{http_url}/prompt"
|
||
async with session.post(prompt_url, json=payload) as response:
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
if "prompt_id" not in result:
|
||
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
|
||
return result["prompt_id"]
|
||
|
||
|
||
async def get_execution_results(prompt_id: str, client_id: str, ws_url: str) -> dict:
|
||
"""通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。"""
|
||
full_ws_url = f"{ws_url}?clientId={client_id}"
|
||
aggregated_outputs = {}
|
||
async with websockets.connect(full_ws_url) as websocket:
|
||
while True:
|
||
try:
|
||
out = await websocket.recv()
|
||
if isinstance(out, str):
|
||
message = json.loads(out)
|
||
msg_type = message.get('type')
|
||
data = message.get('data')
|
||
if data and data.get('prompt_id') == prompt_id:
|
||
if msg_type == 'executed':
|
||
node_id = data.get('node')
|
||
output_data = data.get('output')
|
||
if node_id and output_data:
|
||
aggregated_outputs[node_id] = output_data
|
||
print(f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})")
|
||
elif msg_type == 'executing' and data.get('node') is None:
|
||
print(f"Prompt ID: {prompt_id} 执行完成。")
|
||
return aggregated_outputs
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
print(f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}")
|
||
return aggregated_outputs
|
||
except Exception as e:
|
||
print(f"处理 prompt {prompt_id} 时发生错误: {e}")
|
||
break
|
||
return aggregated_outputs
|
||
|
||
|
||
async def execute_prompt_on_server(prompt: Dict, server: ComfyUIServer) -> Dict:
|
||
"""
|
||
在指定的服务器上执行一个准备好的prompt。
|
||
"""
|
||
client_id = str(uuid.uuid4())
|
||
prompt_id = await queue_prompt(prompt, client_id, server.http_url)
|
||
print(f"工作流已在 {server.http_url} 上入队,Prompt ID: {prompt_id}")
|
||
results = await get_execution_results(prompt_id, client_id, server.ws_url)
|
||
return results |