from ast import List import json import logging from typing import Any, Dict import aiohttp from workflow_service.comfy.comfy_server import ComfyUIServerInfo logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) API_INPUT_PREFIX = "INPUT_" API_OUTPUT_PREFIX = "OUTPUT_" async def build_prompt( workflow_data: dict, api_spec: dict[str, dict[str, Any]], request_data: dict[str, Any], server: ComfyUIServerInfo, ): """ 构建prompt """ patched_workflow = await _patch_workflow( workflow_data, api_spec, request_data, server ) prompt = _convert_workflow_to_prompt_api_format(patched_workflow) return prompt def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]: """ 解析工作流,并根据规范生成嵌套结构的API参数名。 结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}} """ spec = {"inputs": {}, "outputs": {}} if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list): raise ValueError( "Invalid workflow format: 'nodes' key not found or is not a list." ) nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} for node_id, node in nodes_map.items(): title: str = node.get("title") if not title: continue if title.startswith(API_INPUT_PREFIX): # 提取节点名(去掉API_INPUT_PREFIX前缀) node_name = title[len(API_INPUT_PREFIX) :].lower() # 如果节点名不在inputs中,创建新的节点条目 if node_name not in spec["inputs"]: spec["inputs"][node_name] = { "node_id": node_id, "class_type": node.get("type"), } if "inputs" in node: for a_input in node.get("inputs", []): if a_input.get("link") is None and "widget" in a_input: widget_name = a_input["widget"]["name"].lower() # 特殊处理:隐藏LoadImage节点的upload字段 if node.get("type") == "LoadImage" and widget_name == "upload": continue # 确保inputs字段存在 if "inputs" not in spec["inputs"][node_name]: spec["inputs"][node_name]["inputs"] = {} # 直接使用widget_name作为字段名 spec["inputs"][node_name]["inputs"][widget_name] = { "type": _get_param_type(a_input.get("type", "STRING")), "widget_name": a_input["widget"]["name"], } elif title.startswith(API_OUTPUT_PREFIX): # 提取节点名(去掉API_OUTPUT_PREFIX前缀) node_name = title[len(API_OUTPUT_PREFIX) :].lower() # 如果节点名不在outputs中,创建新的节点条目 if node_name not in spec["outputs"]: spec["outputs"][node_name] = { "node_id": node_id, "class_type": node.get("type"), } if "outputs" in node: for an_output in node.get("outputs", []): output_name = an_output["name"].lower() # 确保outputs字段存在 if "outputs" not in spec["outputs"][node_name]: spec["outputs"][node_name]["outputs"] = {} # 直接使用output_name作为字段名 spec["outputs"][node_name]["outputs"][output_name] = { "output_name": an_output["name"], } return spec def parse_inputs_json_schema(workflow_data: dict) -> dict: """ 解析工作流,生成符合 JSON Schema 标准的 API 规范。 返回一个只包含输入参数的 JSON Schema 对象。 """ # 基础 JSON Schema 结构,只包含输入 schema = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "title": "ComfyUI Workflow Input Schema", "description": "工作流输入参数定义", "properties": {}, "required": [], } if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list): raise ValueError( "Invalid workflow format: 'nodes' key not found or is not a list." ) nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} # 收集所有输入节点 input_nodes = {} for node_id, node in nodes_map.items(): title: str = node.get("title") if not title: continue if title.startswith(API_INPUT_PREFIX): # 提取节点名(去掉API_INPUT_PREFIX前缀) node_name = title[len(API_INPUT_PREFIX) :].lower() if node_name not in input_nodes: input_nodes[node_name] = { "node_id": node_id, "class_type": node.get("type"), "properties": {}, "required": [], } if "inputs" in node: for a_input in node.get("inputs", []): if a_input.get("link") is None and "widget" in a_input: widget_name = a_input["widget"]["name"].lower() # 特殊处理:隐藏LoadImage节点的upload字段 if node.get("type") == "LoadImage" and widget_name == "upload": continue # 生成字段的 JSON Schema 定义 field_schema = _generate_field_schema(a_input) input_nodes[node_name]["properties"][widget_name] = field_schema # 如果字段是必需的,添加到required列表 if a_input.get("required", True): # 默认所有字段都是必需的 input_nodes[node_name]["required"].append(widget_name) # 将收集的节点信息转换为 JSON Schema 格式 for node_name, node_info in input_nodes.items(): schema["properties"][node_name] = { "type": "object", "title": f"输入节点: {node_name}", "description": f"节点类型: {node_info['class_type']}", "properties": node_info["properties"], "required": node_info["required"] if node_info["required"] else [], "additionalProperties": False, } # 如果节点有必需字段,将整个节点标记为必需 if node_info["required"]: schema["required"].append(node_name) return schema ######################################################## # 以下是内部函数 ######################################################## async def _patch_workflow( workflow_data: dict, api_spec: dict[str, dict[str, Any]], request_data: dict[str, Any], server: ComfyUIServerInfo, ) -> dict: """ 将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。 request_data结构: {"节点名称": {"字段名称": "字段值"}} """ if "nodes" not in workflow_data: raise ValueError("无效的工作流格式") nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} for node_name, node_fields in request_data.items(): if node_name not in api_spec["inputs"]: continue node_spec = api_spec["inputs"][node_name] node_id = node_spec["node_id"] if node_id not in nodes_map: continue target_node = nodes_map[node_id] # 处理该节点下的所有字段 for field_name, value in node_fields.items(): if "inputs" not in node_spec or field_name not in node_spec["inputs"]: continue field_spec = node_spec["inputs"][field_name] widget_name_to_patch = field_spec["widget_name"] # 找到所有属于widget的输入,并确定目标widget的索引 widget_inputs = [ inp for inp in target_node.get("inputs", []) if "widget" in inp ] target_widget_index = -1 for i, widget_input in enumerate(widget_inputs): # "widget" 字典可能不存在或没有 "name" 键 if widget_input.get("widget", {}).get("name") == widget_name_to_patch: target_widget_index = i break if target_widget_index == -1: logger.warning( f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。" ) continue # 确保 `widgets_values` 存在且为列表 if "widgets_values" not in target_node or not isinstance( target_node.get("widgets_values"), list ): # 如果不存在或格式错误,根据widget数量创建一个占位符列表 target_node["widgets_values"] = [None] * len(widget_inputs) # 确保 `widgets_values` 列表足够长 while len(target_node["widgets_values"]) <= target_widget_index: target_node["widgets_values"].append(None) # 根据API规范转换数据类型 target_type = str if field_spec["type"] == "int": target_type = int elif field_spec["type"] == "float": target_type = float # 在正确的位置上更新值 try: if target_node.get("type") == "LoadImage": value = await _upload_image_to_comfy(value, server) target_node["widgets_values"][target_widget_index] = target_type(value) except (ValueError, TypeError) as e: logger.warning( f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec['type']}'。错误: {e}" ) continue workflow_data["nodes"] = list(nodes_map.values()) # 对workflow_data中的nodes进行排序,按照id从小到大 workflow_data["nodes"] = sorted(workflow_data["nodes"], key=lambda x: x["id"]) return workflow_data def _convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict: """ 将工作流(API格式)转换为提交到/prompt端点的格式。 此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。 """ if "nodes" not in workflow_data: raise ValueError("无效的工作流格式") prompt_api_format = {} # 建立从link_id到源节点的映射 link_map = {} for link_data in workflow_data.get("links", []): ( link_id, origin_node_id, origin_slot_index, target_node_id, target_slot_index, link_type, ) = link_data # 键是目标节点的输入link_id link_map[link_id] = [str(origin_node_id), origin_slot_index] for node in workflow_data["nodes"]: node_id = str(node["id"]) inputs_dict = {} # 1. 处理控件输入 (widgets) widgets_values = node.get("widgets_values", []) widget_cursor = 0 for input_config in node.get("inputs", []): # 如果是widget并且有对应的widgets_values if "widget" in input_config and widget_cursor < len(widgets_values): widget_name = input_config["widget"].get("name") if widget_name: # 使用widgets_values中的值,因为这里已经包含了API传入的修改 inputs_dict[widget_name] = widgets_values[widget_cursor] widget_cursor += 1 # 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值 if node["type"] == "EG_CYQ_JB" and widget_cursor == 1: widget_cursor += 1 # 2. 处理连接输入 (links) for input_config in node.get("inputs", []): if "link" in input_config and input_config["link"] is not None: link_id = input_config["link"] if link_id in link_map: # 输入名称是input_config中的'name' inputs_dict[input_config["name"]] = link_map[link_id] prompt_api_format[node_id] = { "inputs": inputs_dict, "class_type": node["type"], "_meta": { "title": node.get("title", ""), }, } # 过滤掉类型为 Note 的节点 prompt_api_format = { k: v for k, v in prompt_api_format.items() if v["class_type"] not in ["Note"] } return prompt_api_format async def _upload_image_to_comfy(file_path: str, server: ComfyUIServerInfo) -> str: """ 上传文件到服务器。 file 是一个http链接 返回一个名称 """ file_name = file_path.split("/")[-1] file_data = await _download_file(file_path) form_data = aiohttp.FormData() form_data.add_field("image", file_data, filename=file_name, content_type="image/*") form_data.add_field("type", "input") async with aiohttp.ClientSession() as session: async with session.post( f"{server.http_url}/api/upload/image", data=form_data ) as response: response.raise_for_status() result = await response.json() return result["name"] async def _download_file(src: str) -> bytes: """ 下载文件到本地。 """ async with aiohttp.ClientSession() as session: async with session.get(src) as response: response.raise_for_status() return await response.read() def _generate_field_schema(input_field: dict) -> dict: """ 根据输入字段信息生成 JSON Schema 字段定义 """ field_schema = { "title": input_field["widget"]["name"], "description": input_field.get("description", ""), } # 根据字段类型设置 schema 类型 field_type = _get_param_type(input_field.get("type", "STRING")) if field_type == "int": field_schema.update( { "type": "integer", "minimum": input_field.get("min", None), "maximum": input_field.get("max", None), } ) elif field_type == "float": field_schema.update( { "type": "number", "minimum": input_field.get("min", None), "maximum": input_field.get("max", None), } ) elif field_type == "UploadFile": field_schema.update( {"type": "string", "format": "binary", "description": "上传文件"} ) else: # string field_schema.update({"type": "string"}) # 处理枚举值 if "options" in input_field: field_schema["enum"] = input_field["options"] field_schema[ "description" ] += f" 可选值: {', '.join(input_field['options'])}" # 设置默认值 if "default" in input_field: field_schema["default"] = input_field["default"] return field_schema def _get_param_type(input_type_str: str) -> str: """ 根据输入类型字符串确定参数类型 """ input_type_str = input_type_str.upper() if "COMBO" in input_type_str: return "UploadFile" elif "INT" in input_type_str: return "int" elif "FLOAT" in input_type_str: return "float" else: return "string"