ComfyUI-WorkflowPublisher/workflow_service/comfyui_client.py

820 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import base64
import json
import logging
import os
import random
import uuid
import websockets
from collections import defaultdict
from datetime import datetime
from typing import Dict, Any, Optional, List, Set, Union
import aiohttp
from aiohttp import ClientTimeout
from workflow_service.config import Settings, ComfyUIServer
from workflow_service.database import (
create_workflow_run,
update_workflow_run_status,
create_workflow_run_nodes,
update_workflow_run_node_status,
get_workflow_run,
get_pending_workflow_runs,
get_running_workflow_runs,
get_workflow_run_nodes,
)
settings = Settings()
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局任务队列管理器
class WorkflowQueueManager:
def __init__(self):
self.running_tasks = {} # server_url -> task_info
self.pending_tasks = [] # 等待队列
self.lock = asyncio.Lock()
async def add_task(
self,
workflow_run_id: str,
workflow_name: str,
workflow_data: dict,
api_spec: dict,
request_data: dict,
):
"""添加新任务到队列"""
async with self.lock:
# 创建任务记录
await create_workflow_run(
workflow_run_id=workflow_run_id,
workflow_name=workflow_name,
workflow_json=json.dumps(workflow_data),
api_spec=json.dumps(api_spec),
request_data=json.dumps(request_data),
)
# 创建工作流节点记录
nodes_data = []
for node in workflow_data.get("nodes", []):
nodes_data.append(
{"id": str(node["id"]), "type": node.get("type", "unknown")}
)
await create_workflow_run_nodes(workflow_run_id, nodes_data)
# 添加到待处理队列
self.pending_tasks.append(workflow_run_id)
logger.info(
f"任务 {workflow_run_id} 已添加到队列,当前队列长度: {len(self.pending_tasks)}"
)
# 尝试处理队列
asyncio.create_task(self._process_queue())
return workflow_run_id
async def _process_queue(self):
"""处理队列中的任务"""
async with self.lock:
if not self.pending_tasks:
return
# 检查是否有空闲的服务器
available_servers = await self._get_available_servers()
if not available_servers:
logger.info("没有可用的服务器,等待中...")
return
# 获取一个待处理任务
workflow_run_id = self.pending_tasks.pop(0)
server = available_servers[0]
# 标记任务为运行中
await update_workflow_run_status(
workflow_run_id, "running", server.http_url
)
self.running_tasks[server.http_url] = {
"workflow_run_id": workflow_run_id,
"started_at": datetime.now(),
}
# 启动任务执行
asyncio.create_task(self._execute_task(workflow_run_id, server))
async def _get_available_servers(self) -> list[ComfyUIServer]:
"""获取可用的服务器"""
available_servers = []
for server in settings.SERVERS:
if server.http_url not in self.running_tasks:
# 检查服务器状态
try:
async with aiohttp.ClientSession() as session:
status = await get_server_status(server, session)
if status["is_reachable"] and status["is_free"]:
available_servers.append(server)
except Exception as e:
logger.warning(f"检查服务器 {server.http_url} 状态时出错: {e}")
return available_servers
async def _execute_task(self, workflow_run_id: str, server: ComfyUIServer):
"""执行任务"""
cleanup_paths = []
try:
# 获取工作流数据
workflow_run = await get_workflow_run(workflow_run_id)
if not workflow_run:
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
workflow_data = json.loads(workflow_run["workflow_json"])
api_spec = json.loads(workflow_run["api_spec"])
request_data = json.loads(workflow_run["request_data"])
# 执行工作流
result = await execute_prompt_on_server(
workflow_data, api_spec, request_data, server, workflow_run_id
)
# 保存处理后的结果到数据库
await update_workflow_run_status(
workflow_run_id,
"completed",
result=json.dumps(result, ensure_ascii=False),
)
except Exception as e:
logger.error(f"执行任务 {workflow_run_id} 时出错: {e}")
await update_workflow_run_status(
workflow_run_id, "failed", error_message=str(e)
)
finally:
# 清理临时文件
if cleanup_paths:
logger.info(f"正在清理 {len(cleanup_paths)} 个临时文件...")
for path in cleanup_paths:
try:
if os.path.exists(path):
os.remove(path)
logger.info(f" - 已删除: {path}")
except Exception as e:
logger.warning(f" - 删除 {path} 时出错: {e}")
# 清理运行状态
async with self.lock:
if server.http_url in self.running_tasks:
del self.running_tasks[server.http_url]
# 继续处理队列
asyncio.create_task(self._process_queue())
async def get_task_status(self, workflow_run_id: str) -> dict:
"""获取任务状态"""
workflow_run = await get_workflow_run(workflow_run_id)
if not workflow_run:
return {"error": "任务不存在"}
nodes = await get_workflow_run_nodes(workflow_run_id)
result = {
"id": workflow_run_id,
"status": workflow_run["status"],
"created_at": workflow_run["created_at"],
"started_at": workflow_run["started_at"],
"completed_at": workflow_run["completed_at"],
"server_url": workflow_run["server_url"],
"error_message": workflow_run["error_message"],
"nodes": nodes,
}
# 如果任务完成,从数据库获取结果
if workflow_run["status"] == "completed" and workflow_run.get("result"):
try:
result["result"] = json.loads(workflow_run["result"])
except (json.JSONDecodeError, TypeError):
result["result"] = None
return result
# 全局队列管理器实例
queue_manager = WorkflowQueueManager()
# 定义一个自定义异常用于封装来自ComfyUI的执行错误
class ComfyUIExecutionError(Exception):
def __init__(self, error_data: dict):
self.error_data = error_data
# 创建一个对开发者友好的异常消息
message = (
f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, "
f"节点类型: {error_data.get('node_type')}. "
f"错误: {error_data.get('exception_message', 'N/A')}"
)
super().__init__(message)
async def get_server_status(
server: ComfyUIServer, session: aiohttp.ClientSession
) -> Dict[str, Any]:
"""
检查单个ComfyUI服务器的详细状态。
返回一个包含可达性、队列状态和详细队列内容的字典。
"""
# [BUG修复] 确保初始字典结构与成功时的结构一致以满足Pydantic模型
status_info = {
"is_reachable": False,
"is_free": False,
"queue_details": {"running_count": 0, "pending_count": 0},
}
try:
queue_url = f"{server.http_url}/queue"
async with session.get(queue_url, timeout=60) as response:
response.raise_for_status()
queue_data = await response.json()
status_info["is_reachable"] = True
running_count = len(queue_data.get("queue_running", []))
pending_count = len(queue_data.get("queue_pending", []))
status_info["queue_details"] = {
"running_count": running_count,
"pending_count": pending_count,
}
status_info["is_free"] = running_count == 0 and pending_count == 0
except Exception as e:
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
logger.warning(f"无法检查服务器 {server.http_url} 的队列状态: {e}")
return status_info
async def select_server_for_execution() -> ComfyUIServer:
"""
智能选择一个ComfyUI服务器。
优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。
"""
servers = settings.SERVERS
if not servers:
raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。")
if len(servers) == 1:
return servers[0]
async with aiohttp.ClientSession() as session:
tasks = [get_server_status(server, session) for server in servers]
results = await asyncio.gather(*tasks)
free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]]
if free_servers:
selected_server = random.choice(free_servers)
logger.info(
f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}"
)
return selected_server
else:
# 后备方案:选择一个可达的服务器,即使它很忙
reachable_servers = [
servers[i] for i, status in enumerate(results) if status["is_reachable"]
]
if reachable_servers:
selected_server = random.choice(reachable_servers)
logger.info(
f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}"
)
return selected_server
else:
# 最坏情况:所有服务器都不可达,抛出异常
raise ConnectionError("所有配置的ComfyUI服务器都不可达。")
async def execute_prompt_on_server(
workflow_data: Dict,
api_spec: Dict,
request_data: Dict,
server: ComfyUIServer,
workflow_run_id: str,
) -> Dict:
"""
在指定的服务器上执行一个准备好的prompt。
现在支持节点级别的状态跟踪。
"""
client_id = str(uuid.uuid4())
# 应用请求数据到工作流
patched_workflow = await patch_workflow(
workflow_data, api_spec, request_data, server
)
# 转换为prompt格式
prompt = convert_workflow_to_prompt_api_format(patched_workflow)
# 更新工作流运行状态记录prompt_id和client_id
await update_workflow_run_status(
workflow_run_id,
"running",
server.http_url,
None, # prompt_id将在_queue_prompt后更新
client_id,
)
# 提交到ComfyUI
prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url)
# 更新prompt_id
await update_workflow_run_status(
workflow_run_id, "running", server.http_url, prompt_id, client_id
)
logger.info(
f"工作流 {workflow_run_id} 已在 {server.http_url} 上入队Prompt ID: {prompt_id}"
)
# 获取执行结果,现在支持节点级别的状态跟踪
results = await _get_execution_results(
workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id
)
return results
async def submit_workflow_to_queue(
workflow_name: str, workflow_data: Dict, api_spec: Dict, request_data: Dict
) -> str:
"""
提交工作流到队列立即返回任务ID。
这是新的异步接口调用者可以通过任务ID查询状态。
"""
workflow_run_id = str(uuid.uuid4())
# 添加到队列管理器
await queue_manager.add_task(
workflow_run_id, workflow_name, workflow_data, api_spec, request_data
)
return workflow_run_id
def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
"""
解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。
"""
spec = {"inputs": {}, "outputs": {}}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError(
"Invalid workflow format: 'nodes' key not found or is not a list."
)
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
input_name_counter = defaultdict(int)
output_name_counter = defaultdict(int)
for node_id, node in nodes_map.items():
title: str = node.get("title")
if not title:
continue
if title.startswith(API_INPUT_PREFIX):
base_name = title[len(API_INPUT_PREFIX) :].lower()
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"].lower()
# [BUG修复] 构建有意义的基础参数名
param_name_candidate = f"{base_name}_{widget_name}"
input_name_counter[param_name_candidate] += 1
count = input_name_counter[param_name_candidate]
final_param_name = (
f"{param_name_candidate}_{count}"
if count > 1
else param_name_candidate
)
input_type_str = a_input.get("type", "STRING").upper()
param_type = "string"
if "COMBO" in input_type_str:
param_type = "UploadFile"
elif "INT" in input_type_str:
param_type = "int"
elif "FLOAT" in input_type_str:
param_type = "float"
spec["inputs"][final_param_name] = {
"node_id": node_id,
"type": param_type,
"widget_name": a_input["widget"]["name"],
}
elif title.startswith(API_OUTPUT_PREFIX):
base_name = title[len(API_OUTPUT_PREFIX) :].lower()
if "outputs" in node:
for an_output in node.get("outputs", []):
output_name = an_output["name"].lower()
# [BUG修复] 构建有意义的基础参数名
param_name_candidate = f"{base_name}_{output_name}"
output_name_counter[param_name_candidate] += 1
count = output_name_counter[param_name_candidate]
final_param_name = (
f"{param_name_candidate}_{count}"
if count > 1
else param_name_candidate
)
spec["outputs"][final_param_name] = {
"node_id": node_id,
"class_type": node.get("type"),
"output_name": an_output["name"],
"output_index": node["outputs"].index(an_output),
}
return spec
async def patch_workflow(
workflow_data: dict,
api_spec: dict[str, dict[str, Any]],
request_data: dict[str, Any],
server: ComfyUIServer,
) -> dict:
"""
将request_data中的参数值patch到workflow_data中。并返回修改后的workflow_data。
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for param_name, value in request_data.items():
if param_name not in api_spec["inputs"]:
continue
spec = api_spec["inputs"][param_name]
node_id = spec["node_id"]
if node_id not in nodes_map:
continue
target_node = nodes_map[node_id]
widget_name_to_patch = spec["widget_name"]
# 找到所有属于widget的输入并确定目标widget的索引
widget_inputs = [
inp for inp in target_node.get("inputs", []) if "widget" in inp
]
target_widget_index = -1
for i, widget_input in enumerate(widget_inputs):
# "widget" 字典可能不存在或没有 "name" 键
if widget_input.get("widget", {}).get("name") == widget_name_to_patch:
target_widget_index = i
break
if target_widget_index == -1:
logger.warning(
f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。"
)
continue
# 确保 `widgets_values` 存在且为列表
if "widgets_values" not in target_node or not isinstance(
target_node.get("widgets_values"), list
):
# 如果不存在或格式错误根据widget数量创建一个占位符列表
target_node["widgets_values"] = [None] * len(widget_inputs)
# 确保 `widgets_values` 列表足够长
while len(target_node["widgets_values"]) <= target_widget_index:
target_node["widgets_values"].append(None)
# 根据API规范转换数据类型
target_type = str
if spec["type"] == "int":
target_type = int
elif spec["type"] == "float":
target_type = float
# 在正确的位置上更新值
try:
if target_node.get("type") == "LoadImage":
value = await upload_image_to_comfy(value, server)
target_node["widgets_values"][target_widget_index] = target_type(value)
except (ValueError, TypeError) as e:
logger.warning(
f"无法将参数 '{param_name}' 的值 '{value}' 转换为类型 '{spec['type']}'。错误: {e}"
)
continue
workflow_data["nodes"] = list(nodes_map.values())
return workflow_data
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
"""
将工作流API格式转换为提交到/prompt端点的格式。
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
prompt_api_format = {}
# 建立从link_id到源节点的映射
link_map = {}
for link_data in workflow_data.get("links", []):
(
link_id,
origin_node_id,
origin_slot_index,
target_node_id,
target_slot_index,
link_type,
) = link_data
# 键是目标节点的输入link_id
link_map[link_id] = [str(origin_node_id), origin_slot_index]
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = {}
# 1. 处理控件输入 (widgets)
widgets_values = node.get("widgets_values", [])
widget_cursor = 0
for input_config in node.get("inputs", []):
# 如果是widget并且有对应的widgets_values
if "widget" in input_config and widget_cursor < len(widgets_values):
widget_name = input_config["widget"].get("name")
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config["name"]] = link_map[link_id]
prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict}
return prompt_api_format
async def _queue_prompt(
workflow: dict, prompt: dict, client_id: str, http_url: str
) -> str:
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
for node_id in prompt:
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
payload = {
"prompt": prompt,
"client_id": client_id,
"extra_data": {
"api_key_comfy_org": "",
"extra_pnginfo": {"workflow": workflow},
},
}
logger.info(f"提交到 ComfyUI /prompt 端点的payload: {json.dumps(payload)}")
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
prompt_url = f"{http_url}/prompt"
try:
async with session.post(prompt_url, json=payload) as response:
logger.info(f"ComfyUI /prompt 端点返回的响应: {response}")
response.raise_for_status()
result = await response.json()
logger.info(f"ComfyUI /prompt 端点返回的响应: {result}")
if "prompt_id" not in result:
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
return result["prompt_id"]
except Exception as e:
logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}")
raise e
async def _get_execution_results(
workflow_data: Dict,
prompt_id: str,
client_id: str,
ws_url: str,
workflow_run_id: str,
) -> dict:
"""
通过WebSocket连接到指定的ComfyUI服务器聚合执行结果。
现在支持节点级别的状态跟踪。
"""
full_ws_url = f"{ws_url}?clientId={client_id}"
aggregated_outputs = {}
try:
async with websockets.connect(full_ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data")
if not (data and data.get("prompt_id") == prompt_id):
continue
# 捕获并处理执行错误
if msg_type == "execution_error":
error_data = data
logger.error(
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
)
# 更新节点状态为失败
node_id = error_data.get("node_id")
if node_id:
await update_workflow_run_node_status(
workflow_run_id,
node_id,
"failed",
error_message=error_data.get(
"exception_message", "Unknown error"
),
)
# 抛出自定义异常,将错误详情传递出去
raise ComfyUIExecutionError(error_data)
# 处理节点开始执行
if msg_type == "executing" and data.get("node"):
node_id = data.get("node")
logger.info(f"节点 {node_id} 开始执行 (Prompt ID: {prompt_id})")
# 更新节点状态为运行中
await update_workflow_run_node_status(
workflow_run_id, node_id, "running"
)
# 处理节点执行完成
if msg_type == "executed":
node_id = data.get("node")
output_data = data.get("output")
if node_id and output_data:
node = next(
(
x
for x in workflow_data["nodes"]
if str(x["id"]) == node_id
),
None,
)
if (
node
and node.get("title", "")
and node["title"].startswith(API_OUTPUT_PREFIX)
):
aggregated_outputs[node["title"]] = output_data
logger.info(
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
)
# 更新节点状态为完成
await update_workflow_run_node_status(
workflow_run_id,
node_id,
"completed",
output_data=json.dumps(output_data),
)
# 处理整个工作流执行完成
elif msg_type == "executing" and data.get("node") is None:
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
)
return aggregated_outputs
except Exception as e:
# 重新抛出我们自己的异常,或者处理其他意外错误
if not isinstance(e, ComfyUIExecutionError):
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
raise e
except websockets.exceptions.InvalidURI as e:
logger.error(
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
)
raise e
return aggregated_outputs
async def _get_execution_results_legacy(
prompt_id: str, client_id: str, ws_url: str
) -> dict:
"""
简化版的执行结果获取函数,不包含数据库状态跟踪。
"""
full_ws_url = f"{ws_url}?clientId={client_id}"
aggregated_outputs = {}
try:
async with websockets.connect(full_ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data")
if not (data and data.get("prompt_id") == prompt_id):
continue
# 捕获并处理执行错误
if msg_type == "execution_error":
error_data = data
logger.error(
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
)
# 抛出自定义异常,将错误详情传递出去
raise ComfyUIExecutionError(error_data)
# 处理节点执行完成
if msg_type == "executed":
node_id = data.get("node")
output_data = data.get("output")
if node_id and output_data:
aggregated_outputs[node_id] = output_data
logger.info(
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
)
# 处理整个工作流执行完成
elif msg_type == "executing" and data.get("node") is None:
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
)
return aggregated_outputs
except Exception as e:
# 重新抛出我们自己的异常,或者处理其他意外错误
if not isinstance(e, ComfyUIExecutionError):
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
raise e
except websockets.exceptions.InvalidURI as e:
logger.error(
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
)
raise e
return aggregated_outputs
async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str:
"""
上传文件到服务器。
file 是一个http链接
返回一个名称
"""
file_name = file_path.split("/")[-1]
file_data = await download_file(file_path)
form_data = aiohttp.FormData()
form_data.add_field("image", file_data, filename=file_name, content_type="image/*")
form_data.add_field("type", "input")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{server.http_url}/api/upload/image", data=form_data
) as response:
response.raise_for_status()
result = await response.json()
return result["name"]
async def download_file(src: str) -> bytes:
"""
下载文件到本地。
"""
async with aiohttp.ClientSession() as session:
async with session.get(src) as response:
response.raise_for_status()
return await response.read()