ComfyUI-WorkflowPublisher/workflow_service/workflow_parser.py

151 lines
7.0 KiB
Python

import re
from collections import defaultdict
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
def parse_api_spec(workflow_data: dict) -> dict:
"""
解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。
"""
spec = {"inputs": {}, "outputs": {}}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
# 用于处理重名的计数器
input_name_counter = defaultdict(int)
output_name_counter = defaultdict(int)
for node_id, node in nodes_map.items():
title = node.get("title")
if not title: continue
# --- 处理API输入 ---
if title.startswith(API_INPUT_PREFIX):
base_name = title[len(API_INPUT_PREFIX):].lower()
if "inputs" in node:
for a_input in node.get("inputs", []):
# 只关心由用户直接控制的widget输入
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"].lower()
# 构建基础参数名
param_name_candidate = f"{base_name}_{widget_name}"
# 处理重名
input_name_counter[param_name_candidate] += 1
count = input_name_counter[param_name_candidate]
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
# 类型推断
input_type_str = a_input.get("type", "STRING").upper()
if "COMBO" in a_input.get("type"):
# 如果是加载类节点,名字为'image'或'video'的widget是文件上传
if a_input["widget"]["name"] in ["image", "video"]:
param_type = "UploadFile"
else:
param_type = "string" # 加载节点其他参数默认为string
elif "INT" in input_type_str:
param_type = "int"
elif "FLOAT" in input_type_str:
param_type = "float"
else:
param_type = "string"
spec["inputs"][final_param_name] = {
"node_id": node_id,
"type": param_type,
"widget_name": a_input["widget"]["name"] # 保留原始widget名用于patch
}
# --- 处理API输出 ---
elif title.startswith(API_OUTPUT_PREFIX):
base_name = title[len(API_OUTPUT_PREFIX):].lower()
if "outputs" in node:
for an_output in node.get("outputs", []):
output_name = an_output["name"].lower()
# 构建基础参数名
param_name_candidate = f"{base_name}_{output_name}"
# 处理重名
output_name_counter[param_name_candidate] += 1
count = output_name_counter[param_name_candidate]
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
spec["outputs"][final_param_name] = {
"node_id": node_id,
"class_type": node.get("type"),
"output_name": an_output["name"], # 保留原始输出名用于结果匹配
"output_index": node["outputs"].index(an_output) # 保留索引
}
return spec
# --- 其他函数保持不变 ---
def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict:
# (此函数与上一版完全相同,无需改动)
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for param_name, value in request_data.items():
if param_name in api_spec["inputs"]:
spec = api_spec["inputs"][param_name]
node_id = spec["node_id"]
if node_id not in nodes_map: continue
target_node = nodes_map[node_id]
widget_name_to_patch = spec["widget_name"]
widgets_values = target_node.get("widgets_values", {})
if isinstance(widgets_values, dict):
widgets_values[widget_name_to_patch] = value
elif isinstance(widgets_values, list):
widget_cursor = 0
for input_config in target_node.get("inputs", []):
if "widget" in input_config:
if input_config["widget"].get("name") == widget_name_to_patch:
if widget_cursor < len(widgets_values):
target_type = str
if spec['type'] == 'int':
target_type = int
elif spec['type'] == 'float':
target_type = float
widgets_values[widget_cursor] = target_type(value)
break
widget_cursor += 1
target_node["widgets_values"] = widgets_values
workflow_data["nodes"] = list(nodes_map.values())
return workflow_data
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
# (此函数与上一版完全相同,无需改动)
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
prompt_api_format, link_map = {}, {}
for link in workflow_data.get("links", []):
link_map[link[0]] = [str(link[1]), link[2]]
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = {}
widgets_values = node.get("widgets_values", [])
if isinstance(widgets_values, dict):
for key, val in widgets_values.items():
if not isinstance(val, dict): inputs_dict[key] = val
elif isinstance(widgets_values, list):
widget_idx_counter = 0
for input_config in node.get("inputs", []):
if "widget" in input_config:
if widget_idx_counter < len(widgets_values):
inputs_dict[input_config["name"]] = widgets_values[widget_idx_counter]
widget_idx_counter += 1
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
inputs_dict[input_config["name"]] = link_map[link_id]
prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict}
return prompt_api_format