import re from collections import defaultdict API_INPUT_PREFIX = "INPUT_" API_OUTPUT_PREFIX = "OUTPUT_" def parse_api_spec(workflow_data: dict) -> dict: """ 解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。 """ spec = {"inputs": {}, "outputs": {}} if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list): raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.") nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} input_name_counter = defaultdict(int) output_name_counter = defaultdict(int) for node_id, node in nodes_map.items(): title = node.get("title") if not title: continue if title.startswith(API_INPUT_PREFIX): base_name = title[len(API_INPUT_PREFIX):].lower() if "inputs" in node: for a_input in node.get("inputs", []): if a_input.get("link") is None and "widget" in a_input: widget_name = a_input["widget"]["name"].lower() # [BUG修复] 构建有意义的基础参数名 param_name_candidate = f"{base_name}_{widget_name}" input_name_counter[param_name_candidate] += 1 count = input_name_counter[param_name_candidate] final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate input_type_str = a_input.get("type", "STRING").upper() param_type = "string" if "COMBO" in input_type_str: param_type = "UploadFile" elif "INT" in input_type_str: param_type = "int" elif "FLOAT" in input_type_str: param_type = "float" spec["inputs"][final_param_name] = { "node_id": node_id, "type": param_type, "widget_name": a_input["widget"]["name"] } elif title.startswith(API_OUTPUT_PREFIX): base_name = title[len(API_OUTPUT_PREFIX):].lower() if "outputs" in node: for an_output in node.get("outputs", []): output_name = an_output["name"].lower() # [BUG修复] 构建有意义的基础参数名 param_name_candidate = f"{base_name}_{output_name}" output_name_counter[param_name_candidate] += 1 count = output_name_counter[param_name_candidate] final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate spec["outputs"][final_param_name] = { "node_id": node_id, "class_type": node.get("type"), "output_name": an_output["name"], "output_index": node["outputs"].index(an_output) } return spec def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict: """ [BUG修复] 根据API请求数据,正确地修改工作流JSON中的 `widgets_values`。 """ if "nodes" not in workflow_data: raise ValueError("无效的工作流格式") nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} for param_name, value in request_data.items(): if param_name not in api_spec["inputs"]: continue spec = api_spec["inputs"][param_name] node_id = spec["node_id"] if node_id not in nodes_map: continue target_node = nodes_map[node_id] widget_name_to_patch = spec["widget_name"] # 找到所有属于widget的输入,并确定目标widget的索引 widget_inputs = [inp for inp in target_node.get("inputs", []) if "widget" in inp] target_widget_index = -1 for i, widget_input in enumerate(widget_inputs): # "widget" 字典可能不存在或没有 "name" 键 if widget_input.get("widget", {}).get("name") == widget_name_to_patch: target_widget_index = i break if target_widget_index == -1: print(f"警告: 在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。") continue # 确保 `widgets_values` 存在且为列表 if "widgets_values" not in target_node or not isinstance(target_node.get("widgets_values"), list): # 如果不存在或格式错误,根据widget数量创建一个占位符列表 target_node["widgets_values"] = [None] * len(widget_inputs) # 确保 `widgets_values` 列表足够长 while len(target_node["widgets_values"]) <= target_widget_index: target_node["widgets_values"].append(None) # 根据API规范转换数据类型 target_type = str if spec['type'] == 'int': target_type = int elif spec['type'] == 'float': target_type = float # 在正确的位置上更新值 try: target_node["widgets_values"][target_widget_index] = target_type(value) except (ValueError, TypeError) as e: print(f"警告: 无法将参数 '{param_name}' 的值 '{value}' 转换为类型 '{spec['type']}'。错误: {e}") continue workflow_data["nodes"] = list(nodes_map.values()) return workflow_data def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict: """ 将工作流(API格式)转换为提交到/prompt端点的格式。 此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。 """ if "nodes" not in workflow_data: raise ValueError("无效的工作流格式") prompt_api_format = {} # 建立从link_id到源节点的映射 link_map = {} for link_data in workflow_data.get("links", []): link_id, origin_node_id, origin_slot_index, target_node_id, target_slot_index, link_type = link_data # 键是目标节点的输入link_id link_map[link_id] = [str(origin_node_id), origin_slot_index] for node in workflow_data["nodes"]: node_id = str(node["id"]) inputs_dict = {} # 1. 处理控件输入 (widgets) widgets_values = node.get("widgets_values", []) widget_cursor = 0 for input_config in node.get("inputs", []): # 如果是widget并且有对应的widgets_values if "widget" in input_config and widget_cursor < len(widgets_values): widget_name = input_config["widget"].get("name") if widget_name: # 使用widgets_values中的值,因为这里已经包含了API传入的修改 inputs_dict[widget_name] = widgets_values[widget_cursor] widget_cursor += 1 # 2. 处理连接输入 (links) for input_config in node.get("inputs", []): if "link" in input_config and input_config["link"] is not None: link_id = input_config["link"] if link_id in link_map: # 输入名称是input_config中的'name' inputs_dict[input_config["name"]] = link_map[link_id] prompt_api_format[node_id] = { "class_type": node["type"], "inputs": inputs_dict } return prompt_api_format