185 lines
7.5 KiB
Python
185 lines
7.5 KiB
Python
import re
|
||
from collections import defaultdict
|
||
|
||
API_INPUT_PREFIX = "INPUT_"
|
||
API_OUTPUT_PREFIX = "OUTPUT_"
|
||
|
||
|
||
def parse_api_spec(workflow_data: dict) -> dict:
|
||
"""
|
||
解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。
|
||
"""
|
||
spec = {"inputs": {}, "outputs": {}}
|
||
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
|
||
raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.")
|
||
|
||
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||
input_name_counter = defaultdict(int)
|
||
output_name_counter = defaultdict(int)
|
||
|
||
for node_id, node in nodes_map.items():
|
||
title = node.get("title")
|
||
if not title: continue
|
||
|
||
if title.startswith(API_INPUT_PREFIX):
|
||
base_name = title[len(API_INPUT_PREFIX):].lower()
|
||
if "inputs" in node:
|
||
for a_input in node.get("inputs", []):
|
||
if a_input.get("link") is None and "widget" in a_input:
|
||
widget_name = a_input["widget"]["name"].lower()
|
||
|
||
# [BUG修复] 构建有意义的基础参数名
|
||
param_name_candidate = f"{base_name}_{widget_name}"
|
||
|
||
input_name_counter[param_name_candidate] += 1
|
||
count = input_name_counter[param_name_candidate]
|
||
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
|
||
|
||
input_type_str = a_input.get("type", "STRING").upper()
|
||
param_type = "string"
|
||
if "COMBO" in input_type_str:
|
||
param_type = "UploadFile"
|
||
elif "INT" in input_type_str:
|
||
param_type = "int"
|
||
elif "FLOAT" in input_type_str:
|
||
param_type = "float"
|
||
|
||
spec["inputs"][final_param_name] = {
|
||
"node_id": node_id,
|
||
"type": param_type,
|
||
"widget_name": a_input["widget"]["name"]
|
||
}
|
||
|
||
elif title.startswith(API_OUTPUT_PREFIX):
|
||
base_name = title[len(API_OUTPUT_PREFIX):].lower()
|
||
if "outputs" in node:
|
||
for an_output in node.get("outputs", []):
|
||
output_name = an_output["name"].lower()
|
||
|
||
# [BUG修复] 构建有意义的基础参数名
|
||
param_name_candidate = f"{base_name}_{output_name}"
|
||
|
||
output_name_counter[param_name_candidate] += 1
|
||
count = output_name_counter[param_name_candidate]
|
||
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
|
||
|
||
spec["outputs"][final_param_name] = {
|
||
"node_id": node_id,
|
||
"class_type": node.get("type"),
|
||
"output_name": an_output["name"],
|
||
"output_index": node["outputs"].index(an_output)
|
||
}
|
||
|
||
return spec
|
||
|
||
|
||
def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict:
|
||
"""
|
||
[BUG修复] 根据API请求数据,正确地修改工作流JSON中的 `widgets_values`。
|
||
"""
|
||
if "nodes" not in workflow_data:
|
||
raise ValueError("无效的工作流格式")
|
||
|
||
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||
|
||
for param_name, value in request_data.items():
|
||
if param_name not in api_spec["inputs"]:
|
||
continue
|
||
|
||
spec = api_spec["inputs"][param_name]
|
||
node_id = spec["node_id"]
|
||
if node_id not in nodes_map:
|
||
continue
|
||
|
||
target_node = nodes_map[node_id]
|
||
widget_name_to_patch = spec["widget_name"]
|
||
|
||
# 找到所有属于widget的输入,并确定目标widget的索引
|
||
widget_inputs = [inp for inp in target_node.get("inputs", []) if "widget" in inp]
|
||
|
||
target_widget_index = -1
|
||
for i, widget_input in enumerate(widget_inputs):
|
||
# "widget" 字典可能不存在或没有 "name" 键
|
||
if widget_input.get("widget", {}).get("name") == widget_name_to_patch:
|
||
target_widget_index = i
|
||
break
|
||
|
||
if target_widget_index == -1:
|
||
print(f"警告: 在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。")
|
||
continue
|
||
|
||
# 确保 `widgets_values` 存在且为列表
|
||
if "widgets_values" not in target_node or not isinstance(target_node.get("widgets_values"), list):
|
||
# 如果不存在或格式错误,根据widget数量创建一个占位符列表
|
||
target_node["widgets_values"] = [None] * len(widget_inputs)
|
||
|
||
# 确保 `widgets_values` 列表足够长
|
||
while len(target_node["widgets_values"]) <= target_widget_index:
|
||
target_node["widgets_values"].append(None)
|
||
|
||
# 根据API规范转换数据类型
|
||
target_type = str
|
||
if spec['type'] == 'int':
|
||
target_type = int
|
||
elif spec['type'] == 'float':
|
||
target_type = float
|
||
|
||
# 在正确的位置上更新值
|
||
try:
|
||
target_node["widgets_values"][target_widget_index] = target_type(value)
|
||
except (ValueError, TypeError) as e:
|
||
print(f"警告: 无法将参数 '{param_name}' 的值 '{value}' 转换为类型 '{spec['type']}'。错误: {e}")
|
||
continue
|
||
|
||
workflow_data["nodes"] = list(nodes_map.values())
|
||
return workflow_data
|
||
|
||
|
||
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
|
||
"""
|
||
将工作流(API格式)转换为提交到/prompt端点的格式。
|
||
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
|
||
"""
|
||
if "nodes" not in workflow_data:
|
||
raise ValueError("无效的工作流格式")
|
||
|
||
prompt_api_format = {}
|
||
|
||
# 建立从link_id到源节点的映射
|
||
link_map = {}
|
||
for link_data in workflow_data.get("links", []):
|
||
link_id, origin_node_id, origin_slot_index, target_node_id, target_slot_index, link_type = link_data
|
||
|
||
# 键是目标节点的输入link_id
|
||
link_map[link_id] = [str(origin_node_id), origin_slot_index]
|
||
|
||
for node in workflow_data["nodes"]:
|
||
node_id = str(node["id"])
|
||
inputs_dict = {}
|
||
|
||
# 1. 处理控件输入 (widgets)
|
||
widgets_values = node.get("widgets_values", [])
|
||
widget_cursor = 0
|
||
for input_config in node.get("inputs", []):
|
||
# 如果是widget并且有对应的widgets_values
|
||
if "widget" in input_config and widget_cursor < len(widgets_values):
|
||
widget_name = input_config["widget"].get("name")
|
||
if widget_name:
|
||
# 使用widgets_values中的值,因为这里已经包含了API传入的修改
|
||
inputs_dict[widget_name] = widgets_values[widget_cursor]
|
||
widget_cursor += 1
|
||
|
||
# 2. 处理连接输入 (links)
|
||
for input_config in node.get("inputs", []):
|
||
if "link" in input_config and input_config["link"] is not None:
|
||
link_id = input_config["link"]
|
||
if link_id in link_map:
|
||
# 输入名称是input_config中的'name'
|
||
inputs_dict[input_config["name"]] = link_map[link_id]
|
||
|
||
prompt_api_format[node_id] = {
|
||
"class_type": node["type"],
|
||
"inputs": inputs_dict
|
||
}
|
||
|
||
return prompt_api_format |