ComfyUI-WorkflowPublisher/workflow_service/comfyui_client.py

966 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import base64
import json
import logging
import os
import random
import uuid
import websockets
from collections import defaultdict
from datetime import datetime
from typing import Dict, Any, Optional, List, Set, Union
import aiohttp
from aiohttp import ClientTimeout
from workflow_service.config import Settings, ComfyUIServer
from workflow_service.database import (
create_workflow_run,
update_workflow_run_status,
create_workflow_run_nodes,
update_workflow_run_node_status,
get_workflow_run,
get_pending_workflow_runs,
get_running_workflow_runs,
get_workflow_run_nodes,
)
settings = Settings()
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局任务队列管理器
class WorkflowQueueManager:
def __init__(self):
self.running_tasks = {} # server_url -> task_info
self.pending_tasks = [] # 等待队列
self.lock = asyncio.Lock()
async def add_task(
self,
workflow_run_id: str,
workflow_name: str,
workflow_data: dict,
api_spec: dict,
request_data: dict,
):
"""添加新任务到队列"""
async with self.lock:
# 创建任务记录
await create_workflow_run(
workflow_run_id=workflow_run_id,
workflow_name=workflow_name,
workflow_json=json.dumps(workflow_data),
api_spec=json.dumps(api_spec),
request_data=json.dumps(request_data),
)
# 创建工作流节点记录
nodes_data = []
for node in workflow_data.get("nodes", []):
nodes_data.append(
{"id": str(node["id"]), "type": node.get("type", "unknown")}
)
await create_workflow_run_nodes(workflow_run_id, nodes_data)
# 添加到待处理队列
self.pending_tasks.append(workflow_run_id)
logger.info(
f"任务 {workflow_run_id} 已添加到队列,当前队列长度: {len(self.pending_tasks)}"
)
# 尝试处理队列
asyncio.create_task(self._process_queue())
return workflow_run_id
async def _process_queue(self):
"""处理队列中的任务"""
async with self.lock:
if not self.pending_tasks:
return
# 检查是否有空闲的服务器
available_servers = await self._get_available_servers()
if not available_servers:
logger.info("没有可用的服务器,等待中...")
return
# 获取一个待处理任务
workflow_run_id = self.pending_tasks.pop(0)
server = available_servers[0]
# 标记任务为运行中
await update_workflow_run_status(
workflow_run_id, "running", server.http_url
)
self.running_tasks[server.http_url] = {
"workflow_run_id": workflow_run_id,
"started_at": datetime.now(),
}
# 启动任务执行
asyncio.create_task(self._execute_task(workflow_run_id, server))
async def _get_available_servers(self) -> list[ComfyUIServer]:
"""获取可用的服务器"""
available_servers = []
for server in settings.SERVERS:
if server.http_url not in self.running_tasks:
# 检查服务器状态
try:
async with aiohttp.ClientSession() as session:
status = await get_server_status(server, session)
if status["is_reachable"] and status["is_free"]:
available_servers.append(server)
except Exception as e:
logger.warning(f"检查服务器 {server.http_url} 状态时出错: {e}")
return available_servers
async def _execute_task(self, workflow_run_id: str, server: ComfyUIServer):
"""执行任务"""
cleanup_paths = []
try:
# 获取工作流数据
workflow_run = await get_workflow_run(workflow_run_id)
if not workflow_run:
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
workflow_data = json.loads(workflow_run["workflow_json"])
api_spec = json.loads(workflow_run["api_spec"])
request_data = json.loads(workflow_run["request_data"])
# 执行工作流
result = await execute_prompt_on_server(
workflow_data, api_spec, request_data, server, workflow_run_id
)
# 保存处理后的结果到数据库
await update_workflow_run_status(
workflow_run_id,
"completed",
result=json.dumps(result, ensure_ascii=False),
)
except Exception as e:
logger.error(f"执行任务 {workflow_run_id} 时出错: {e}")
await update_workflow_run_status(
workflow_run_id, "failed", error_message=str(e)
)
finally:
# 清理临时文件
if cleanup_paths:
logger.info(f"正在清理 {len(cleanup_paths)} 个临时文件...")
for path in cleanup_paths:
try:
if os.path.exists(path):
os.remove(path)
logger.info(f" - 已删除: {path}")
except Exception as e:
logger.warning(f" - 删除 {path} 时出错: {e}")
# 清理运行状态
async with self.lock:
if server.http_url in self.running_tasks:
del self.running_tasks[server.http_url]
# 继续处理队列
asyncio.create_task(self._process_queue())
async def get_task_status(self, workflow_run_id: str) -> dict:
"""获取任务状态"""
workflow_run = await get_workflow_run(workflow_run_id)
if not workflow_run:
return {"error": "任务不存在"}
nodes = await get_workflow_run_nodes(workflow_run_id)
result = {
"id": workflow_run_id,
"status": workflow_run["status"],
"created_at": workflow_run["created_at"],
"started_at": workflow_run["started_at"],
"completed_at": workflow_run["completed_at"],
"server_url": workflow_run["server_url"],
"error_message": workflow_run["error_message"],
"nodes": nodes,
}
# 如果任务完成,从数据库获取结果
if workflow_run["status"] == "completed" and workflow_run.get("result"):
try:
result["result"] = json.loads(workflow_run["result"])
except (json.JSONDecodeError, TypeError):
result["result"] = None
return result
# 全局队列管理器实例
queue_manager = WorkflowQueueManager()
# 定义一个自定义异常用于封装来自ComfyUI的执行错误
class ComfyUIExecutionError(Exception):
def __init__(self, error_data: dict):
self.error_data = error_data
# 创建一个对开发者友好的异常消息
message = (
f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, "
f"节点类型: {error_data.get('node_type')}. "
f"错误: {error_data.get('exception_message', 'N/A')}"
)
super().__init__(message)
async def get_server_status(
server: ComfyUIServer, session: aiohttp.ClientSession
) -> Dict[str, Any]:
"""
检查单个ComfyUI服务器的详细状态。
返回一个包含可达性、队列状态和详细队列内容的字典。
"""
# [BUG修复] 确保初始字典结构与成功时的结构一致以满足Pydantic模型
status_info = {
"is_reachable": False,
"is_free": False,
"queue_details": {"running_count": 0, "pending_count": 0},
}
try:
queue_url = f"{server.http_url}/queue"
async with session.get(queue_url, timeout=60) as response:
response.raise_for_status()
queue_data = await response.json()
status_info["is_reachable"] = True
running_count = len(queue_data.get("queue_running", []))
pending_count = len(queue_data.get("queue_pending", []))
status_info["queue_details"] = {
"running_count": running_count,
"pending_count": pending_count,
}
status_info["is_free"] = running_count == 0 and pending_count == 0
except Exception as e:
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
logger.warning(f"无法检查服务器 {server.http_url} 的队列状态: {e}")
return status_info
async def select_server_for_execution() -> ComfyUIServer:
"""
智能选择一个ComfyUI服务器。
优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。
"""
servers = settings.SERVERS
if not servers:
raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。")
if len(servers) == 1:
return servers[0]
async with aiohttp.ClientSession() as session:
tasks = [get_server_status(server, session) for server in servers]
results = await asyncio.gather(*tasks)
free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]]
if free_servers:
selected_server = random.choice(free_servers)
logger.info(
f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}"
)
return selected_server
else:
# 后备方案:选择一个可达的服务器,即使它很忙
reachable_servers = [
servers[i] for i, status in enumerate(results) if status["is_reachable"]
]
if reachable_servers:
selected_server = random.choice(reachable_servers)
logger.info(
f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}"
)
return selected_server
else:
# 最坏情况:所有服务器都不可达,抛出异常
raise ConnectionError("所有配置的ComfyUI服务器都不可达。")
async def execute_prompt_on_server(
workflow_data: Dict,
api_spec: Dict,
request_data: Dict,
server: ComfyUIServer,
workflow_run_id: str,
) -> Dict:
"""
在指定的服务器上执行一个准备好的prompt。
现在支持节点级别的状态跟踪。
"""
client_id = str(uuid.uuid4())
# 应用请求数据到工作流
patched_workflow = await patch_workflow(
workflow_data, api_spec, request_data, server
)
# 转换为prompt格式
prompt = convert_workflow_to_prompt_api_format(patched_workflow)
# 更新工作流运行状态记录prompt_id和client_id
await update_workflow_run_status(
workflow_run_id,
"running",
server.http_url,
None, # prompt_id将在_queue_prompt后更新
client_id,
)
# 提交到ComfyUI
prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url)
# 更新prompt_id
await update_workflow_run_status(
workflow_run_id, "running", server.http_url, prompt_id, client_id
)
logger.info(
f"工作流 {workflow_run_id} 已在 {server.http_url} 上入队Prompt ID: {prompt_id}"
)
# 获取执行结果,现在支持节点级别的状态跟踪
results = await _get_execution_results(
workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id
)
return results
async def submit_workflow_to_queue(
workflow_name: str, workflow_data: Dict, api_spec: Dict, request_data: Dict
) -> str:
"""
提交工作流到队列立即返回任务ID。
这是新的异步接口调用者可以通过任务ID查询状态。
"""
workflow_run_id = str(uuid.uuid4())
# 添加到队列管理器
await queue_manager.add_task(
workflow_run_id, workflow_name, workflow_data, api_spec, request_data
)
return workflow_run_id
def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
"""
解析工作流并根据规范生成嵌套结构的API参数名。
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
"""
spec = {"inputs": {}, "outputs": {}}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError(
"Invalid workflow format: 'nodes' key not found or is not a list."
)
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for node_id, node in nodes_map.items():
title: str = node.get("title")
if not title:
continue
if title.startswith(API_INPUT_PREFIX):
# 提取节点名去掉API_INPUT_PREFIX前缀
node_name = title[len(API_INPUT_PREFIX) :].lower()
# 如果节点名不在inputs中创建新的节点条目
if node_name not in spec["inputs"]:
spec["inputs"][node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
}
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"].lower()
# 特殊处理隐藏LoadImage节点的upload字段
if node.get("type") == "LoadImage" and widget_name == "upload":
continue
# 确保inputs字段存在
if "inputs" not in spec["inputs"][node_name]:
spec["inputs"][node_name]["inputs"] = {}
# 直接使用widget_name作为字段名
spec["inputs"][node_name]["inputs"][widget_name] = {
"type": _get_param_type(a_input.get("type", "STRING")),
"widget_name": a_input["widget"]["name"],
}
elif title.startswith(API_OUTPUT_PREFIX):
# 提取节点名去掉API_OUTPUT_PREFIX前缀
node_name = title[len(API_OUTPUT_PREFIX) :].lower()
# 如果节点名不在outputs中创建新的节点条目
if node_name not in spec["outputs"]:
spec["outputs"][node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
}
if "outputs" in node:
for an_output in node.get("outputs", []):
output_name = an_output["name"].lower()
# 确保outputs字段存在
if "outputs" not in spec["outputs"][node_name]:
spec["outputs"][node_name]["outputs"] = {}
# 直接使用output_name作为字段名
spec["outputs"][node_name]["outputs"][output_name] = {
"output_name": an_output["name"],
}
return spec
def parse_inputs_json_schema(workflow_data: dict) -> dict:
"""
解析工作流,生成符合 JSON Schema 标准的 API 规范。
返回一个只包含输入参数的 JSON Schema 对象。
"""
# 基础 JSON Schema 结构,只包含输入
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"title": "ComfyUI Workflow Input Schema",
"description": "工作流输入参数定义",
"properties": {},
"required": [],
}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError(
"Invalid workflow format: 'nodes' key not found or is not a list."
)
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
# 收集所有输入节点
input_nodes = {}
for node_id, node in nodes_map.items():
title: str = node.get("title")
if not title:
continue
if title.startswith(API_INPUT_PREFIX):
# 提取节点名去掉API_INPUT_PREFIX前缀
node_name = title[len(API_INPUT_PREFIX) :].lower()
if node_name not in input_nodes:
input_nodes[node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
"properties": {},
"required": [],
}
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"].lower()
# 特殊处理隐藏LoadImage节点的upload字段
if node.get("type") == "LoadImage" and widget_name == "upload":
continue
# 生成字段的 JSON Schema 定义
field_schema = _generate_field_schema(a_input)
input_nodes[node_name]["properties"][widget_name] = field_schema
# 如果字段是必需的添加到required列表
if a_input.get("required", True): # 默认所有字段都是必需的
input_nodes[node_name]["required"].append(widget_name)
# 将收集的节点信息转换为 JSON Schema 格式
for node_name, node_info in input_nodes.items():
schema["properties"][node_name] = {
"type": "object",
"title": f"输入节点: {node_name}",
"description": f"节点类型: {node_info['class_type']}",
"properties": node_info["properties"],
"required": node_info["required"] if node_info["required"] else [],
"additionalProperties": False,
}
# 如果节点有必需字段,将整个节点标记为必需
if node_info["required"]:
schema["required"].append(node_name)
return schema
def _generate_field_schema(input_field: dict) -> dict:
"""
根据输入字段信息生成 JSON Schema 字段定义
"""
field_schema = {
"title": input_field["widget"]["name"],
"description": input_field.get("description", ""),
}
# 根据字段类型设置 schema 类型
field_type = _get_param_type(input_field.get("type", "STRING"))
if field_type == "int":
field_schema.update(
{
"type": "integer",
"minimum": input_field.get("min", None),
"maximum": input_field.get("max", None),
}
)
elif field_type == "float":
field_schema.update(
{
"type": "number",
"minimum": input_field.get("min", None),
"maximum": input_field.get("max", None),
}
)
elif field_type == "UploadFile":
field_schema.update(
{"type": "string", "format": "binary", "description": "上传文件"}
)
else: # string
field_schema.update({"type": "string"})
# 处理枚举值
if "options" in input_field:
field_schema["enum"] = input_field["options"]
field_schema[
"description"
] += f" 可选值: {', '.join(input_field['options'])}"
# 设置默认值
if "default" in input_field:
field_schema["default"] = input_field["default"]
return field_schema
def _get_param_type(input_type_str: str) -> str:
"""
根据输入类型字符串确定参数类型
"""
input_type_str = input_type_str.upper()
if "COMBO" in input_type_str:
return "UploadFile"
elif "INT" in input_type_str:
return "int"
elif "FLOAT" in input_type_str:
return "float"
else:
return "string"
async def patch_workflow(
workflow_data: dict,
api_spec: dict[str, dict[str, Any]],
request_data: dict[str, Any],
server: ComfyUIServer,
) -> dict:
"""
将request_data中的参数值patch到workflow_data中。并返回修改后的workflow_data。
request_data结构: {"节点名称": {"字段名称": "字段值"}}
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for node_name, node_fields in request_data.items():
if node_name not in api_spec["inputs"]:
continue
node_spec = api_spec["inputs"][node_name]
node_id = node_spec["node_id"]
if node_id not in nodes_map:
continue
target_node = nodes_map[node_id]
# 处理该节点下的所有字段
for field_name, value in node_fields.items():
if "inputs" not in node_spec or field_name not in node_spec["inputs"]:
continue
field_spec = node_spec["inputs"][field_name]
widget_name_to_patch = field_spec["widget_name"]
# 找到所有属于widget的输入并确定目标widget的索引
widget_inputs = [
inp for inp in target_node.get("inputs", []) if "widget" in inp
]
target_widget_index = -1
for i, widget_input in enumerate(widget_inputs):
# "widget" 字典可能不存在或没有 "name" 键
if widget_input.get("widget", {}).get("name") == widget_name_to_patch:
target_widget_index = i
break
if target_widget_index == -1:
logger.warning(
f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。"
)
continue
# 确保 `widgets_values` 存在且为列表
if "widgets_values" not in target_node or not isinstance(
target_node.get("widgets_values"), list
):
# 如果不存在或格式错误根据widget数量创建一个占位符列表
target_node["widgets_values"] = [None] * len(widget_inputs)
# 确保 `widgets_values` 列表足够长
while len(target_node["widgets_values"]) <= target_widget_index:
target_node["widgets_values"].append(None)
# 根据API规范转换数据类型
target_type = str
if field_spec["type"] == "int":
target_type = int
elif field_spec["type"] == "float":
target_type = float
# 在正确的位置上更新值
try:
if target_node.get("type") == "LoadImage":
value = await upload_image_to_comfy(value, server)
target_node["widgets_values"][target_widget_index] = target_type(value)
except (ValueError, TypeError) as e:
logger.warning(
f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec['type']}'。错误: {e}"
)
continue
workflow_data["nodes"] = list(nodes_map.values())
return workflow_data
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
"""
将工作流API格式转换为提交到/prompt端点的格式。
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
prompt_api_format = {}
# 建立从link_id到源节点的映射
link_map = {}
for link_data in workflow_data.get("links", []):
(
link_id,
origin_node_id,
origin_slot_index,
target_node_id,
target_slot_index,
link_type,
) = link_data
# 键是目标节点的输入link_id
link_map[link_id] = [str(origin_node_id), origin_slot_index]
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = {}
# 1. 处理控件输入 (widgets)
widgets_values = node.get("widgets_values", [])
widget_cursor = 0
for input_config in node.get("inputs", []):
# 如果是widget并且有对应的widgets_values
if "widget" in input_config and widget_cursor < len(widgets_values):
widget_name = input_config["widget"].get("name")
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config["name"]] = link_map[link_id]
prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict}
return prompt_api_format
async def _queue_prompt(
workflow: dict, prompt: dict, client_id: str, http_url: str
) -> str:
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
for node_id in prompt:
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
payload = {
"prompt": prompt,
"client_id": client_id,
"extra_data": {
"api_key_comfy_org": "",
"extra_pnginfo": {"workflow": workflow},
},
}
logger.info(f"提交到 ComfyUI /prompt 端点的payload: {json.dumps(payload)}")
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
prompt_url = f"{http_url}/prompt"
try:
async with session.post(prompt_url, json=payload) as response:
logger.info(f"ComfyUI /prompt 端点返回的响应: {response}")
response.raise_for_status()
result = await response.json()
logger.info(f"ComfyUI /prompt 端点返回的响应: {result}")
if "prompt_id" not in result:
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
return result["prompt_id"]
except Exception as e:
logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}")
raise e
async def _get_execution_results(
workflow_data: Dict,
prompt_id: str,
client_id: str,
ws_url: str,
workflow_run_id: str,
) -> dict:
"""
通过WebSocket连接到指定的ComfyUI服务器聚合执行结果。
现在支持节点级别的状态跟踪。
"""
full_ws_url = f"{ws_url}?clientId={client_id}"
aggregated_outputs = {}
try:
async with websockets.connect(full_ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data")
if not (data and data.get("prompt_id") == prompt_id):
continue
# 捕获并处理执行错误
if msg_type == "execution_error":
error_data = data
logger.error(
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
)
# 更新节点状态为失败
node_id = error_data.get("node_id")
if node_id:
await update_workflow_run_node_status(
workflow_run_id,
node_id,
"failed",
error_message=error_data.get(
"exception_message", "Unknown error"
),
)
# 抛出自定义异常,将错误详情传递出去
raise ComfyUIExecutionError(error_data)
# 处理节点开始执行
if msg_type == "executing" and data.get("node"):
node_id = data.get("node")
logger.info(f"节点 {node_id} 开始执行 (Prompt ID: {prompt_id})")
# 更新节点状态为运行中
await update_workflow_run_node_status(
workflow_run_id, node_id, "running"
)
# 处理节点执行完成
if msg_type == "executed":
node_id = data.get("node")
output_data = data.get("output")
if node_id and output_data:
node = next(
(
x
for x in workflow_data["nodes"]
if str(x["id"]) == node_id
),
None,
)
if (
node
and node.get("title", "")
and node["title"].startswith(API_OUTPUT_PREFIX)
):
aggregated_outputs[node["title"]] = output_data
logger.info(
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
)
# 更新节点状态为完成
await update_workflow_run_node_status(
workflow_run_id,
node_id,
"completed",
output_data=json.dumps(output_data),
)
# 处理整个工作流执行完成
elif msg_type == "executing" and data.get("node") is None:
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
)
return aggregated_outputs
except Exception as e:
# 重新抛出我们自己的异常,或者处理其他意外错误
if not isinstance(e, ComfyUIExecutionError):
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
raise e
except websockets.exceptions.InvalidURI as e:
logger.error(
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
)
raise e
return aggregated_outputs
async def _get_execution_results_legacy(
prompt_id: str, client_id: str, ws_url: str
) -> dict:
"""
简化版的执行结果获取函数,不包含数据库状态跟踪。
"""
full_ws_url = f"{ws_url}?clientId={client_id}"
aggregated_outputs = {}
try:
async with websockets.connect(full_ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data")
if not (data and data.get("prompt_id") == prompt_id):
continue
# 捕获并处理执行错误
if msg_type == "execution_error":
error_data = data
logger.error(
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
)
# 抛出自定义异常,将错误详情传递出去
raise ComfyUIExecutionError(error_data)
# 处理节点执行完成
if msg_type == "executed":
node_id = data.get("node")
output_data = data.get("output")
if node_id and output_data:
aggregated_outputs[node_id] = output_data
logger.info(
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
)
# 处理整个工作流执行完成
elif msg_type == "executing" and data.get("node") is None:
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
)
return aggregated_outputs
except Exception as e:
# 重新抛出我们自己的异常,或者处理其他意外错误
if not isinstance(e, ComfyUIExecutionError):
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
raise e
except websockets.exceptions.InvalidURI as e:
logger.error(
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
)
raise e
return aggregated_outputs
async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str:
"""
上传文件到服务器。
file 是一个http链接
返回一个名称
"""
file_name = file_path.split("/")[-1]
file_data = await download_file(file_path)
form_data = aiohttp.FormData()
form_data.add_field("image", file_data, filename=file_name, content_type="image/*")
form_data.add_field("type", "input")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{server.http_url}/api/upload/image", data=form_data
) as response:
response.raise_for_status()
result = await response.json()
return result["name"]
async def download_file(src: str) -> bytes:
"""
下载文件到本地。
"""
async with aiohttp.ClientSession() as session:
async with session.get(src) as response:
response.raise_for_status()
return await response.read()