966 lines
35 KiB
Python
966 lines
35 KiB
Python
import asyncio
|
||
import base64
|
||
import json
|
||
import logging
|
||
import os
|
||
import random
|
||
import uuid
|
||
import websockets
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
from typing import Dict, Any, Optional, List, Set, Union
|
||
|
||
import aiohttp
|
||
from aiohttp import ClientTimeout
|
||
|
||
from workflow_service.config import Settings, ComfyUIServer
|
||
from workflow_service.database import (
|
||
create_workflow_run,
|
||
update_workflow_run_status,
|
||
create_workflow_run_nodes,
|
||
update_workflow_run_node_status,
|
||
get_workflow_run,
|
||
get_pending_workflow_runs,
|
||
get_running_workflow_runs,
|
||
get_workflow_run_nodes,
|
||
)
|
||
|
||
settings = Settings()
|
||
|
||
API_INPUT_PREFIX = "INPUT_"
|
||
API_OUTPUT_PREFIX = "OUTPUT_"
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 全局任务队列管理器
|
||
class WorkflowQueueManager:
|
||
def __init__(self):
|
||
self.running_tasks = {} # server_url -> task_info
|
||
self.pending_tasks = [] # 等待队列
|
||
self.lock = asyncio.Lock()
|
||
|
||
async def add_task(
|
||
self,
|
||
workflow_run_id: str,
|
||
workflow_name: str,
|
||
workflow_data: dict,
|
||
api_spec: dict,
|
||
request_data: dict,
|
||
):
|
||
"""添加新任务到队列"""
|
||
async with self.lock:
|
||
# 创建任务记录
|
||
await create_workflow_run(
|
||
workflow_run_id=workflow_run_id,
|
||
workflow_name=workflow_name,
|
||
workflow_json=json.dumps(workflow_data),
|
||
api_spec=json.dumps(api_spec),
|
||
request_data=json.dumps(request_data),
|
||
)
|
||
|
||
# 创建工作流节点记录
|
||
nodes_data = []
|
||
for node in workflow_data.get("nodes", []):
|
||
nodes_data.append(
|
||
{"id": str(node["id"]), "type": node.get("type", "unknown")}
|
||
)
|
||
await create_workflow_run_nodes(workflow_run_id, nodes_data)
|
||
|
||
# 添加到待处理队列
|
||
self.pending_tasks.append(workflow_run_id)
|
||
logger.info(
|
||
f"任务 {workflow_run_id} 已添加到队列,当前队列长度: {len(self.pending_tasks)}"
|
||
)
|
||
|
||
# 尝试处理队列
|
||
asyncio.create_task(self._process_queue())
|
||
|
||
return workflow_run_id
|
||
|
||
async def _process_queue(self):
|
||
"""处理队列中的任务"""
|
||
async with self.lock:
|
||
if not self.pending_tasks:
|
||
return
|
||
|
||
# 检查是否有空闲的服务器
|
||
available_servers = await self._get_available_servers()
|
||
if not available_servers:
|
||
logger.info("没有可用的服务器,等待中...")
|
||
return
|
||
|
||
# 获取一个待处理任务
|
||
workflow_run_id = self.pending_tasks.pop(0)
|
||
server = available_servers[0]
|
||
|
||
# 标记任务为运行中
|
||
await update_workflow_run_status(
|
||
workflow_run_id, "running", server.http_url
|
||
)
|
||
self.running_tasks[server.http_url] = {
|
||
"workflow_run_id": workflow_run_id,
|
||
"started_at": datetime.now(),
|
||
}
|
||
|
||
# 启动任务执行
|
||
asyncio.create_task(self._execute_task(workflow_run_id, server))
|
||
|
||
async def _get_available_servers(self) -> list[ComfyUIServer]:
|
||
"""获取可用的服务器"""
|
||
available_servers = []
|
||
for server in settings.SERVERS:
|
||
if server.http_url not in self.running_tasks:
|
||
# 检查服务器状态
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
status = await get_server_status(server, session)
|
||
if status["is_reachable"] and status["is_free"]:
|
||
available_servers.append(server)
|
||
except Exception as e:
|
||
logger.warning(f"检查服务器 {server.http_url} 状态时出错: {e}")
|
||
|
||
return available_servers
|
||
|
||
async def _execute_task(self, workflow_run_id: str, server: ComfyUIServer):
|
||
"""执行任务"""
|
||
cleanup_paths = []
|
||
try:
|
||
# 获取工作流数据
|
||
workflow_run = await get_workflow_run(workflow_run_id)
|
||
if not workflow_run:
|
||
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
|
||
|
||
workflow_data = json.loads(workflow_run["workflow_json"])
|
||
api_spec = json.loads(workflow_run["api_spec"])
|
||
request_data = json.loads(workflow_run["request_data"])
|
||
|
||
# 执行工作流
|
||
result = await execute_prompt_on_server(
|
||
workflow_data, api_spec, request_data, server, workflow_run_id
|
||
)
|
||
|
||
# 保存处理后的结果到数据库
|
||
await update_workflow_run_status(
|
||
workflow_run_id,
|
||
"completed",
|
||
result=json.dumps(result, ensure_ascii=False),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"执行任务 {workflow_run_id} 时出错: {e}")
|
||
await update_workflow_run_status(
|
||
workflow_run_id, "failed", error_message=str(e)
|
||
)
|
||
finally:
|
||
# 清理临时文件
|
||
if cleanup_paths:
|
||
logger.info(f"正在清理 {len(cleanup_paths)} 个临时文件...")
|
||
for path in cleanup_paths:
|
||
try:
|
||
if os.path.exists(path):
|
||
os.remove(path)
|
||
logger.info(f" - 已删除: {path}")
|
||
except Exception as e:
|
||
logger.warning(f" - 删除 {path} 时出错: {e}")
|
||
|
||
# 清理运行状态
|
||
async with self.lock:
|
||
if server.http_url in self.running_tasks:
|
||
del self.running_tasks[server.http_url]
|
||
|
||
# 继续处理队列
|
||
asyncio.create_task(self._process_queue())
|
||
|
||
async def get_task_status(self, workflow_run_id: str) -> dict:
|
||
"""获取任务状态"""
|
||
workflow_run = await get_workflow_run(workflow_run_id)
|
||
if not workflow_run:
|
||
return {"error": "任务不存在"}
|
||
nodes = await get_workflow_run_nodes(workflow_run_id)
|
||
|
||
result = {
|
||
"id": workflow_run_id,
|
||
"status": workflow_run["status"],
|
||
"created_at": workflow_run["created_at"],
|
||
"started_at": workflow_run["started_at"],
|
||
"completed_at": workflow_run["completed_at"],
|
||
"server_url": workflow_run["server_url"],
|
||
"error_message": workflow_run["error_message"],
|
||
"nodes": nodes,
|
||
}
|
||
|
||
# 如果任务完成,从数据库获取结果
|
||
if workflow_run["status"] == "completed" and workflow_run.get("result"):
|
||
try:
|
||
result["result"] = json.loads(workflow_run["result"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
result["result"] = None
|
||
|
||
return result
|
||
|
||
|
||
# 全局队列管理器实例
|
||
queue_manager = WorkflowQueueManager()
|
||
|
||
|
||
# 定义一个自定义异常,用于封装来自ComfyUI的执行错误
|
||
class ComfyUIExecutionError(Exception):
|
||
def __init__(self, error_data: dict):
|
||
self.error_data = error_data
|
||
# 创建一个对开发者友好的异常消息
|
||
message = (
|
||
f"ComfyUI节点执行失败。节点ID: {error_data.get('node_id')}, "
|
||
f"节点类型: {error_data.get('node_type')}. "
|
||
f"错误: {error_data.get('exception_message', 'N/A')}"
|
||
)
|
||
super().__init__(message)
|
||
|
||
|
||
async def get_server_status(
|
||
server: ComfyUIServer, session: aiohttp.ClientSession
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
检查单个ComfyUI服务器的详细状态。
|
||
返回一个包含可达性、队列状态和详细队列内容的字典。
|
||
"""
|
||
# [BUG修复] 确保初始字典结构与成功时的结构一致,以满足Pydantic模型
|
||
status_info = {
|
||
"is_reachable": False,
|
||
"is_free": False,
|
||
"queue_details": {"running_count": 0, "pending_count": 0},
|
||
}
|
||
try:
|
||
queue_url = f"{server.http_url}/queue"
|
||
async with session.get(queue_url, timeout=60) as response:
|
||
response.raise_for_status()
|
||
queue_data = await response.json()
|
||
|
||
status_info["is_reachable"] = True
|
||
|
||
running_count = len(queue_data.get("queue_running", []))
|
||
pending_count = len(queue_data.get("queue_pending", []))
|
||
|
||
status_info["queue_details"] = {
|
||
"running_count": running_count,
|
||
"pending_count": pending_count,
|
||
}
|
||
|
||
status_info["is_free"] = running_count == 0 and pending_count == 0
|
||
|
||
except Exception as e:
|
||
# 当请求失败时,将返回上面定义的、结构正确的初始 status_info
|
||
logger.warning(f"无法检查服务器 {server.http_url} 的队列状态: {e}")
|
||
|
||
return status_info
|
||
|
||
|
||
async def select_server_for_execution() -> ComfyUIServer:
|
||
"""
|
||
智能选择一个ComfyUI服务器。
|
||
优先选择一个空闲的服务器,如果所有服务器都忙,则随机选择一个。
|
||
"""
|
||
servers = settings.SERVERS
|
||
if not servers:
|
||
raise ValueError("没有在 COMFYUI_SERVERS_JSON 中配置任何服务器。")
|
||
if len(servers) == 1:
|
||
return servers[0]
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
tasks = [get_server_status(server, session) for server in servers]
|
||
results = await asyncio.gather(*tasks)
|
||
|
||
free_servers = [servers[i] for i, status in enumerate(results) if status["is_free"]]
|
||
|
||
if free_servers:
|
||
selected_server = random.choice(free_servers)
|
||
logger.info(
|
||
f"发现 {len(free_servers)} 个空闲服务器。已选择: {selected_server.http_url}"
|
||
)
|
||
return selected_server
|
||
else:
|
||
# 后备方案:选择一个可达的服务器,即使它很忙
|
||
reachable_servers = [
|
||
servers[i] for i, status in enumerate(results) if status["is_reachable"]
|
||
]
|
||
if reachable_servers:
|
||
selected_server = random.choice(reachable_servers)
|
||
logger.info(
|
||
f"所有服务器当前都在忙。从可达服务器中随机选择: {selected_server.http_url}"
|
||
)
|
||
return selected_server
|
||
else:
|
||
# 最坏情况:所有服务器都不可达,抛出异常
|
||
raise ConnectionError("所有配置的ComfyUI服务器都不可达。")
|
||
|
||
|
||
async def execute_prompt_on_server(
|
||
workflow_data: Dict,
|
||
api_spec: Dict,
|
||
request_data: Dict,
|
||
server: ComfyUIServer,
|
||
workflow_run_id: str,
|
||
) -> Dict:
|
||
"""
|
||
在指定的服务器上执行一个准备好的prompt。
|
||
现在支持节点级别的状态跟踪。
|
||
"""
|
||
client_id = str(uuid.uuid4())
|
||
|
||
# 应用请求数据到工作流
|
||
patched_workflow = await patch_workflow(
|
||
workflow_data, api_spec, request_data, server
|
||
)
|
||
|
||
# 转换为prompt格式
|
||
prompt = convert_workflow_to_prompt_api_format(patched_workflow)
|
||
|
||
# 更新工作流运行状态,记录prompt_id和client_id
|
||
await update_workflow_run_status(
|
||
workflow_run_id,
|
||
"running",
|
||
server.http_url,
|
||
None, # prompt_id将在_queue_prompt后更新
|
||
client_id,
|
||
)
|
||
|
||
# 提交到ComfyUI
|
||
prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url)
|
||
|
||
# 更新prompt_id
|
||
await update_workflow_run_status(
|
||
workflow_run_id, "running", server.http_url, prompt_id, client_id
|
||
)
|
||
|
||
logger.info(
|
||
f"工作流 {workflow_run_id} 已在 {server.http_url} 上入队,Prompt ID: {prompt_id}"
|
||
)
|
||
|
||
# 获取执行结果,现在支持节点级别的状态跟踪
|
||
results = await _get_execution_results(
|
||
workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id
|
||
)
|
||
return results
|
||
|
||
|
||
async def submit_workflow_to_queue(
|
||
workflow_name: str, workflow_data: Dict, api_spec: Dict, request_data: Dict
|
||
) -> str:
|
||
"""
|
||
提交工作流到队列,立即返回任务ID。
|
||
这是新的异步接口,调用者可以通过任务ID查询状态。
|
||
"""
|
||
workflow_run_id = str(uuid.uuid4())
|
||
|
||
# 添加到队列管理器
|
||
await queue_manager.add_task(
|
||
workflow_run_id, workflow_name, workflow_data, api_spec, request_data
|
||
)
|
||
|
||
return workflow_run_id
|
||
|
||
|
||
def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
解析工作流,并根据规范生成嵌套结构的API参数名。
|
||
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
|
||
"""
|
||
spec = {"inputs": {}, "outputs": {}}
|
||
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
|
||
raise ValueError(
|
||
"Invalid workflow format: 'nodes' key not found or is not a list."
|
||
)
|
||
|
||
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||
|
||
for node_id, node in nodes_map.items():
|
||
title: str = node.get("title")
|
||
if not title:
|
||
continue
|
||
|
||
if title.startswith(API_INPUT_PREFIX):
|
||
# 提取节点名(去掉API_INPUT_PREFIX前缀)
|
||
node_name = title[len(API_INPUT_PREFIX) :].lower()
|
||
|
||
# 如果节点名不在inputs中,创建新的节点条目
|
||
if node_name not in spec["inputs"]:
|
||
spec["inputs"][node_name] = {
|
||
"node_id": node_id,
|
||
"class_type": node.get("type"),
|
||
}
|
||
|
||
if "inputs" in node:
|
||
for a_input in node.get("inputs", []):
|
||
if a_input.get("link") is None and "widget" in a_input:
|
||
widget_name = a_input["widget"]["name"].lower()
|
||
|
||
# 特殊处理:隐藏LoadImage节点的upload字段
|
||
if node.get("type") == "LoadImage" and widget_name == "upload":
|
||
continue
|
||
|
||
# 确保inputs字段存在
|
||
if "inputs" not in spec["inputs"][node_name]:
|
||
spec["inputs"][node_name]["inputs"] = {}
|
||
|
||
# 直接使用widget_name作为字段名
|
||
spec["inputs"][node_name]["inputs"][widget_name] = {
|
||
"type": _get_param_type(a_input.get("type", "STRING")),
|
||
"widget_name": a_input["widget"]["name"],
|
||
}
|
||
|
||
elif title.startswith(API_OUTPUT_PREFIX):
|
||
# 提取节点名(去掉API_OUTPUT_PREFIX前缀)
|
||
node_name = title[len(API_OUTPUT_PREFIX) :].lower()
|
||
|
||
# 如果节点名不在outputs中,创建新的节点条目
|
||
if node_name not in spec["outputs"]:
|
||
spec["outputs"][node_name] = {
|
||
"node_id": node_id,
|
||
"class_type": node.get("type"),
|
||
}
|
||
|
||
if "outputs" in node:
|
||
for an_output in node.get("outputs", []):
|
||
output_name = an_output["name"].lower()
|
||
|
||
# 确保outputs字段存在
|
||
if "outputs" not in spec["outputs"][node_name]:
|
||
spec["outputs"][node_name]["outputs"] = {}
|
||
|
||
# 直接使用output_name作为字段名
|
||
spec["outputs"][node_name]["outputs"][output_name] = {
|
||
"output_name": an_output["name"],
|
||
}
|
||
|
||
return spec
|
||
|
||
|
||
def parse_inputs_json_schema(workflow_data: dict) -> dict:
|
||
"""
|
||
解析工作流,生成符合 JSON Schema 标准的 API 规范。
|
||
返回一个只包含输入参数的 JSON Schema 对象。
|
||
"""
|
||
# 基础 JSON Schema 结构,只包含输入
|
||
schema = {
|
||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||
"type": "object",
|
||
"title": "ComfyUI Workflow Input Schema",
|
||
"description": "工作流输入参数定义",
|
||
"properties": {},
|
||
"required": [],
|
||
}
|
||
|
||
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
|
||
raise ValueError(
|
||
"Invalid workflow format: 'nodes' key not found or is not a list."
|
||
)
|
||
|
||
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||
|
||
# 收集所有输入节点
|
||
input_nodes = {}
|
||
|
||
for node_id, node in nodes_map.items():
|
||
title: str = node.get("title")
|
||
if not title:
|
||
continue
|
||
|
||
if title.startswith(API_INPUT_PREFIX):
|
||
# 提取节点名(去掉API_INPUT_PREFIX前缀)
|
||
node_name = title[len(API_INPUT_PREFIX) :].lower()
|
||
|
||
if node_name not in input_nodes:
|
||
input_nodes[node_name] = {
|
||
"node_id": node_id,
|
||
"class_type": node.get("type"),
|
||
"properties": {},
|
||
"required": [],
|
||
}
|
||
|
||
if "inputs" in node:
|
||
for a_input in node.get("inputs", []):
|
||
if a_input.get("link") is None and "widget" in a_input:
|
||
widget_name = a_input["widget"]["name"].lower()
|
||
|
||
# 特殊处理:隐藏LoadImage节点的upload字段
|
||
if node.get("type") == "LoadImage" and widget_name == "upload":
|
||
continue
|
||
|
||
# 生成字段的 JSON Schema 定义
|
||
field_schema = _generate_field_schema(a_input)
|
||
input_nodes[node_name]["properties"][widget_name] = field_schema
|
||
|
||
# 如果字段是必需的,添加到required列表
|
||
if a_input.get("required", True): # 默认所有字段都是必需的
|
||
input_nodes[node_name]["required"].append(widget_name)
|
||
|
||
# 将收集的节点信息转换为 JSON Schema 格式
|
||
for node_name, node_info in input_nodes.items():
|
||
schema["properties"][node_name] = {
|
||
"type": "object",
|
||
"title": f"输入节点: {node_name}",
|
||
"description": f"节点类型: {node_info['class_type']}",
|
||
"properties": node_info["properties"],
|
||
"required": node_info["required"] if node_info["required"] else [],
|
||
"additionalProperties": False,
|
||
}
|
||
|
||
# 如果节点有必需字段,将整个节点标记为必需
|
||
if node_info["required"]:
|
||
schema["required"].append(node_name)
|
||
|
||
return schema
|
||
|
||
|
||
def _generate_field_schema(input_field: dict) -> dict:
|
||
"""
|
||
根据输入字段信息生成 JSON Schema 字段定义
|
||
"""
|
||
field_schema = {
|
||
"title": input_field["widget"]["name"],
|
||
"description": input_field.get("description", ""),
|
||
}
|
||
|
||
# 根据字段类型设置 schema 类型
|
||
field_type = _get_param_type(input_field.get("type", "STRING"))
|
||
|
||
if field_type == "int":
|
||
field_schema.update(
|
||
{
|
||
"type": "integer",
|
||
"minimum": input_field.get("min", None),
|
||
"maximum": input_field.get("max", None),
|
||
}
|
||
)
|
||
elif field_type == "float":
|
||
field_schema.update(
|
||
{
|
||
"type": "number",
|
||
"minimum": input_field.get("min", None),
|
||
"maximum": input_field.get("max", None),
|
||
}
|
||
)
|
||
elif field_type == "UploadFile":
|
||
field_schema.update(
|
||
{"type": "string", "format": "binary", "description": "上传文件"}
|
||
)
|
||
else: # string
|
||
field_schema.update({"type": "string"})
|
||
|
||
# 处理枚举值
|
||
if "options" in input_field:
|
||
field_schema["enum"] = input_field["options"]
|
||
field_schema[
|
||
"description"
|
||
] += f" 可选值: {', '.join(input_field['options'])}"
|
||
|
||
# 设置默认值
|
||
if "default" in input_field:
|
||
field_schema["default"] = input_field["default"]
|
||
|
||
return field_schema
|
||
|
||
|
||
def _get_param_type(input_type_str: str) -> str:
|
||
"""
|
||
根据输入类型字符串确定参数类型
|
||
"""
|
||
input_type_str = input_type_str.upper()
|
||
if "COMBO" in input_type_str:
|
||
return "UploadFile"
|
||
elif "INT" in input_type_str:
|
||
return "int"
|
||
elif "FLOAT" in input_type_str:
|
||
return "float"
|
||
else:
|
||
return "string"
|
||
|
||
|
||
async def patch_workflow(
|
||
workflow_data: dict,
|
||
api_spec: dict[str, dict[str, Any]],
|
||
request_data: dict[str, Any],
|
||
server: ComfyUIServer,
|
||
) -> dict:
|
||
"""
|
||
将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。
|
||
|
||
request_data结构: {"节点名称": {"字段名称": "字段值"}}
|
||
"""
|
||
if "nodes" not in workflow_data:
|
||
raise ValueError("无效的工作流格式")
|
||
|
||
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||
|
||
for node_name, node_fields in request_data.items():
|
||
if node_name not in api_spec["inputs"]:
|
||
continue
|
||
|
||
node_spec = api_spec["inputs"][node_name]
|
||
node_id = node_spec["node_id"]
|
||
if node_id not in nodes_map:
|
||
continue
|
||
|
||
target_node = nodes_map[node_id]
|
||
|
||
# 处理该节点下的所有字段
|
||
for field_name, value in node_fields.items():
|
||
if "inputs" not in node_spec or field_name not in node_spec["inputs"]:
|
||
continue
|
||
|
||
field_spec = node_spec["inputs"][field_name]
|
||
widget_name_to_patch = field_spec["widget_name"]
|
||
|
||
# 找到所有属于widget的输入,并确定目标widget的索引
|
||
widget_inputs = [
|
||
inp for inp in target_node.get("inputs", []) if "widget" in inp
|
||
]
|
||
|
||
target_widget_index = -1
|
||
for i, widget_input in enumerate(widget_inputs):
|
||
# "widget" 字典可能不存在或没有 "name" 键
|
||
if widget_input.get("widget", {}).get("name") == widget_name_to_patch:
|
||
target_widget_index = i
|
||
break
|
||
|
||
if target_widget_index == -1:
|
||
logger.warning(
|
||
f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。"
|
||
)
|
||
continue
|
||
|
||
# 确保 `widgets_values` 存在且为列表
|
||
if "widgets_values" not in target_node or not isinstance(
|
||
target_node.get("widgets_values"), list
|
||
):
|
||
# 如果不存在或格式错误,根据widget数量创建一个占位符列表
|
||
target_node["widgets_values"] = [None] * len(widget_inputs)
|
||
|
||
# 确保 `widgets_values` 列表足够长
|
||
while len(target_node["widgets_values"]) <= target_widget_index:
|
||
target_node["widgets_values"].append(None)
|
||
|
||
# 根据API规范转换数据类型
|
||
target_type = str
|
||
if field_spec["type"] == "int":
|
||
target_type = int
|
||
elif field_spec["type"] == "float":
|
||
target_type = float
|
||
|
||
# 在正确的位置上更新值
|
||
try:
|
||
if target_node.get("type") == "LoadImage":
|
||
value = await upload_image_to_comfy(value, server)
|
||
target_node["widgets_values"][target_widget_index] = target_type(value)
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(
|
||
f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec['type']}'。错误: {e}"
|
||
)
|
||
continue
|
||
|
||
workflow_data["nodes"] = list(nodes_map.values())
|
||
return workflow_data
|
||
|
||
|
||
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
|
||
"""
|
||
将工作流(API格式)转换为提交到/prompt端点的格式。
|
||
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
|
||
"""
|
||
if "nodes" not in workflow_data:
|
||
raise ValueError("无效的工作流格式")
|
||
|
||
prompt_api_format = {}
|
||
|
||
# 建立从link_id到源节点的映射
|
||
link_map = {}
|
||
for link_data in workflow_data.get("links", []):
|
||
(
|
||
link_id,
|
||
origin_node_id,
|
||
origin_slot_index,
|
||
target_node_id,
|
||
target_slot_index,
|
||
link_type,
|
||
) = link_data
|
||
|
||
# 键是目标节点的输入link_id
|
||
link_map[link_id] = [str(origin_node_id), origin_slot_index]
|
||
|
||
for node in workflow_data["nodes"]:
|
||
node_id = str(node["id"])
|
||
inputs_dict = {}
|
||
|
||
# 1. 处理控件输入 (widgets)
|
||
widgets_values = node.get("widgets_values", [])
|
||
widget_cursor = 0
|
||
for input_config in node.get("inputs", []):
|
||
# 如果是widget并且有对应的widgets_values
|
||
if "widget" in input_config and widget_cursor < len(widgets_values):
|
||
widget_name = input_config["widget"].get("name")
|
||
if widget_name:
|
||
# 使用widgets_values中的值,因为这里已经包含了API传入的修改
|
||
inputs_dict[widget_name] = widgets_values[widget_cursor]
|
||
widget_cursor += 1
|
||
|
||
# 2. 处理连接输入 (links)
|
||
for input_config in node.get("inputs", []):
|
||
if "link" in input_config and input_config["link"] is not None:
|
||
link_id = input_config["link"]
|
||
if link_id in link_map:
|
||
# 输入名称是input_config中的'name'
|
||
inputs_dict[input_config["name"]] = link_map[link_id]
|
||
|
||
prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict}
|
||
|
||
return prompt_api_format
|
||
|
||
|
||
async def _queue_prompt(
|
||
workflow: dict, prompt: dict, client_id: str, http_url: str
|
||
) -> str:
|
||
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
|
||
for node_id in prompt:
|
||
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
|
||
payload = {
|
||
"prompt": prompt,
|
||
"client_id": client_id,
|
||
"extra_data": {
|
||
"api_key_comfy_org": "",
|
||
"extra_pnginfo": {"workflow": workflow},
|
||
},
|
||
}
|
||
logger.info(f"提交到 ComfyUI /prompt 端点的payload: {json.dumps(payload)}")
|
||
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
|
||
prompt_url = f"{http_url}/prompt"
|
||
try:
|
||
async with session.post(prompt_url, json=payload) as response:
|
||
logger.info(f"ComfyUI /prompt 端点返回的响应: {response}")
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
logger.info(f"ComfyUI /prompt 端点返回的响应: {result}")
|
||
if "prompt_id" not in result:
|
||
raise Exception(f"从 ComfyUI /prompt 端点返回的响应无效: {result}")
|
||
return result["prompt_id"]
|
||
except Exception as e:
|
||
logger.error(f"提交到 ComfyUI /prompt 端点时发生错误: {e}")
|
||
raise e
|
||
|
||
|
||
async def _get_execution_results(
|
||
workflow_data: Dict,
|
||
prompt_id: str,
|
||
client_id: str,
|
||
ws_url: str,
|
||
workflow_run_id: str,
|
||
) -> dict:
|
||
"""
|
||
通过WebSocket连接到指定的ComfyUI服务器,聚合执行结果。
|
||
现在支持节点级别的状态跟踪。
|
||
"""
|
||
full_ws_url = f"{ws_url}?clientId={client_id}"
|
||
aggregated_outputs = {}
|
||
|
||
try:
|
||
async with websockets.connect(full_ws_url) as websocket:
|
||
while True:
|
||
try:
|
||
out = await websocket.recv()
|
||
if not isinstance(out, str):
|
||
continue
|
||
|
||
message = json.loads(out)
|
||
msg_type = message.get("type")
|
||
data = message.get("data")
|
||
|
||
if not (data and data.get("prompt_id") == prompt_id):
|
||
continue
|
||
|
||
# 捕获并处理执行错误
|
||
if msg_type == "execution_error":
|
||
error_data = data
|
||
logger.error(
|
||
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
|
||
)
|
||
|
||
# 更新节点状态为失败
|
||
node_id = error_data.get("node_id")
|
||
if node_id:
|
||
await update_workflow_run_node_status(
|
||
workflow_run_id,
|
||
node_id,
|
||
"failed",
|
||
error_message=error_data.get(
|
||
"exception_message", "Unknown error"
|
||
),
|
||
)
|
||
|
||
# 抛出自定义异常,将错误详情传递出去
|
||
raise ComfyUIExecutionError(error_data)
|
||
|
||
# 处理节点开始执行
|
||
if msg_type == "executing" and data.get("node"):
|
||
node_id = data.get("node")
|
||
logger.info(f"节点 {node_id} 开始执行 (Prompt ID: {prompt_id})")
|
||
|
||
# 更新节点状态为运行中
|
||
await update_workflow_run_node_status(
|
||
workflow_run_id, node_id, "running"
|
||
)
|
||
|
||
# 处理节点执行完成
|
||
if msg_type == "executed":
|
||
node_id = data.get("node")
|
||
output_data = data.get("output")
|
||
if node_id and output_data:
|
||
node = next(
|
||
(
|
||
x
|
||
for x in workflow_data["nodes"]
|
||
if str(x["id"]) == node_id
|
||
),
|
||
None,
|
||
)
|
||
if (
|
||
node
|
||
and node.get("title", "")
|
||
and node["title"].startswith(API_OUTPUT_PREFIX)
|
||
):
|
||
aggregated_outputs[node["title"]] = output_data
|
||
logger.info(
|
||
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
|
||
)
|
||
|
||
# 更新节点状态为完成
|
||
await update_workflow_run_node_status(
|
||
workflow_run_id,
|
||
node_id,
|
||
"completed",
|
||
output_data=json.dumps(output_data),
|
||
)
|
||
|
||
# 处理整个工作流执行完成
|
||
elif msg_type == "executing" and data.get("node") is None:
|
||
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
|
||
return aggregated_outputs
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
logger.warning(
|
||
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
|
||
)
|
||
return aggregated_outputs
|
||
except Exception as e:
|
||
# 重新抛出我们自己的异常,或者处理其他意外错误
|
||
if not isinstance(e, ComfyUIExecutionError):
|
||
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
|
||
raise e
|
||
|
||
except websockets.exceptions.InvalidURI as e:
|
||
logger.error(
|
||
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
|
||
)
|
||
raise e
|
||
return aggregated_outputs
|
||
|
||
|
||
async def _get_execution_results_legacy(
|
||
prompt_id: str, client_id: str, ws_url: str
|
||
) -> dict:
|
||
"""
|
||
简化版的执行结果获取函数,不包含数据库状态跟踪。
|
||
"""
|
||
full_ws_url = f"{ws_url}?clientId={client_id}"
|
||
aggregated_outputs = {}
|
||
|
||
try:
|
||
async with websockets.connect(full_ws_url) as websocket:
|
||
while True:
|
||
try:
|
||
out = await websocket.recv()
|
||
if not isinstance(out, str):
|
||
continue
|
||
|
||
message = json.loads(out)
|
||
msg_type = message.get("type")
|
||
data = message.get("data")
|
||
|
||
if not (data and data.get("prompt_id") == prompt_id):
|
||
continue
|
||
|
||
# 捕获并处理执行错误
|
||
if msg_type == "execution_error":
|
||
error_data = data
|
||
logger.error(
|
||
f"ComfyUI执行错误 (Prompt ID: {prompt_id}): {error_data}"
|
||
)
|
||
# 抛出自定义异常,将错误详情传递出去
|
||
raise ComfyUIExecutionError(error_data)
|
||
|
||
# 处理节点执行完成
|
||
if msg_type == "executed":
|
||
node_id = data.get("node")
|
||
output_data = data.get("output")
|
||
if node_id and output_data:
|
||
aggregated_outputs[node_id] = output_data
|
||
logger.info(
|
||
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
|
||
)
|
||
|
||
# 处理整个工作流执行完成
|
||
elif msg_type == "executing" and data.get("node") is None:
|
||
logger.info(f"Prompt ID: {prompt_id} 执行完成。")
|
||
return aggregated_outputs
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
logger.warning(
|
||
f"WebSocket 连接已关闭 (Prompt ID: {prompt_id})。错误: {e}"
|
||
)
|
||
return aggregated_outputs
|
||
except Exception as e:
|
||
# 重新抛出我们自己的异常,或者处理其他意外错误
|
||
if not isinstance(e, ComfyUIExecutionError):
|
||
logger.error(f"处理 prompt {prompt_id} 时发生意外错误: {e}")
|
||
raise e
|
||
|
||
except websockets.exceptions.InvalidURI as e:
|
||
logger.error(
|
||
f"错误: 尝试连接的WebSocket URI无效: '{full_ws_url}'. 原始URL: '{ws_url}'. 错误: {e}"
|
||
)
|
||
raise e
|
||
|
||
return aggregated_outputs
|
||
|
||
|
||
async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str:
|
||
"""
|
||
上传文件到服务器。
|
||
file 是一个http链接
|
||
返回一个名称
|
||
"""
|
||
file_name = file_path.split("/")[-1]
|
||
file_data = await download_file(file_path)
|
||
|
||
form_data = aiohttp.FormData()
|
||
form_data.add_field("image", file_data, filename=file_name, content_type="image/*")
|
||
form_data.add_field("type", "input")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{server.http_url}/api/upload/image", data=form_data
|
||
) as response:
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
return result["name"]
|
||
|
||
|
||
async def download_file(src: str) -> bytes:
|
||
"""
|
||
下载文件到本地。
|
||
"""
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(src) as response:
|
||
response.raise_for_status()
|
||
return await response.read()
|