diff --git a/workflow_service/comfy/comfy_workflow.py b/workflow_service/comfy/comfy_workflow.py index 68acfed..408a534 100644 --- a/workflow_service/comfy/comfy_workflow.py +++ b/workflow_service/comfy/comfy_workflow.py @@ -28,6 +28,7 @@ class ComfyWorkflow: self.workflow_data = workflow_data self._nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} self._api_spec = self._parse_api_spec() + self._inputs_json_schema = self._parse_inputs_json_schema() async def build_prompt( self, server: ComfyUIServerInfo, request_data: dict[str, Any] @@ -41,10 +42,43 @@ class ComfyWorkflow: Returns: dict: 构建好的prompt """ - api_spec = self._get_or_create_api_spec() - patched_workflow = await self._patch_workflow(server, api_spec, request_data) - prompt = self._convert_workflow_to_prompt_api_format(patched_workflow) - return prompt + patched_workflow = await self._patch_workflow(server, request_data) + if "nodes" not in patched_workflow: + raise ValueError("无效的工作流格式") + + prompt_api_format = {} + + # 建立从link_id到源节点的映射 + link_map = {} + for link_data in patched_workflow.get("links", []): + link_id, origin_node_id, origin_slot_index = ( + link_data[0], + link_data[1], + link_data[2], + ) + + # 键是目标节点的输入link_id + link_map[link_id] = [str(origin_node_id), origin_slot_index] + + for node in patched_workflow["nodes"]: + node_id = str(node["id"]) + inputs_dict = self._get_node_inputs(node, link_map) + prompt_api_format[node_id] = { + "inputs": inputs_dict, + "class_type": node["type"], + "_meta": { + "title": node.get("title", ""), + }, + } + + # 过滤掉类型为 Note 的节点 + prompt_api_format = { + k: v + for k, v in prompt_api_format.items() + if v["class_type"] not in ["Note"] + } + + return prompt_api_format def get_api_spec(self) -> Dict[str, Dict[str, Any]]: """ @@ -53,7 +87,7 @@ class ComfyWorkflow: Returns: Dict[str, Dict[str, Any]]: API规范 """ - return self._get_or_create_api_spec() + return self._api_spec def get_inputs_json_schema(self) -> dict: """ @@ -62,7 +96,7 @@ class ComfyWorkflow: Returns: dict: JSON Schema """ - return self._parse_inputs_json_schema() + return self._inputs_json_schema def _parse_api_spec(self) -> Dict[str, Dict[str, Any]]: """ @@ -70,9 +104,8 @@ class ComfyWorkflow: 结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}} """ spec = {"inputs": {}, "outputs": {}} - nodes_map = self._get_or_create_nodes_map() - for node_id, node in nodes_map.items(): + for node_id, node in self._nodes_map.items(): title: str = node.get("title") if not title: continue @@ -153,12 +186,10 @@ class ComfyWorkflow: "required": [], } - nodes_map = self._get_or_create_nodes_map() - # 收集所有输入节点 input_nodes = {} - for node_id, node in nodes_map.items(): + for node_id, node in self._nodes_map.items(): title: str = node.get("title") if not title: continue @@ -227,7 +258,6 @@ class ComfyWorkflow: async def _patch_workflow( self, server: ComfyUIServerInfo, - api_spec: dict[str, dict[str, Any]], request_data: dict[str, Any], ) -> dict: """ @@ -238,13 +268,13 @@ class ComfyWorkflow: if "nodes" not in self.workflow_data: raise ValueError("无效的工作流格式") - nodes_map = self._get_or_create_nodes_map().copy() + nodes_map = self._nodes_map.copy() for node_name, node_fields in request_data.items(): - if node_name not in api_spec["inputs"]: + if node_name not in self._api_spec["inputs"]: continue - node_spec = api_spec["inputs"][node_name] + node_spec = self._api_spec["inputs"][node_name] node_id = node_spec["node_id"] if node_id not in nodes_map: continue @@ -512,23 +542,3 @@ class ComfyWorkflow: inputs_dict[input_config["name"]] = link_map[link_id] return inputs_dict - - def _get_or_create_nodes_map(self) -> dict: - """获取或创建节点映射""" - if self._nodes_map is None: - if "nodes" not in self.workflow_data or not isinstance( - self.workflow_data["nodes"], list - ): - raise ValueError( - "Invalid workflow format: 'nodes' key not found or is not a list." - ) - self._nodes_map = { - str(node["id"]): node for node in self.workflow_data["nodes"] - } - return self._nodes_map - - def _get_or_create_api_spec(self) -> Dict[str, Dict[str, Any]]: - """获取或创建API规范""" - if self._api_spec is None: - self._api_spec = self._parse_api_spec() - return self._api_spec