diff --git a/workflow_service/comfyui_client.py b/workflow_service/comfyui_client.py index 823eeb2..12a03c5 100644 --- a/workflow_service/comfyui_client.py +++ b/workflow_service/comfyui_client.py @@ -437,6 +437,132 @@ def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]: return spec +def parse_inputs_json_schema(workflow_data: dict) -> dict: + """ + 解析工作流,生成符合 JSON Schema 标准的 API 规范。 + 返回一个只包含输入参数的 JSON Schema 对象。 + """ + # 基础 JSON Schema 结构,只包含输入 + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "title": "ComfyUI Workflow Input Schema", + "description": "工作流输入参数定义", + "properties": {}, + "required": [], + } + + if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list): + raise ValueError( + "Invalid workflow format: 'nodes' key not found or is not a list." + ) + + nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} + + # 收集所有输入节点 + input_nodes = {} + + for node_id, node in nodes_map.items(): + title: str = node.get("title") + if not title: + continue + + if title.startswith(API_INPUT_PREFIX): + # 提取节点名(去掉API_INPUT_PREFIX前缀) + node_name = title[len(API_INPUT_PREFIX) :].lower() + + if node_name not in input_nodes: + input_nodes[node_name] = { + "node_id": node_id, + "class_type": node.get("type"), + "properties": {}, + "required": [], + } + + if "inputs" in node: + for a_input in node.get("inputs", []): + if a_input.get("link") is None and "widget" in a_input: + widget_name = a_input["widget"]["name"].lower() + + # 特殊处理:隐藏LoadImage节点的upload字段 + if node.get("type") == "LoadImage" and widget_name == "upload": + continue + + # 生成字段的 JSON Schema 定义 + field_schema = _generate_field_schema(a_input) + input_nodes[node_name]["properties"][widget_name] = field_schema + + # 如果字段是必需的,添加到required列表 + if a_input.get("required", True): # 默认所有字段都是必需的 + input_nodes[node_name]["required"].append(widget_name) + + # 将收集的节点信息转换为 JSON Schema 格式 + for node_name, node_info in input_nodes.items(): + schema["properties"][node_name] = { + "type": "object", + "title": f"输入节点: {node_name}", + "description": f"节点类型: {node_info['class_type']}", + "properties": node_info["properties"], + "required": node_info["required"] if node_info["required"] else [], + "additionalProperties": False, + } + + # 如果节点有必需字段,将整个节点标记为必需 + if node_info["required"]: + schema["required"].append(node_name) + + return schema + + +def _generate_field_schema(input_field: dict) -> dict: + """ + 根据输入字段信息生成 JSON Schema 字段定义 + """ + field_schema = { + "title": input_field["widget"]["name"], + "description": input_field.get("description", ""), + } + + # 根据字段类型设置 schema 类型 + field_type = _get_param_type(input_field.get("type", "STRING")) + + if field_type == "int": + field_schema.update( + { + "type": "integer", + "minimum": input_field.get("min", None), + "maximum": input_field.get("max", None), + } + ) + elif field_type == "float": + field_schema.update( + { + "type": "number", + "minimum": input_field.get("min", None), + "maximum": input_field.get("max", None), + } + ) + elif field_type == "UploadFile": + field_schema.update( + {"type": "string", "format": "binary", "description": "上传文件"} + ) + else: # string + field_schema.update({"type": "string"}) + + # 处理枚举值 + if "options" in input_field: + field_schema["enum"] = input_field["options"] + field_schema[ + "description" + ] += f" 可选值: {', '.join(input_field['options'])}" + + # 设置默认值 + if "default" in input_field: + field_schema["default"] = input_field["default"] + + return field_schema + + def _get_param_type(input_type_str: str) -> str: """ 根据输入类型字符串确定参数类型 diff --git a/workflow_service/routes/workflow.py b/workflow_service/routes/workflow.py index ddfcaac..a19430c 100644 --- a/workflow_service/routes/workflow.py +++ b/workflow_service/routes/workflow.py @@ -91,6 +91,7 @@ async def get_one_workflow_endpoint(base_name: str, version: Optional[str] = Non return { "workflow": workflow, "api_spec": comfyui_client.parse_api_spec(workflow), + "inputs_json_schema": comfyui_client.parse_inputs_json_schema(workflow), } except Exception as e: raise HTTPException(