feat: 增强工作流提示构建逻辑,支持节点输入映射和过滤,优化API格式返回

This commit is contained in:
iHeyTang 2025-08-20 16:42:56 +08:00
parent 861d9ddf46
commit df0c133461
1 changed files with 45 additions and 35 deletions

View File

@ -28,6 +28,7 @@ class ComfyWorkflow:
self.workflow_data = workflow_data
self._nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
self._api_spec = self._parse_api_spec()
self._inputs_json_schema = self._parse_inputs_json_schema()
async def build_prompt(
self, server: ComfyUIServerInfo, request_data: dict[str, Any]
@ -41,10 +42,43 @@ class ComfyWorkflow:
Returns:
dict: 构建好的prompt
"""
api_spec = self._get_or_create_api_spec()
patched_workflow = await self._patch_workflow(server, api_spec, request_data)
prompt = self._convert_workflow_to_prompt_api_format(patched_workflow)
return prompt
patched_workflow = await self._patch_workflow(server, request_data)
if "nodes" not in patched_workflow:
raise ValueError("无效的工作流格式")
prompt_api_format = {}
# 建立从link_id到源节点的映射
link_map = {}
for link_data in patched_workflow.get("links", []):
link_id, origin_node_id, origin_slot_index = (
link_data[0],
link_data[1],
link_data[2],
)
# 键是目标节点的输入link_id
link_map[link_id] = [str(origin_node_id), origin_slot_index]
for node in patched_workflow["nodes"]:
node_id = str(node["id"])
inputs_dict = self._get_node_inputs(node, link_map)
prompt_api_format[node_id] = {
"inputs": inputs_dict,
"class_type": node["type"],
"_meta": {
"title": node.get("title", ""),
},
}
# 过滤掉类型为 Note 的节点
prompt_api_format = {
k: v
for k, v in prompt_api_format.items()
if v["class_type"] not in ["Note"]
}
return prompt_api_format
def get_api_spec(self) -> Dict[str, Dict[str, Any]]:
"""
@ -53,7 +87,7 @@ class ComfyWorkflow:
Returns:
Dict[str, Dict[str, Any]]: API规范
"""
return self._get_or_create_api_spec()
return self._api_spec
def get_inputs_json_schema(self) -> dict:
"""
@ -62,7 +96,7 @@ class ComfyWorkflow:
Returns:
dict: JSON Schema
"""
return self._parse_inputs_json_schema()
return self._inputs_json_schema
def _parse_api_spec(self) -> Dict[str, Dict[str, Any]]:
"""
@ -70,9 +104,8 @@ class ComfyWorkflow:
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
"""
spec = {"inputs": {}, "outputs": {}}
nodes_map = self._get_or_create_nodes_map()
for node_id, node in nodes_map.items():
for node_id, node in self._nodes_map.items():
title: str = node.get("title")
if not title:
continue
@ -153,12 +186,10 @@ class ComfyWorkflow:
"required": [],
}
nodes_map = self._get_or_create_nodes_map()
# 收集所有输入节点
input_nodes = {}
for node_id, node in nodes_map.items():
for node_id, node in self._nodes_map.items():
title: str = node.get("title")
if not title:
continue
@ -227,7 +258,6 @@ class ComfyWorkflow:
async def _patch_workflow(
self,
server: ComfyUIServerInfo,
api_spec: dict[str, dict[str, Any]],
request_data: dict[str, Any],
) -> dict:
"""
@ -238,13 +268,13 @@ class ComfyWorkflow:
if "nodes" not in self.workflow_data:
raise ValueError("无效的工作流格式")
nodes_map = self._get_or_create_nodes_map().copy()
nodes_map = self._nodes_map.copy()
for node_name, node_fields in request_data.items():
if node_name not in api_spec["inputs"]:
if node_name not in self._api_spec["inputs"]:
continue
node_spec = api_spec["inputs"][node_name]
node_spec = self._api_spec["inputs"][node_name]
node_id = node_spec["node_id"]
if node_id not in nodes_map:
continue
@ -512,23 +542,3 @@ class ComfyWorkflow:
inputs_dict[input_config["name"]] = link_map[link_id]
return inputs_dict
def _get_or_create_nodes_map(self) -> dict:
"""获取或创建节点映射"""
if self._nodes_map is None:
if "nodes" not in self.workflow_data or not isinstance(
self.workflow_data["nodes"], list
):
raise ValueError(
"Invalid workflow format: 'nodes' key not found or is not a list."
)
self._nodes_map = {
str(node["id"]): node for node in self.workflow_data["nodes"]
}
return self._nodes_map
def _get_or_create_api_spec(self) -> Dict[str, Dict[str, Any]]:
"""获取或创建API规范"""
if self._api_spec is None:
self._api_spec = self._parse_api_spec()
return self._api_spec