151 lines
7.0 KiB
Python
151 lines
7.0 KiB
Python
import re
|
|
from collections import defaultdict
|
|
|
|
API_INPUT_PREFIX = "INPUT_"
|
|
API_OUTPUT_PREFIX = "OUTPUT_"
|
|
|
|
|
|
def parse_api_spec(workflow_data: dict) -> dict:
|
|
"""
|
|
解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。
|
|
"""
|
|
spec = {"inputs": {}, "outputs": {}}
|
|
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
|
|
raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.")
|
|
|
|
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
|
|
|
# 用于处理重名的计数器
|
|
input_name_counter = defaultdict(int)
|
|
output_name_counter = defaultdict(int)
|
|
|
|
for node_id, node in nodes_map.items():
|
|
title = node.get("title")
|
|
if not title: continue
|
|
|
|
# --- 处理API输入 ---
|
|
if title.startswith(API_INPUT_PREFIX):
|
|
base_name = title[len(API_INPUT_PREFIX):].lower()
|
|
if "inputs" in node:
|
|
for a_input in node.get("inputs", []):
|
|
# 只关心由用户直接控制的widget输入
|
|
if a_input.get("link") is None and "widget" in a_input:
|
|
widget_name = a_input["widget"]["name"].lower()
|
|
|
|
# 构建基础参数名
|
|
param_name_candidate = f"{base_name}_{widget_name}"
|
|
|
|
# 处理重名
|
|
input_name_counter[param_name_candidate] += 1
|
|
count = input_name_counter[param_name_candidate]
|
|
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
|
|
|
|
# 类型推断
|
|
input_type_str = a_input.get("type", "STRING").upper()
|
|
if "COMBO" in a_input.get("type"):
|
|
# 如果是加载类节点,名字为'image'或'video'的widget是文件上传
|
|
if a_input["widget"]["name"] in ["image", "video"]:
|
|
param_type = "UploadFile"
|
|
else:
|
|
param_type = "string" # 加载节点其他参数默认为string
|
|
elif "INT" in input_type_str:
|
|
param_type = "int"
|
|
elif "FLOAT" in input_type_str:
|
|
param_type = "float"
|
|
else:
|
|
param_type = "string"
|
|
|
|
spec["inputs"][final_param_name] = {
|
|
"node_id": node_id,
|
|
"type": param_type,
|
|
"widget_name": a_input["widget"]["name"] # 保留原始widget名用于patch
|
|
}
|
|
|
|
# --- 处理API输出 ---
|
|
elif title.startswith(API_OUTPUT_PREFIX):
|
|
base_name = title[len(API_OUTPUT_PREFIX):].lower()
|
|
|
|
if "outputs" in node:
|
|
for an_output in node.get("outputs", []):
|
|
output_name = an_output["name"].lower()
|
|
|
|
# 构建基础参数名
|
|
param_name_candidate = f"{base_name}_{output_name}"
|
|
|
|
# 处理重名
|
|
output_name_counter[param_name_candidate] += 1
|
|
count = output_name_counter[param_name_candidate]
|
|
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
|
|
|
|
spec["outputs"][final_param_name] = {
|
|
"node_id": node_id,
|
|
"class_type": node.get("type"),
|
|
"output_name": an_output["name"], # 保留原始输出名用于结果匹配
|
|
"output_index": node["outputs"].index(an_output) # 保留索引
|
|
}
|
|
|
|
return spec
|
|
|
|
|
|
# --- 其他函数保持不变 ---
|
|
|
|
def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict:
|
|
# (此函数与上一版完全相同,无需改动)
|
|
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
|
|
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
|
for param_name, value in request_data.items():
|
|
if param_name in api_spec["inputs"]:
|
|
spec = api_spec["inputs"][param_name]
|
|
node_id = spec["node_id"]
|
|
if node_id not in nodes_map: continue
|
|
target_node = nodes_map[node_id]
|
|
widget_name_to_patch = spec["widget_name"]
|
|
widgets_values = target_node.get("widgets_values", {})
|
|
if isinstance(widgets_values, dict):
|
|
widgets_values[widget_name_to_patch] = value
|
|
elif isinstance(widgets_values, list):
|
|
widget_cursor = 0
|
|
for input_config in target_node.get("inputs", []):
|
|
if "widget" in input_config:
|
|
if input_config["widget"].get("name") == widget_name_to_patch:
|
|
if widget_cursor < len(widgets_values):
|
|
target_type = str
|
|
if spec['type'] == 'int':
|
|
target_type = int
|
|
elif spec['type'] == 'float':
|
|
target_type = float
|
|
widgets_values[widget_cursor] = target_type(value)
|
|
break
|
|
widget_cursor += 1
|
|
target_node["widgets_values"] = widgets_values
|
|
workflow_data["nodes"] = list(nodes_map.values())
|
|
return workflow_data
|
|
|
|
|
|
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
|
|
# (此函数与上一版完全相同,无需改动)
|
|
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
|
|
prompt_api_format, link_map = {}, {}
|
|
for link in workflow_data.get("links", []):
|
|
link_map[link[0]] = [str(link[1]), link[2]]
|
|
for node in workflow_data["nodes"]:
|
|
node_id = str(node["id"])
|
|
inputs_dict = {}
|
|
widgets_values = node.get("widgets_values", [])
|
|
if isinstance(widgets_values, dict):
|
|
for key, val in widgets_values.items():
|
|
if not isinstance(val, dict): inputs_dict[key] = val
|
|
elif isinstance(widgets_values, list):
|
|
widget_idx_counter = 0
|
|
for input_config in node.get("inputs", []):
|
|
if "widget" in input_config:
|
|
if widget_idx_counter < len(widgets_values):
|
|
inputs_dict[input_config["name"]] = widgets_values[widget_idx_counter]
|
|
widget_idx_counter += 1
|
|
for input_config in node.get("inputs", []):
|
|
if "link" in input_config and input_config["link"] is not None:
|
|
link_id = input_config["link"]
|
|
if link_id in link_map:
|
|
inputs_dict[input_config["name"]] = link_map[link_id]
|
|
prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict}
|
|
return prompt_api_format |