feat: 引入ComfyAPISpec类以增强API规范定义,更新工作流管理逻辑以支持新结构,优化输入和输出处理

This commit is contained in:
iHeyTang 2025-08-21 14:16:14 +08:00
parent df0c133461
commit 13580de0a1
2 changed files with 342 additions and 143 deletions

View File

@ -11,7 +11,7 @@ from typing import Dict, Any, Optional
import aiohttp
from aiohttp import ClientTimeout
from workflow_service.comfy.comfy_workflow import ComfyWorkflow
from workflow_service.comfy.comfy_workflow import ComfyAPISpec, ComfyWorkflow
from workflow_service.config import Settings
from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo
from workflow_service.database.api import (
@ -165,7 +165,7 @@ class WorkflowQueueManager:
self,
workflow_name: str,
workflow_data: dict,
api_spec: dict,
api_spec: ComfyAPISpec,
request_data: dict,
):
"""添加新任务到队列"""
@ -176,7 +176,7 @@ class WorkflowQueueManager:
workflow_run_id=workflow_run_id,
workflow_name=workflow_name,
workflow_json=json.dumps(workflow_data),
api_spec=json.dumps(api_spec),
api_spec=json.dumps(api_spec.model_dump()),
request_data=json.dumps(request_data),
)

View File

@ -1,7 +1,8 @@
import logging
from typing import Any, Dict
from typing import Any, Dict, List, Optional, Union
import aiohttp
from pydantic import BaseModel
from workflow_service.comfy.comfy_server import ComfyUIServerInfo
@ -13,6 +14,147 @@ API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
class ComfyWorkflowNodeWidget(BaseModel):
"""节点控件定义"""
name: str
class ComfyWorkflowNodeInput(BaseModel):
"""节点输入定义"""
label: Optional[str] = None
localized_name: Optional[str] = None
name: str
shape: Optional[int] = None
type: str
link: Optional[int] = None
widget: Optional[ComfyWorkflowNodeWidget] = None
class ComfyWorkflowNodeOutput(BaseModel):
"""节点输出定义"""
label: Optional[str] = None
localized_name: Optional[str] = None
name: str
type: str
links: Optional[List[int]] = None
class ComfyWorkflowNodeProperties(BaseModel):
"""节点属性定义"""
cnr_id: Optional[str] = None
aux_id: Optional[str] = None
ver: Optional[str] = None
Node_name_for_SR: Optional[str] = None
class ComfyAPIFieldSpec(BaseModel):
"""API字段规范"""
type: str
widget_name: str
class ComfyAPIOutputSpec(BaseModel):
"""API输出规范"""
output_name: str
class ComfyAPINodeSpec(BaseModel):
"""API节点规范"""
node_id: str
class_type: str
inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {}
outputs: Optional[Dict[str, ComfyAPIOutputSpec]] = {}
class ComfyAPISpec(BaseModel):
"""API规范定义"""
inputs: Dict[str, ComfyAPINodeSpec] = {}
outputs: Dict[str, ComfyAPINodeSpec] = {}
class ComfyJSONSchemaProperty(BaseModel):
"""JSON Schema属性定义"""
type: str
title: Optional[str] = None
description: Optional[str] = None
default: Optional[Any] = None
minimum: Optional[Union[int, float]] = None
maximum: Optional[Union[int, float]] = None
enum: Optional[List[str]] = None
format: Optional[str] = None
contentMediaType: Optional[str] = None
pattern: Optional[str] = None
class ComfyJSONSchemaNode(BaseModel):
"""JSON Schema节点定义"""
type: str = "object"
title: str
description: str
properties: Dict[str, ComfyJSONSchemaProperty] = {}
required: List[str] = []
additionalProperties: bool = False
class ComfyJSONSchema(BaseModel):
"""JSON Schema定义"""
schema: str = "http://json-schema.org/draft-07/schema#"
type: str = "object"
title: str = "ComfyUI Workflow Input Schema"
description: str = "工作流输入参数定义"
properties: Dict[str, ComfyJSONSchemaNode] = {}
required: List[str] = []
class Config:
fields = {"schema": "$schema"}
class ComfyWorkflowNode(BaseModel):
"""工作流节点定义"""
id: int
type: str
pos: List[float]
size: List[float]
flags: Dict[str, Any] = {}
order: int
mode: int
inputs: Optional[List[ComfyWorkflowNodeInput]] = []
outputs: Optional[List[ComfyWorkflowNodeOutput]] = []
properties: ComfyWorkflowNodeProperties = ComfyWorkflowNodeProperties()
widgets_values: Optional[List[Any]] = []
title: Optional[str] = None
class ComfyWorkflowExtra(BaseModel):
"""工作流额外配置"""
ds: Optional[Dict[str, Any]] = None
frontendVersion: Optional[str] = None
VHS_latentpreview: Optional[bool] = None
VHS_latentpreviewrate: Optional[int] = None
VHS_MetadataImage: Optional[bool] = None
VHS_KeepIntermediate: Optional[bool] = None
class ComfyWorkflowDataSpec(BaseModel):
"""ComfyUI工作流数据规范"""
id: str
revision: int = 0
last_node_id: int
last_link_id: int
nodes: List[ComfyWorkflowNode]
links: List[List[Union[int, str]]] = []
groups: List[Any] = []
config: Dict[str, Any] = {}
extra: ComfyWorkflowExtra = ComfyWorkflowExtra()
version: float
class ComfyWorkflow:
"""
ComfyUI工作流处理器类提供面向对象的工作流管理和处理能力
@ -25,8 +167,8 @@ class ComfyWorkflow:
Args:
workflow_data: 工作流数据
"""
self.workflow_data = workflow_data
self._nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes}
self._api_spec = self._parse_api_spec()
self._inputs_json_schema = self._parse_inputs_json_schema()
@ -80,33 +222,34 @@ class ComfyWorkflow:
return prompt_api_format
def get_api_spec(self) -> Dict[str, Dict[str, Any]]:
def get_api_spec(self) -> ComfyAPISpec:
"""
获取API规范
Returns:
Dict[str, Dict[str, Any]]: API规范
ComfyAPISpec: API规范
"""
return self._api_spec
def get_inputs_json_schema(self) -> dict:
def get_inputs_json_schema(self) -> ComfyJSONSchema:
"""
获取输入参数的JSON Schema
Returns:
dict: JSON Schema
ComfyJSONSchema: JSON Schema
"""
return self._inputs_json_schema
def _parse_api_spec(self) -> Dict[str, Dict[str, Any]]:
def _parse_api_spec(self) -> ComfyAPISpec:
"""
解析工作流并根据规范生成嵌套结构的API参数名
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
"""
spec = {"inputs": {}, "outputs": {}}
inputs = {}
outputs = {}
for node_id, node in self._nodes_map.items():
title: str = node.get("title")
title: str = node.title
if not title:
continue
@ -115,82 +258,64 @@ class ComfyWorkflow:
node_name = title[len(API_INPUT_PREFIX) :]
# 如果节点名不在inputs中创建新的节点条目
if node_name not in spec["inputs"]:
spec["inputs"][node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
}
if node_name not in inputs:
inputs[node_name] = ComfyAPINodeSpec(
node_id=node_id,
class_type=node.type,
inputs={},
outputs={}
)
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"]
if node.inputs:
for a_input in node.inputs:
if a_input.link is None and a_input.widget:
widget_name = a_input.widget.name
# 特殊处理隐藏LoadImage节点的upload字段
if (
node.get("type") == "LoadImage"
and widget_name == "upload"
):
if node.type == "LoadImage" and widget_name == "upload":
continue
# 确保inputs字段存在
if "inputs" not in spec["inputs"][node_name]:
spec["inputs"][node_name]["inputs"] = {}
# 直接使用widget_name作为字段名
spec["inputs"][node_name]["inputs"][widget_name] = {
"type": self._get_param_type(
a_input.get("type", "STRING")
),
"widget_name": a_input["widget"]["name"],
}
inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec(
type=self._get_param_type(a_input.type or "STRING"),
widget_name=a_input.widget.name
)
elif title.startswith(API_OUTPUT_PREFIX):
# 提取节点名去掉API_OUTPUT_PREFIX前缀
node_name = title[len(API_OUTPUT_PREFIX) :]
# 如果节点名不在outputs中创建新的节点条目
if node_name not in spec["outputs"]:
spec["outputs"][node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
}
if node_name not in outputs:
outputs[node_name] = ComfyAPINodeSpec(
node_id=node_id,
class_type=node.type,
inputs={},
outputs={}
)
if "outputs" in node:
for an_output in node.get("outputs", []):
output_name = an_output["name"]
# 确保outputs字段存在
if "outputs" not in spec["outputs"][node_name]:
spec["outputs"][node_name]["outputs"] = {}
if node.outputs:
for an_output in node.outputs:
output_name = an_output.name
# 直接使用output_name作为字段名
spec["outputs"][node_name]["outputs"][output_name] = {
"output_name": an_output["name"],
}
outputs[node_name].outputs[output_name] = ComfyAPIOutputSpec(
output_name=an_output.name
)
return spec
return ComfyAPISpec(inputs=inputs, outputs=outputs)
def _parse_inputs_json_schema(self) -> dict:
def _parse_inputs_json_schema(self) -> ComfyJSONSchema:
"""
解析工作流生成符合 JSON Schema 标准的 API 规范
返回一个只包含输入参数的 JSON Schema 对象
"""
# 基础 JSON Schema 结构,只包含输入
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"title": "ComfyUI Workflow Input Schema",
"description": "工作流输入参数定义",
"properties": {},
"required": [],
}
# 收集所有输入节点
input_nodes = {}
properties = {}
required = []
for node_id, node in self._nodes_map.items():
title: str = node.get("title")
for _, node in self._nodes_map.items():
title: str = node.title
if not title:
continue
@ -198,62 +323,46 @@ class ComfyWorkflow:
# 提取节点名去掉API_INPUT_PREFIX前缀
node_name = title[len(API_INPUT_PREFIX) :]
if node_name not in input_nodes:
input_nodes[node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
"properties": {},
"required": [],
}
node_properties = {}
node_required = []
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"]
if node.inputs:
for a_input in node.inputs:
if a_input.link is None and a_input.widget:
widget_name = a_input.widget.name
# 特殊处理隐藏LoadImage节点的upload字段
if (
node.get("type") == "LoadImage"
and widget_name == "upload"
):
if node.type == "LoadImage" and widget_name == "upload":
continue
# 获取节点的当前值作为默认值
# 创建空的link_map因为输入节点通常没有连接
empty_link_map = {}
node_current_inputs = self._get_node_inputs(
node, empty_link_map
)
node_current_inputs = self._get_node_inputs(node, empty_link_map)
current_value = node_current_inputs.get(widget_name)
# 生成字段的 JSON Schema 定义,包含默认值
field_schema = self._generate_field_schema(
a_input, current_value
)
input_nodes[node_name]["properties"][
widget_name
] = field_schema
field_schema = self._generate_field_schema_model(a_input, current_value)
node_properties[widget_name] = field_schema
# 如果字段是必需的添加到required列表
if a_input.get("required", True): # 默认所有字段都是必需的
input_nodes[node_name]["required"].append(widget_name)
# TODO: 从输入配置中获取是否必需的信息
node_required.append(widget_name)
# 将收集的节点信息转换为 JSON Schema 格式
for node_name, node_info in input_nodes.items():
schema["properties"][node_name] = {
"type": "object",
"title": f"输入节点: {node_name}",
"description": f"节点类型: {node_info['class_type']}",
"properties": node_info["properties"],
"required": node_info["required"] if node_info["required"] else [],
"additionalProperties": False,
}
# 创建节点的JSON Schema
if node_properties:
properties[node_name] = ComfyJSONSchemaNode(
title=f"输入节点: {node_name}",
description=f"节点类型: {node.type}",
properties=node_properties,
required=node_required
)
# 如果节点有必需字段,将整个节点标记为必需
if node_info["required"]:
schema["required"].append(node_name)
# 如果节点有必需字段,将整个节点标记为必需
if node_required:
required.append(node_name)
return schema
return ComfyJSONSchema(properties=properties, required=required)
async def _patch_workflow(
self,
@ -265,17 +374,15 @@ class ComfyWorkflow:
request_data结构: {"节点名称": {"字段名称": "字段值"}}
"""
if "nodes" not in self.workflow_data:
raise ValueError("无效的工作流格式")
nodes_map = self._nodes_map.copy()
nodes_map = {k: v.model_copy(deep=True) for k, v in self._nodes_map.items()}
for node_name, node_fields in request_data.items():
if node_name not in self._api_spec["inputs"]:
if node_name not in self._api_spec.inputs:
continue
node_spec = self._api_spec["inputs"][node_name]
node_id = node_spec["node_id"]
node_spec = self._api_spec.inputs[node_name]
node_id = node_spec.node_id
if node_id not in nodes_map:
continue
@ -283,24 +390,20 @@ class ComfyWorkflow:
# 处理该节点下的所有字段
for field_name, value in node_fields.items():
if "inputs" not in node_spec or field_name not in node_spec["inputs"]:
if field_name not in node_spec.inputs:
continue
field_spec = node_spec["inputs"][field_name]
widget_name_to_patch = field_spec["widget_name"]
field_spec = node_spec.inputs[field_name]
widget_name_to_patch = field_spec.widget_name
# 找到所有属于widget的输入并确定目标widget的索引
widget_inputs = [
inp for inp in target_node.get("inputs", []) if "widget" in inp
inp for inp in (target_node.inputs or []) if inp.widget is not None
]
target_widget_index = -1
for i, widget_input in enumerate(widget_inputs):
# "widget" 字典可能不存在或没有 "name" 键
if (
widget_input.get("widget", {}).get("name")
== widget_name_to_patch
):
if widget_input.widget and widget_input.widget.name == widget_name_to_patch:
target_widget_index = i
break
@ -311,46 +414,41 @@ class ComfyWorkflow:
continue
# 确保 `widgets_values` 存在且为列表
if "widgets_values" not in target_node or not isinstance(
target_node.get("widgets_values"), list
):
if not target_node.widgets_values or not isinstance(target_node.widgets_values, list):
# 如果不存在或格式错误根据widget数量创建一个占位符列表
target_node["widgets_values"] = [None] * len(widget_inputs)
target_node.widgets_values = [None] * len(widget_inputs)
# 确保 `widgets_values` 列表足够长
while len(target_node["widgets_values"]) <= target_widget_index:
target_node["widgets_values"].append(None)
while len(target_node.widgets_values) <= target_widget_index:
target_node.widgets_values.append(None)
# 根据API规范转换数据类型
target_type = str
if field_spec["type"] == "int":
if field_spec.type == "int":
target_type = int
elif field_spec["type"] == "float":
elif field_spec.type == "float":
target_type = float
# 在正确的位置上更新值
try:
if target_node.get("type") == "LoadImage":
if target_node.type == "LoadImage":
value = await self._upload_image_to_comfy(server, value)
target_node["widgets_values"][target_widget_index] = target_type(
value
)
target_node.widgets_values[target_widget_index] = target_type(value)
except (ValueError, TypeError) as e:
logger.warning(
f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec['type']}'。错误: {e}"
f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec.type}'。错误: {e}"
)
continue
# 创建新的workflow_data副本
patched_workflow = self.workflow_data.copy()
patched_workflow["nodes"] = list(nodes_map.values())
patched_workflow_data = self.workflow_data.model_copy(deep=True)
patched_workflow_data.nodes = list(nodes_map.values())
# 对workflow_data中的nodes进行排序按照id从小到大
patched_workflow["nodes"] = sorted(
patched_workflow["nodes"], key=lambda x: x["id"]
)
patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id)
return patched_workflow
# 转换为字典格式以保持向后兼容
return patched_workflow_data.model_dump()
def _convert_workflow_to_prompt_api_format(self, workflow_data: dict) -> dict:
"""
@ -432,7 +530,7 @@ class ComfyWorkflow:
self, input_field: dict, default_value: Any = None
) -> dict:
"""
根据输入字段信息生成 JSON Schema 字段定义
根据输入字段信息生成 JSON Schema 字段定义字典版本向后兼容
Args:
input_field: 输入字段配置
@ -490,6 +588,57 @@ class ComfyWorkflow:
return field_schema
def _generate_field_schema_model(
self, input_field: ComfyWorkflowNodeInput, default_value: Any = None
) -> ComfyJSONSchemaProperty:
"""
根据输入字段信息生成 JSON Schema 字段定义模型版本
Args:
input_field: 输入字段配置
default_value: 字段的默认值
"""
title = input_field.widget.name if input_field.widget else input_field.name
description = input_field.localized_name or ""
# 根据字段类型设置 schema 类型
field_type = self._get_param_type(input_field.type or "STRING")
if field_type == "int":
return ComfyJSONSchemaProperty(
type="integer",
title=title,
description=description,
default=default_value,
# TODO: 从输入配置中获取min/max值
)
elif field_type == "float":
return ComfyJSONSchemaProperty(
type="number",
title=title,
description=description,
default=default_value,
# TODO: 从输入配置中获取min/max值
)
elif field_type == "UploadFile":
return ComfyJSONSchemaProperty(
type="string",
title=title,
format="binary",
contentMediaType="image/*",
pattern=".*\\.(jpg|jpeg|png)$",
description="[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件",
default=default_value,
)
else: # string
return ComfyJSONSchemaProperty(
type="string",
title=title,
description=description,
default=default_value,
# TODO: 从输入配置中获取枚举值
)
def _get_param_type(self, input_type_str: str) -> str:
"""
根据输入类型字符串确定参数类型
@ -504,12 +653,62 @@ class ComfyWorkflow:
else:
return "string"
def _get_node_inputs(self, node: dict, link_map: dict) -> dict:
def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict:
"""
获取节点的输入字段值
Args:
node: 节点数据
node: 节点数据可以是ComfyWorkflowNode模型或字典
link_map: 从link_id到源节点的映射
Returns:
dict: 包含节点输入字段值的字典
"""
inputs_dict = {}
# 兼容处理:如果是字典,转换为模型
if isinstance(node, dict):
try:
node_model = ComfyWorkflowNode.model_validate(node)
except Exception:
# 如果转换失败,使用原来的字典方式
return self._get_node_inputs_dict(node, link_map)
else:
node_model = node
# 1. 处理控件输入 (widgets)
widgets_values = node_model.widgets_values or []
widget_cursor = 0
for input_config in (node_model.inputs or []):
# 如果是widget并且有对应的widgets_values
if input_config.widget and widget_cursor < len(widgets_values):
widget_name = input_config.widget.name
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 特殊处理:如果节点是 EG_CYQ_JB 则需要忽略widgets_values的第1个值这个是用来控制control net的它不参与赋值
if node_model.type == "EG_CYQ_JB" and widget_cursor == 1:
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in (node_model.inputs or []):
if input_config.link is not None:
link_id = input_config.link
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config.name] = link_map[link_id]
return inputs_dict
def _get_node_inputs_dict(self, node: dict, link_map: dict) -> dict:
"""
获取节点的输入字段值字典版本用于向后兼容
Args:
node: 节点数据字典
link_map: 从link_id到源节点的映射
Returns: