feat: 添加解析输入参数的JSON Schema功能,增强工作流API的输入验证和文档生成能力
This commit is contained in:
parent
89101e3341
commit
0d5963fb56
|
|
@ -437,6 +437,132 @@ def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def parse_inputs_json_schema(workflow_data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
解析工作流,生成符合 JSON Schema 标准的 API 规范。
|
||||||
|
返回一个只包含输入参数的 JSON Schema 对象。
|
||||||
|
"""
|
||||||
|
# 基础 JSON Schema 结构,只包含输入
|
||||||
|
schema = {
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"type": "object",
|
||||||
|
"title": "ComfyUI Workflow Input Schema",
|
||||||
|
"description": "工作流输入参数定义",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid workflow format: 'nodes' key not found or is not a list."
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||||||
|
|
||||||
|
# 收集所有输入节点
|
||||||
|
input_nodes = {}
|
||||||
|
|
||||||
|
for node_id, node in nodes_map.items():
|
||||||
|
title: str = node.get("title")
|
||||||
|
if not title:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if title.startswith(API_INPUT_PREFIX):
|
||||||
|
# 提取节点名(去掉API_INPUT_PREFIX前缀)
|
||||||
|
node_name = title[len(API_INPUT_PREFIX) :].lower()
|
||||||
|
|
||||||
|
if node_name not in input_nodes:
|
||||||
|
input_nodes[node_name] = {
|
||||||
|
"node_id": node_id,
|
||||||
|
"class_type": node.get("type"),
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if "inputs" in node:
|
||||||
|
for a_input in node.get("inputs", []):
|
||||||
|
if a_input.get("link") is None and "widget" in a_input:
|
||||||
|
widget_name = a_input["widget"]["name"].lower()
|
||||||
|
|
||||||
|
# 特殊处理:隐藏LoadImage节点的upload字段
|
||||||
|
if node.get("type") == "LoadImage" and widget_name == "upload":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成字段的 JSON Schema 定义
|
||||||
|
field_schema = _generate_field_schema(a_input)
|
||||||
|
input_nodes[node_name]["properties"][widget_name] = field_schema
|
||||||
|
|
||||||
|
# 如果字段是必需的,添加到required列表
|
||||||
|
if a_input.get("required", True): # 默认所有字段都是必需的
|
||||||
|
input_nodes[node_name]["required"].append(widget_name)
|
||||||
|
|
||||||
|
# 将收集的节点信息转换为 JSON Schema 格式
|
||||||
|
for node_name, node_info in input_nodes.items():
|
||||||
|
schema["properties"][node_name] = {
|
||||||
|
"type": "object",
|
||||||
|
"title": f"输入节点: {node_name}",
|
||||||
|
"description": f"节点类型: {node_info['class_type']}",
|
||||||
|
"properties": node_info["properties"],
|
||||||
|
"required": node_info["required"] if node_info["required"] else [],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果节点有必需字段,将整个节点标记为必需
|
||||||
|
if node_info["required"]:
|
||||||
|
schema["required"].append(node_name)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_field_schema(input_field: dict) -> dict:
|
||||||
|
"""
|
||||||
|
根据输入字段信息生成 JSON Schema 字段定义
|
||||||
|
"""
|
||||||
|
field_schema = {
|
||||||
|
"title": input_field["widget"]["name"],
|
||||||
|
"description": input_field.get("description", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据字段类型设置 schema 类型
|
||||||
|
field_type = _get_param_type(input_field.get("type", "STRING"))
|
||||||
|
|
||||||
|
if field_type == "int":
|
||||||
|
field_schema.update(
|
||||||
|
{
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": input_field.get("min", None),
|
||||||
|
"maximum": input_field.get("max", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif field_type == "float":
|
||||||
|
field_schema.update(
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"minimum": input_field.get("min", None),
|
||||||
|
"maximum": input_field.get("max", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif field_type == "UploadFile":
|
||||||
|
field_schema.update(
|
||||||
|
{"type": "string", "format": "binary", "description": "上传文件"}
|
||||||
|
)
|
||||||
|
else: # string
|
||||||
|
field_schema.update({"type": "string"})
|
||||||
|
|
||||||
|
# 处理枚举值
|
||||||
|
if "options" in input_field:
|
||||||
|
field_schema["enum"] = input_field["options"]
|
||||||
|
field_schema[
|
||||||
|
"description"
|
||||||
|
] += f" 可选值: {', '.join(input_field['options'])}"
|
||||||
|
|
||||||
|
# 设置默认值
|
||||||
|
if "default" in input_field:
|
||||||
|
field_schema["default"] = input_field["default"]
|
||||||
|
|
||||||
|
return field_schema
|
||||||
|
|
||||||
|
|
||||||
def _get_param_type(input_type_str: str) -> str:
|
def _get_param_type(input_type_str: str) -> str:
|
||||||
"""
|
"""
|
||||||
根据输入类型字符串确定参数类型
|
根据输入类型字符串确定参数类型
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,7 @@ async def get_one_workflow_endpoint(base_name: str, version: Optional[str] = Non
|
||||||
return {
|
return {
|
||||||
"workflow": workflow,
|
"workflow": workflow,
|
||||||
"api_spec": comfyui_client.parse_api_spec(workflow),
|
"api_spec": comfyui_client.parse_api_spec(workflow),
|
||||||
|
"inputs_json_schema": comfyui_client.parse_inputs_json_schema(workflow),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue