75 lines
3.2 KiB
Python
75 lines
3.2 KiB
Python
import websockets
|
||
import json
|
||
import uuid
|
||
import aiohttp
|
||
import random
|
||
from workflow_service.config import Settings
|
||
settings = Settings()
|
||
|
||
|
||
async def queue_prompt(prompt: dict, client_id: str) -> str:
|
||
"""通过HTTP POST将工作流任务提交到ComfyUI队列。"""
|
||
# [缓存破解] 在每个节点中添加一个随机数输入,强制重新执行
|
||
for node_id in prompt:
|
||
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
|
||
|
||
payload = {"prompt": prompt, "client_id": client_id}
|
||
async with aiohttp.ClientSession() as session:
|
||
http_url = f"{settings.COMFYUI_HTTP_URL}/prompt"
|
||
async with session.post(http_url, json=payload) as response:
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
if "prompt_id" not in result:
|
||
raise Exception(f"Invalid response from ComfyUI /prompt endpoint: {result}")
|
||
return result["prompt_id"]
|
||
|
||
|
||
async def get_execution_results(prompt_id: str, client_id: str) -> dict:
|
||
"""
|
||
通过WebSocket连接,聚合所有'executed'事件的输出,
|
||
直到整个执行流程结束。
|
||
"""
|
||
ws_url = f"{settings.COMFYUI_URL}?clientId={client_id}"
|
||
aggregated_outputs = {}
|
||
|
||
async with websockets.connect(ws_url) as websocket:
|
||
while True:
|
||
try:
|
||
out = await websocket.recv()
|
||
if isinstance(out, str):
|
||
message = json.loads(out)
|
||
msg_type = message.get('type')
|
||
data = message.get('data')
|
||
|
||
# 我们只关心与我们prompt_id相关的事件
|
||
if data and data.get('prompt_id') == prompt_id:
|
||
if msg_type == 'executed':
|
||
# 聚合每个节点的输出
|
||
node_id = data.get('node')
|
||
output_data = data.get('output')
|
||
if node_id and output_data:
|
||
aggregated_outputs[node_id] = output_data
|
||
print(f"Output received for node {node_id} (Prompt ID: {prompt_id})")
|
||
|
||
# 判断执行是否结束
|
||
# 官方UI使用 "executing" 且 node is null 作为结束标志
|
||
elif msg_type == 'executing' and data.get('node') is None:
|
||
print(f"Execution finished for Prompt ID: {prompt_id}")
|
||
return aggregated_outputs
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
print(f"WebSocket connection closed for {prompt_id}. Returning aggregated results. Error: {e}")
|
||
return aggregated_outputs # 连接关闭也视为结束
|
||
except Exception as e:
|
||
print(f"An error occurred for {prompt_id}: {e}")
|
||
break
|
||
return aggregated_outputs
|
||
|
||
|
||
async def run_workflow(prompt: dict) -> dict:
|
||
"""主协调函数:提交任务,然后等待结果。"""
|
||
client_id = str(uuid.uuid4())
|
||
prompt_id = await queue_prompt(prompt, client_id)
|
||
print(f"Workflow successfully queued with Prompt ID: {prompt_id}")
|
||
results = await get_execution_results(prompt_id, client_id)
|
||
return results |