feat: 更新节点输入处理逻辑,支持获取默认值并优化JSON Schema生成
This commit is contained in:
parent
7ab784eef6
commit
829daebc61
|
|
@ -158,8 +158,14 @@ def parse_inputs_json_schema(workflow_data: dict) -> dict:
|
|||
if node.get("type") == "LoadImage" and widget_name == "upload":
|
||||
continue
|
||||
|
||||
# 生成字段的 JSON Schema 定义
|
||||
field_schema = _generate_field_schema(a_input)
|
||||
# 获取节点的当前值作为默认值
|
||||
# 创建空的link_map,因为输入节点通常没有连接
|
||||
empty_link_map = {}
|
||||
node_current_inputs = _get_node_inputs(node, empty_link_map)
|
||||
current_value = node_current_inputs.get(widget_name)
|
||||
|
||||
# 生成字段的 JSON Schema 定义,包含默认值
|
||||
field_schema = _generate_field_schema(a_input, current_value)
|
||||
input_nodes[node_name]["properties"][widget_name] = field_schema
|
||||
|
||||
# 如果字段是必需的,添加到required列表
|
||||
|
|
@ -306,31 +312,7 @@ def _convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
|
|||
|
||||
for node in workflow_data["nodes"]:
|
||||
node_id = str(node["id"])
|
||||
inputs_dict = {}
|
||||
# 1. 处理控件输入 (widgets)
|
||||
widgets_values = node.get("widgets_values", [])
|
||||
widget_cursor = 0
|
||||
for input_config in node.get("inputs", []):
|
||||
# 如果是widget并且有对应的widgets_values
|
||||
if "widget" in input_config and widget_cursor < len(widgets_values):
|
||||
widget_name = input_config["widget"].get("name")
|
||||
if widget_name:
|
||||
# 使用widgets_values中的值,因为这里已经包含了API传入的修改
|
||||
inputs_dict[widget_name] = widgets_values[widget_cursor]
|
||||
widget_cursor += 1
|
||||
|
||||
# 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值
|
||||
if node["type"] == "EG_CYQ_JB" and widget_cursor == 1:
|
||||
widget_cursor += 1
|
||||
|
||||
# 2. 处理连接输入 (links)
|
||||
for input_config in node.get("inputs", []):
|
||||
if "link" in input_config and input_config["link"] is not None:
|
||||
link_id = input_config["link"]
|
||||
if link_id in link_map:
|
||||
# 输入名称是input_config中的'name'
|
||||
inputs_dict[input_config["name"]] = link_map[link_id]
|
||||
|
||||
inputs_dict = _get_node_inputs(node, link_map)
|
||||
prompt_api_format[node_id] = {
|
||||
"inputs": inputs_dict,
|
||||
"class_type": node["type"],
|
||||
|
|
@ -379,9 +361,13 @@ async def _download_file(src: str) -> bytes:
|
|||
return await response.read()
|
||||
|
||||
|
||||
def _generate_field_schema(input_field: dict) -> dict:
|
||||
def _generate_field_schema(input_field: dict, default_value: Any = None) -> dict:
|
||||
"""
|
||||
根据输入字段信息生成 JSON Schema 字段定义
|
||||
|
||||
Args:
|
||||
input_field: 输入字段配置
|
||||
default_value: 字段的默认值
|
||||
"""
|
||||
field_schema = {
|
||||
"title": input_field["widget"]["name"],
|
||||
|
|
@ -428,7 +414,9 @@ def _generate_field_schema(input_field: dict) -> dict:
|
|||
] += f" 可选值: {', '.join(input_field['options'])}"
|
||||
|
||||
# 设置默认值
|
||||
if "default" in input_field:
|
||||
if default_value is not None:
|
||||
field_schema["default"] = default_value
|
||||
elif "default" in input_field:
|
||||
field_schema["default"] = input_field["default"]
|
||||
|
||||
return field_schema
|
||||
|
|
@ -447,3 +435,43 @@ def _get_param_type(input_type_str: str) -> str:
|
|||
return "float"
|
||||
else:
|
||||
return "string"
|
||||
|
||||
|
||||
def _get_node_inputs(node: dict, link_map: dict) -> dict:
|
||||
"""
|
||||
获取节点的输入字段值
|
||||
|
||||
Args:
|
||||
node: 节点数据
|
||||
link_map: 从link_id到源节点的映射
|
||||
|
||||
Returns:
|
||||
dict: 包含节点输入字段值的字典
|
||||
"""
|
||||
inputs_dict = {}
|
||||
|
||||
# 1. 处理控件输入 (widgets)
|
||||
widgets_values = node.get("widgets_values", [])
|
||||
widget_cursor = 0
|
||||
for input_config in node.get("inputs", []):
|
||||
# 如果是widget并且有对应的widgets_values
|
||||
if "widget" in input_config and widget_cursor < len(widgets_values):
|
||||
widget_name = input_config["widget"].get("name")
|
||||
if widget_name:
|
||||
# 使用widgets_values中的值,因为这里已经包含了API传入的修改
|
||||
inputs_dict[widget_name] = widgets_values[widget_cursor]
|
||||
widget_cursor += 1
|
||||
|
||||
# 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值
|
||||
if node["type"] == "EG_CYQ_JB" and widget_cursor == 1:
|
||||
widget_cursor += 1
|
||||
|
||||
# 2. 处理连接输入 (links)
|
||||
for input_config in node.get("inputs", []):
|
||||
if "link" in input_config and input_config["link"] is not None:
|
||||
link_id = input_config["link"]
|
||||
if link_id in link_map:
|
||||
# 输入名称是input_config中的'name'
|
||||
inputs_dict[input_config["name"]] = link_map[link_id]
|
||||
|
||||
return inputs_dict
|
||||
|
|
|
|||
Loading…
Reference in New Issue