ComfyUI-WorkflowPublisher/workflow_service/comfyui_client.py

74 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import websockets
import json
import uuid
import aiohttp
import random
from config import settings
async def queue_prompt(prompt: dict, client_id: str) -> str:
"""通过HTTP POST将工作流任务提交到ComfyUI队列。"""
# [缓存破解] 在每个节点中添加一个随机数输入,强制重新执行
for node_id in prompt:
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
payload = {"prompt": prompt, "client_id": client_id}
async with aiohttp.ClientSession() as session:
http_url = f"{settings.COMFYUI_HTTP_URL}/prompt"
async with session.post(http_url, json=payload) as response:
response.raise_for_status()
result = await response.json()
if "prompt_id" not in result:
raise Exception(f"Invalid response from ComfyUI /prompt endpoint: {result}")
return result["prompt_id"]
async def get_execution_results(prompt_id: str, client_id: str) -> dict:
"""
通过WebSocket连接聚合所有'executed'事件的输出,
直到整个执行流程结束。
"""
ws_url = f"{settings.COMFYUI_URL}?clientId={client_id}"
aggregated_outputs = {}
async with websockets.connect(ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get('type')
data = message.get('data')
# 我们只关心与我们prompt_id相关的事件
if data and data.get('prompt_id') == prompt_id:
if msg_type == 'executed':
# 聚合每个节点的输出
node_id = data.get('node')
output_data = data.get('output')
if node_id and output_data:
aggregated_outputs[node_id] = output_data
print(f"Output received for node {node_id} (Prompt ID: {prompt_id})")
# 判断执行是否结束
# 官方UI使用 "executing" 且 node is null 作为结束标志
elif msg_type == 'executing' and data.get('node') is None:
print(f"Execution finished for Prompt ID: {prompt_id}")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
print(f"WebSocket connection closed for {prompt_id}. Returning aggregated results. Error: {e}")
return aggregated_outputs # 连接关闭也视为结束
except Exception as e:
print(f"An error occurred for {prompt_id}: {e}")
break
return aggregated_outputs
async def run_workflow(prompt: dict) -> dict:
"""主协调函数:提交任务,然后等待结果。"""
client_id = str(uuid.uuid4())
prompt_id = await queue_prompt(prompt, client_id)
print(f"Workflow successfully queued with Prompt ID: {prompt_id}")
results = await get_execution_results(prompt_id, client_id)
return results