feat: 增强工作流提示构建逻辑,支持节点输入映射和过滤,优化API格式返回
This commit is contained in:
parent
861d9ddf46
commit
df0c133461
|
|
@ -28,6 +28,7 @@ class ComfyWorkflow:
|
||||||
self.workflow_data = workflow_data
|
self.workflow_data = workflow_data
|
||||||
self._nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
self._nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||||||
self._api_spec = self._parse_api_spec()
|
self._api_spec = self._parse_api_spec()
|
||||||
|
self._inputs_json_schema = self._parse_inputs_json_schema()
|
||||||
|
|
||||||
async def build_prompt(
|
async def build_prompt(
|
||||||
self, server: ComfyUIServerInfo, request_data: dict[str, Any]
|
self, server: ComfyUIServerInfo, request_data: dict[str, Any]
|
||||||
|
|
@ -41,10 +42,43 @@ class ComfyWorkflow:
|
||||||
Returns:
|
Returns:
|
||||||
dict: 构建好的prompt
|
dict: 构建好的prompt
|
||||||
"""
|
"""
|
||||||
api_spec = self._get_or_create_api_spec()
|
patched_workflow = await self._patch_workflow(server, request_data)
|
||||||
patched_workflow = await self._patch_workflow(server, api_spec, request_data)
|
if "nodes" not in patched_workflow:
|
||||||
prompt = self._convert_workflow_to_prompt_api_format(patched_workflow)
|
raise ValueError("无效的工作流格式")
|
||||||
return prompt
|
|
||||||
|
prompt_api_format = {}
|
||||||
|
|
||||||
|
# 建立从link_id到源节点的映射
|
||||||
|
link_map = {}
|
||||||
|
for link_data in patched_workflow.get("links", []):
|
||||||
|
link_id, origin_node_id, origin_slot_index = (
|
||||||
|
link_data[0],
|
||||||
|
link_data[1],
|
||||||
|
link_data[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 键是目标节点的输入link_id
|
||||||
|
link_map[link_id] = [str(origin_node_id), origin_slot_index]
|
||||||
|
|
||||||
|
for node in patched_workflow["nodes"]:
|
||||||
|
node_id = str(node["id"])
|
||||||
|
inputs_dict = self._get_node_inputs(node, link_map)
|
||||||
|
prompt_api_format[node_id] = {
|
||||||
|
"inputs": inputs_dict,
|
||||||
|
"class_type": node["type"],
|
||||||
|
"_meta": {
|
||||||
|
"title": node.get("title", ""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 过滤掉类型为 Note 的节点
|
||||||
|
prompt_api_format = {
|
||||||
|
k: v
|
||||||
|
for k, v in prompt_api_format.items()
|
||||||
|
if v["class_type"] not in ["Note"]
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt_api_format
|
||||||
|
|
||||||
def get_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
def get_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -53,7 +87,7 @@ class ComfyWorkflow:
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Dict[str, Any]]: API规范
|
Dict[str, Dict[str, Any]]: API规范
|
||||||
"""
|
"""
|
||||||
return self._get_or_create_api_spec()
|
return self._api_spec
|
||||||
|
|
||||||
def get_inputs_json_schema(self) -> dict:
|
def get_inputs_json_schema(self) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|
@ -62,7 +96,7 @@ class ComfyWorkflow:
|
||||||
Returns:
|
Returns:
|
||||||
dict: JSON Schema
|
dict: JSON Schema
|
||||||
"""
|
"""
|
||||||
return self._parse_inputs_json_schema()
|
return self._inputs_json_schema
|
||||||
|
|
||||||
def _parse_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
def _parse_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -70,9 +104,8 @@ class ComfyWorkflow:
|
||||||
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
|
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
|
||||||
"""
|
"""
|
||||||
spec = {"inputs": {}, "outputs": {}}
|
spec = {"inputs": {}, "outputs": {}}
|
||||||
nodes_map = self._get_or_create_nodes_map()
|
|
||||||
|
|
||||||
for node_id, node in nodes_map.items():
|
for node_id, node in self._nodes_map.items():
|
||||||
title: str = node.get("title")
|
title: str = node.get("title")
|
||||||
if not title:
|
if not title:
|
||||||
continue
|
continue
|
||||||
|
|
@ -153,12 +186,10 @@ class ComfyWorkflow:
|
||||||
"required": [],
|
"required": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes_map = self._get_or_create_nodes_map()
|
|
||||||
|
|
||||||
# 收集所有输入节点
|
# 收集所有输入节点
|
||||||
input_nodes = {}
|
input_nodes = {}
|
||||||
|
|
||||||
for node_id, node in nodes_map.items():
|
for node_id, node in self._nodes_map.items():
|
||||||
title: str = node.get("title")
|
title: str = node.get("title")
|
||||||
if not title:
|
if not title:
|
||||||
continue
|
continue
|
||||||
|
|
@ -227,7 +258,6 @@ class ComfyWorkflow:
|
||||||
async def _patch_workflow(
|
async def _patch_workflow(
|
||||||
self,
|
self,
|
||||||
server: ComfyUIServerInfo,
|
server: ComfyUIServerInfo,
|
||||||
api_spec: dict[str, dict[str, Any]],
|
|
||||||
request_data: dict[str, Any],
|
request_data: dict[str, Any],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|
@ -238,13 +268,13 @@ class ComfyWorkflow:
|
||||||
if "nodes" not in self.workflow_data:
|
if "nodes" not in self.workflow_data:
|
||||||
raise ValueError("无效的工作流格式")
|
raise ValueError("无效的工作流格式")
|
||||||
|
|
||||||
nodes_map = self._get_or_create_nodes_map().copy()
|
nodes_map = self._nodes_map.copy()
|
||||||
|
|
||||||
for node_name, node_fields in request_data.items():
|
for node_name, node_fields in request_data.items():
|
||||||
if node_name not in api_spec["inputs"]:
|
if node_name not in self._api_spec["inputs"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
node_spec = api_spec["inputs"][node_name]
|
node_spec = self._api_spec["inputs"][node_name]
|
||||||
node_id = node_spec["node_id"]
|
node_id = node_spec["node_id"]
|
||||||
if node_id not in nodes_map:
|
if node_id not in nodes_map:
|
||||||
continue
|
continue
|
||||||
|
|
@ -512,23 +542,3 @@ class ComfyWorkflow:
|
||||||
inputs_dict[input_config["name"]] = link_map[link_id]
|
inputs_dict[input_config["name"]] = link_map[link_id]
|
||||||
|
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def _get_or_create_nodes_map(self) -> dict:
|
|
||||||
"""获取或创建节点映射"""
|
|
||||||
if self._nodes_map is None:
|
|
||||||
if "nodes" not in self.workflow_data or not isinstance(
|
|
||||||
self.workflow_data["nodes"], list
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid workflow format: 'nodes' key not found or is not a list."
|
|
||||||
)
|
|
||||||
self._nodes_map = {
|
|
||||||
str(node["id"]): node for node in self.workflow_data["nodes"]
|
|
||||||
}
|
|
||||||
return self._nodes_map
|
|
||||||
|
|
||||||
def _get_or_create_api_spec(self) -> Dict[str, Dict[str, Any]]:
|
|
||||||
"""获取或创建API规范"""
|
|
||||||
if self._api_spec is None:
|
|
||||||
self._api_spec = self._parse_api_spec()
|
|
||||||
return self._api_spec
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue