ComfyUI-WorkflowPublisher/workflow_service/workflow_parser.py

185 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from collections import defaultdict
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
def parse_api_spec(workflow_data: dict) -> dict:
"""
解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。
"""
spec = {"inputs": {}, "outputs": {}}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
input_name_counter = defaultdict(int)
output_name_counter = defaultdict(int)
for node_id, node in nodes_map.items():
title = node.get("title")
if not title: continue
if title.startswith(API_INPUT_PREFIX):
base_name = title[len(API_INPUT_PREFIX):].lower()
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"].lower()
# [BUG修复] 构建有意义的基础参数名
param_name_candidate = f"{base_name}_{widget_name}"
input_name_counter[param_name_candidate] += 1
count = input_name_counter[param_name_candidate]
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
input_type_str = a_input.get("type", "STRING").upper()
param_type = "string"
if "COMBO" in input_type_str:
param_type = "UploadFile"
elif "INT" in input_type_str:
param_type = "int"
elif "FLOAT" in input_type_str:
param_type = "float"
spec["inputs"][final_param_name] = {
"node_id": node_id,
"type": param_type,
"widget_name": a_input["widget"]["name"]
}
elif title.startswith(API_OUTPUT_PREFIX):
base_name = title[len(API_OUTPUT_PREFIX):].lower()
if "outputs" in node:
for an_output in node.get("outputs", []):
output_name = an_output["name"].lower()
# [BUG修复] 构建有意义的基础参数名
param_name_candidate = f"{base_name}_{output_name}"
output_name_counter[param_name_candidate] += 1
count = output_name_counter[param_name_candidate]
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
spec["outputs"][final_param_name] = {
"node_id": node_id,
"class_type": node.get("type"),
"output_name": an_output["name"],
"output_index": node["outputs"].index(an_output)
}
return spec
def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict:
"""
[BUG修复] 根据API请求数据正确地修改工作流JSON中的 `widgets_values`。
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for param_name, value in request_data.items():
if param_name not in api_spec["inputs"]:
continue
spec = api_spec["inputs"][param_name]
node_id = spec["node_id"]
if node_id not in nodes_map:
continue
target_node = nodes_map[node_id]
widget_name_to_patch = spec["widget_name"]
# 找到所有属于widget的输入并确定目标widget的索引
widget_inputs = [inp for inp in target_node.get("inputs", []) if "widget" in inp]
target_widget_index = -1
for i, widget_input in enumerate(widget_inputs):
# "widget" 字典可能不存在或没有 "name" 键
if widget_input.get("widget", {}).get("name") == widget_name_to_patch:
target_widget_index = i
break
if target_widget_index == -1:
print(f"警告: 在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。")
continue
# 确保 `widgets_values` 存在且为列表
if "widgets_values" not in target_node or not isinstance(target_node.get("widgets_values"), list):
# 如果不存在或格式错误根据widget数量创建一个占位符列表
target_node["widgets_values"] = [None] * len(widget_inputs)
# 确保 `widgets_values` 列表足够长
while len(target_node["widgets_values"]) <= target_widget_index:
target_node["widgets_values"].append(None)
# 根据API规范转换数据类型
target_type = str
if spec['type'] == 'int':
target_type = int
elif spec['type'] == 'float':
target_type = float
# 在正确的位置上更新值
try:
target_node["widgets_values"][target_widget_index] = target_type(value)
except (ValueError, TypeError) as e:
print(f"警告: 无法将参数 '{param_name}' 的值 '{value}' 转换为类型 '{spec['type']}'。错误: {e}")
continue
workflow_data["nodes"] = list(nodes_map.values())
return workflow_data
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
"""
将工作流API格式转换为提交到/prompt端点的格式。
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
prompt_api_format = {}
# 建立从link_id到源节点的映射
link_map = {}
for link_data in workflow_data.get("links", []):
link_id, origin_node_id, origin_slot_index, target_node_id, target_slot_index, link_type = link_data
# 键是目标节点的输入link_id
link_map[link_id] = [str(origin_node_id), origin_slot_index]
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = {}
# 1. 处理控件输入 (widgets)
widgets_values = node.get("widgets_values", [])
widget_cursor = 0
for input_config in node.get("inputs", []):
# 如果是widget并且有对应的widgets_values
if "widget" in input_config and widget_cursor < len(widgets_values):
widget_name = input_config["widget"].get("name")
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config["name"]] = link_map[link_id]
prompt_api_format[node_id] = {
"class_type": node["type"],
"inputs": inputs_dict
}
return prompt_api_format