feat: 更新节点输入处理逻辑,支持获取默认值并优化JSON Schema生成

This commit is contained in:
iHeyTang 2025-08-20 13:52:28 +08:00
parent 7ab784eef6
commit 829daebc61
1 changed files with 57 additions and 29 deletions

View File

@ -158,8 +158,14 @@ def parse_inputs_json_schema(workflow_data: dict) -> dict:
if node.get("type") == "LoadImage" and widget_name == "upload":
continue
# 生成字段的 JSON Schema 定义
field_schema = _generate_field_schema(a_input)
# 获取节点的当前值作为默认值
# 创建空的link_map因为输入节点通常没有连接
empty_link_map = {}
node_current_inputs = _get_node_inputs(node, empty_link_map)
current_value = node_current_inputs.get(widget_name)
# 生成字段的 JSON Schema 定义,包含默认值
field_schema = _generate_field_schema(a_input, current_value)
input_nodes[node_name]["properties"][widget_name] = field_schema
# 如果字段是必需的添加到required列表
@ -306,31 +312,7 @@ def _convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = {}
# 1. 处理控件输入 (widgets)
widgets_values = node.get("widgets_values", [])
widget_cursor = 0
for input_config in node.get("inputs", []):
# 如果是widget并且有对应的widgets_values
if "widget" in input_config and widget_cursor < len(widgets_values):
widget_name = input_config["widget"].get("name")
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 特殊处理:如果节点是 EG_CYQ_JB 则需要忽略widgets_values的第1个值这个是用来控制control net的它不参与赋值
if node["type"] == "EG_CYQ_JB" and widget_cursor == 1:
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config["name"]] = link_map[link_id]
inputs_dict = _get_node_inputs(node, link_map)
prompt_api_format[node_id] = {
"inputs": inputs_dict,
"class_type": node["type"],
@ -379,9 +361,13 @@ async def _download_file(src: str) -> bytes:
return await response.read()
def _generate_field_schema(input_field: dict) -> dict:
def _generate_field_schema(input_field: dict, default_value: Any = None) -> dict:
"""
根据输入字段信息生成 JSON Schema 字段定义
Args:
input_field: 输入字段配置
default_value: 字段的默认值
"""
field_schema = {
"title": input_field["widget"]["name"],
@ -428,7 +414,9 @@ def _generate_field_schema(input_field: dict) -> dict:
] += f" 可选值: {', '.join(input_field['options'])}"
# 设置默认值
if "default" in input_field:
if default_value is not None:
field_schema["default"] = default_value
elif "default" in input_field:
field_schema["default"] = input_field["default"]
return field_schema
@ -447,3 +435,43 @@ def _get_param_type(input_type_str: str) -> str:
return "float"
else:
return "string"
def _get_node_inputs(node: dict, link_map: dict) -> dict:
"""
获取节点的输入字段值
Args:
node: 节点数据
link_map: 从link_id到源节点的映射
Returns:
dict: 包含节点输入字段值的字典
"""
inputs_dict = {}
# 1. 处理控件输入 (widgets)
widgets_values = node.get("widgets_values", [])
widget_cursor = 0
for input_config in node.get("inputs", []):
# 如果是widget并且有对应的widgets_values
if "widget" in input_config and widget_cursor < len(widgets_values):
widget_name = input_config["widget"].get("name")
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 特殊处理:如果节点是 EG_CYQ_JB 则需要忽略widgets_values的第1个值这个是用来控制control net的它不参与赋值
if node["type"] == "EG_CYQ_JB" and widget_cursor == 1:
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config["name"]] = link_map[link_id]
return inputs_dict