feat: 增强工作流提示构建逻辑,支持节点输入映射和过滤,优化API格式返回
This commit is contained in:
parent
861d9ddf46
commit
df0c133461
|
|
@ -28,6 +28,7 @@ class ComfyWorkflow:
|
|||
self.workflow_data = workflow_data
|
||||
self._nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||||
self._api_spec = self._parse_api_spec()
|
||||
self._inputs_json_schema = self._parse_inputs_json_schema()
|
||||
|
||||
async def build_prompt(
|
||||
self, server: ComfyUIServerInfo, request_data: dict[str, Any]
|
||||
|
|
@ -41,10 +42,43 @@ class ComfyWorkflow:
|
|||
Returns:
|
||||
dict: 构建好的prompt
|
||||
"""
|
||||
api_spec = self._get_or_create_api_spec()
|
||||
patched_workflow = await self._patch_workflow(server, api_spec, request_data)
|
||||
prompt = self._convert_workflow_to_prompt_api_format(patched_workflow)
|
||||
return prompt
|
||||
patched_workflow = await self._patch_workflow(server, request_data)
|
||||
if "nodes" not in patched_workflow:
|
||||
raise ValueError("无效的工作流格式")
|
||||
|
||||
prompt_api_format = {}
|
||||
|
||||
# 建立从link_id到源节点的映射
|
||||
link_map = {}
|
||||
for link_data in patched_workflow.get("links", []):
|
||||
link_id, origin_node_id, origin_slot_index = (
|
||||
link_data[0],
|
||||
link_data[1],
|
||||
link_data[2],
|
||||
)
|
||||
|
||||
# 键是目标节点的输入link_id
|
||||
link_map[link_id] = [str(origin_node_id), origin_slot_index]
|
||||
|
||||
for node in patched_workflow["nodes"]:
|
||||
node_id = str(node["id"])
|
||||
inputs_dict = self._get_node_inputs(node, link_map)
|
||||
prompt_api_format[node_id] = {
|
||||
"inputs": inputs_dict,
|
||||
"class_type": node["type"],
|
||||
"_meta": {
|
||||
"title": node.get("title", ""),
|
||||
},
|
||||
}
|
||||
|
||||
# 过滤掉类型为 Note 的节点
|
||||
prompt_api_format = {
|
||||
k: v
|
||||
for k, v in prompt_api_format.items()
|
||||
if v["class_type"] not in ["Note"]
|
||||
}
|
||||
|
||||
return prompt_api_format
|
||||
|
||||
def get_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -53,7 +87,7 @@ class ComfyWorkflow:
|
|||
Returns:
|
||||
Dict[str, Dict[str, Any]]: API规范
|
||||
"""
|
||||
return self._get_or_create_api_spec()
|
||||
return self._api_spec
|
||||
|
||||
def get_inputs_json_schema(self) -> dict:
|
||||
"""
|
||||
|
|
@ -62,7 +96,7 @@ class ComfyWorkflow:
|
|||
Returns:
|
||||
dict: JSON Schema
|
||||
"""
|
||||
return self._parse_inputs_json_schema()
|
||||
return self._inputs_json_schema
|
||||
|
||||
def _parse_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -70,9 +104,8 @@ class ComfyWorkflow:
|
|||
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
|
||||
"""
|
||||
spec = {"inputs": {}, "outputs": {}}
|
||||
nodes_map = self._get_or_create_nodes_map()
|
||||
|
||||
for node_id, node in nodes_map.items():
|
||||
for node_id, node in self._nodes_map.items():
|
||||
title: str = node.get("title")
|
||||
if not title:
|
||||
continue
|
||||
|
|
@ -153,12 +186,10 @@ class ComfyWorkflow:
|
|||
"required": [],
|
||||
}
|
||||
|
||||
nodes_map = self._get_or_create_nodes_map()
|
||||
|
||||
# 收集所有输入节点
|
||||
input_nodes = {}
|
||||
|
||||
for node_id, node in nodes_map.items():
|
||||
for node_id, node in self._nodes_map.items():
|
||||
title: str = node.get("title")
|
||||
if not title:
|
||||
continue
|
||||
|
|
@ -227,7 +258,6 @@ class ComfyWorkflow:
|
|||
async def _patch_workflow(
|
||||
self,
|
||||
server: ComfyUIServerInfo,
|
||||
api_spec: dict[str, dict[str, Any]],
|
||||
request_data: dict[str, Any],
|
||||
) -> dict:
|
||||
"""
|
||||
|
|
@ -238,13 +268,13 @@ class ComfyWorkflow:
|
|||
if "nodes" not in self.workflow_data:
|
||||
raise ValueError("无效的工作流格式")
|
||||
|
||||
nodes_map = self._get_or_create_nodes_map().copy()
|
||||
nodes_map = self._nodes_map.copy()
|
||||
|
||||
for node_name, node_fields in request_data.items():
|
||||
if node_name not in api_spec["inputs"]:
|
||||
if node_name not in self._api_spec["inputs"]:
|
||||
continue
|
||||
|
||||
node_spec = api_spec["inputs"][node_name]
|
||||
node_spec = self._api_spec["inputs"][node_name]
|
||||
node_id = node_spec["node_id"]
|
||||
if node_id not in nodes_map:
|
||||
continue
|
||||
|
|
@ -512,23 +542,3 @@ class ComfyWorkflow:
|
|||
inputs_dict[input_config["name"]] = link_map[link_id]
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def _get_or_create_nodes_map(self) -> dict:
|
||||
"""获取或创建节点映射"""
|
||||
if self._nodes_map is None:
|
||||
if "nodes" not in self.workflow_data or not isinstance(
|
||||
self.workflow_data["nodes"], list
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid workflow format: 'nodes' key not found or is not a list."
|
||||
)
|
||||
self._nodes_map = {
|
||||
str(node["id"]): node for node in self.workflow_data["nodes"]
|
||||
}
|
||||
return self._nodes_map
|
||||
|
||||
def _get_or_create_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取或创建API规范"""
|
||||
if self._api_spec is None:
|
||||
self._api_spec = self._parse_api_spec()
|
||||
return self._api_spec
|
||||
|
|
|
|||
Loading…
Reference in New Issue