From 829daebc617658d4a11bcc1141c0c862a5fb753b Mon Sep 17 00:00:00 2001 From: iHeyTang Date: Wed, 20 Aug 2025 13:52:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E8=8A=82=E7=82=B9?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=8E=B7=E5=8F=96=E9=BB=98=E8=AE=A4=E5=80=BC?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96JSON=20Schema=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- workflow_service/comfy/comfy_workflow.py | 86 ++++++++++++++++-------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/workflow_service/comfy/comfy_workflow.py b/workflow_service/comfy/comfy_workflow.py index fe3ffeb..3cf78a4 100644 --- a/workflow_service/comfy/comfy_workflow.py +++ b/workflow_service/comfy/comfy_workflow.py @@ -158,8 +158,14 @@ def parse_inputs_json_schema(workflow_data: dict) -> dict: if node.get("type") == "LoadImage" and widget_name == "upload": continue - # 生成字段的 JSON Schema 定义 - field_schema = _generate_field_schema(a_input) + # 获取节点的当前值作为默认值 + # 创建空的link_map,因为输入节点通常没有连接 + empty_link_map = {} + node_current_inputs = _get_node_inputs(node, empty_link_map) + current_value = node_current_inputs.get(widget_name) + + # 生成字段的 JSON Schema 定义,包含默认值 + field_schema = _generate_field_schema(a_input, current_value) input_nodes[node_name]["properties"][widget_name] = field_schema # 如果字段是必需的,添加到required列表 @@ -306,31 +312,7 @@ def _convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict: for node in workflow_data["nodes"]: node_id = str(node["id"]) - inputs_dict = {} - # 1. 处理控件输入 (widgets) - widgets_values = node.get("widgets_values", []) - widget_cursor = 0 - for input_config in node.get("inputs", []): - # 如果是widget并且有对应的widgets_values - if "widget" in input_config and widget_cursor < len(widgets_values): - widget_name = input_config["widget"].get("name") - if widget_name: - # 使用widgets_values中的值,因为这里已经包含了API传入的修改 - inputs_dict[widget_name] = widgets_values[widget_cursor] - widget_cursor += 1 - - # 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值 - if node["type"] == "EG_CYQ_JB" and widget_cursor == 1: - widget_cursor += 1 - - # 2. 处理连接输入 (links) - for input_config in node.get("inputs", []): - if "link" in input_config and input_config["link"] is not None: - link_id = input_config["link"] - if link_id in link_map: - # 输入名称是input_config中的'name' - inputs_dict[input_config["name"]] = link_map[link_id] - + inputs_dict = _get_node_inputs(node, link_map) prompt_api_format[node_id] = { "inputs": inputs_dict, "class_type": node["type"], @@ -379,9 +361,13 @@ async def _download_file(src: str) -> bytes: return await response.read() -def _generate_field_schema(input_field: dict) -> dict: +def _generate_field_schema(input_field: dict, default_value: Any = None) -> dict: """ 根据输入字段信息生成 JSON Schema 字段定义 + + Args: + input_field: 输入字段配置 + default_value: 字段的默认值 """ field_schema = { "title": input_field["widget"]["name"], @@ -428,7 +414,9 @@ def _generate_field_schema(input_field: dict) -> dict: ] += f" 可选值: {', '.join(input_field['options'])}" # 设置默认值 - if "default" in input_field: + if default_value is not None: + field_schema["default"] = default_value + elif "default" in input_field: field_schema["default"] = input_field["default"] return field_schema @@ -447,3 +435,43 @@ def _get_param_type(input_type_str: str) -> str: return "float" else: return "string" + + +def _get_node_inputs(node: dict, link_map: dict) -> dict: + """ + 获取节点的输入字段值 + + Args: + node: 节点数据 + link_map: 从link_id到源节点的映射 + + Returns: + dict: 包含节点输入字段值的字典 + """ + inputs_dict = {} + + # 1. 处理控件输入 (widgets) + widgets_values = node.get("widgets_values", []) + widget_cursor = 0 + for input_config in node.get("inputs", []): + # 如果是widget并且有对应的widgets_values + if "widget" in input_config and widget_cursor < len(widgets_values): + widget_name = input_config["widget"].get("name") + if widget_name: + # 使用widgets_values中的值,因为这里已经包含了API传入的修改 + inputs_dict[widget_name] = widgets_values[widget_cursor] + widget_cursor += 1 + + # 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值 + if node["type"] == "EG_CYQ_JB" and widget_cursor == 1: + widget_cursor += 1 + + # 2. 处理连接输入 (links) + for input_config in node.get("inputs", []): + if "link" in input_config and input_config["link"] is not None: + link_id = input_config["link"] + if link_id in link_map: + # 输入名称是input_config中的'name' + inputs_dict[input_config["name"]] = link_map[link_id] + + return inputs_dict