ComfyUI-WorkflowPublisher/workflow_service/comfy/comfy_workflow.py

744 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from typing import Any, Dict, List, Optional, Union
import aiohttp
from pydantic import BaseModel
from workflow_service.comfy.comfy_server import ComfyUIServerInfo
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
class ComfyWorkflowNodeWidget(BaseModel):
"""节点控件定义"""
name: str
class ComfyWorkflowNodeInput(BaseModel):
"""节点输入定义"""
label: Optional[str] = None
localized_name: Optional[str] = None
name: str
shape: Optional[int] = None
type: str
link: Optional[int] = None
widget: Optional[ComfyWorkflowNodeWidget] = None
class ComfyWorkflowNodeOutput(BaseModel):
"""节点输出定义"""
label: Optional[str] = None
localized_name: Optional[str] = None
name: str
type: str
links: Optional[List[int]] = None
class ComfyWorkflowNodeProperties(BaseModel):
"""节点属性定义"""
cnr_id: Optional[str] = None
aux_id: Optional[str] = None
ver: Optional[str] = None
Node_name_for_SR: Optional[str] = None
class ComfyAPIFieldSpec(BaseModel):
"""API字段规范"""
type: str
widget_name: str
class ComfyAPIOutputSpec(BaseModel):
"""API输出规范"""
output_name: str
class ComfyAPINodeSpec(BaseModel):
"""API节点规范"""
node_id: str
class_type: str
inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {}
outputs: Optional[Dict[str, ComfyAPIOutputSpec]] = {}
class ComfyAPISpec(BaseModel):
"""API规范定义"""
inputs: Dict[str, ComfyAPINodeSpec] = {}
outputs: Dict[str, ComfyAPINodeSpec] = {}
class ComfyJSONSchemaProperty(BaseModel):
"""JSON Schema属性定义"""
type: str
title: Optional[str] = None
description: Optional[str] = None
default: Optional[Any] = None
minimum: Optional[Union[int, float]] = None
maximum: Optional[Union[int, float]] = None
enum: Optional[List[str]] = None
format: Optional[str] = None
contentMediaType: Optional[str] = None
pattern: Optional[str] = None
class ComfyJSONSchemaNode(BaseModel):
"""JSON Schema节点定义"""
type: str = "object"
title: str
description: str
properties: Dict[str, ComfyJSONSchemaProperty] = {}
required: List[str] = []
additionalProperties: bool = False
class ComfyJSONSchema(BaseModel):
"""JSON Schema定义"""
schema: str = "http://json-schema.org/draft-07/schema#"
type: str = "object"
title: str = "ComfyUI Workflow Input Schema"
description: str = "工作流输入参数定义"
properties: Dict[str, ComfyJSONSchemaNode] = {}
required: List[str] = []
class Config:
fields = {"schema": "$schema"}
class ComfyWorkflowNode(BaseModel):
"""工作流节点定义"""
id: int
type: str
pos: List[float]
size: List[float]
flags: Dict[str, Any] = {}
order: int
mode: int
inputs: Optional[List[ComfyWorkflowNodeInput]] = []
outputs: Optional[List[ComfyWorkflowNodeOutput]] = []
properties: ComfyWorkflowNodeProperties = ComfyWorkflowNodeProperties()
widgets_values: Optional[List[Any]] = []
title: Optional[str] = None
class ComfyWorkflowExtra(BaseModel):
"""工作流额外配置"""
ds: Optional[Dict[str, Any]] = None
frontendVersion: Optional[str] = None
VHS_latentpreview: Optional[bool] = None
VHS_latentpreviewrate: Optional[int] = None
VHS_MetadataImage: Optional[bool] = None
VHS_KeepIntermediate: Optional[bool] = None
class ComfyWorkflowDataSpec(BaseModel):
"""ComfyUI工作流数据规范"""
id: str
revision: int = 0
last_node_id: int
last_link_id: int
nodes: List[ComfyWorkflowNode]
links: List[List[Union[int, str]]] = []
groups: List[Any] = []
config: Dict[str, Any] = {}
extra: ComfyWorkflowExtra = ComfyWorkflowExtra()
version: float
class ComfyWorkflow:
"""
ComfyUI工作流处理器类提供面向对象的工作流管理和处理能力
"""
def __init__(self, workflow_data: dict):
"""
初始化工作流实例
Args:
workflow_data: 工作流数据
"""
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes}
self._api_spec = self._parse_api_spec()
self._inputs_json_schema = self._parse_inputs_json_schema()
async def build_prompt(
self, server: ComfyUIServerInfo, request_data: dict[str, Any]
) -> dict:
"""
构建prompt
Args:
request_data: 请求数据
Returns:
dict: 构建好的prompt
"""
patched_workflow = await self._patch_workflow(server, request_data)
if "nodes" not in patched_workflow:
raise ValueError("无效的工作流格式")
prompt_api_format = {}
# 建立从link_id到源节点的映射
link_map = {}
for link_data in patched_workflow.get("links", []):
link_id, origin_node_id, origin_slot_index = (
link_data[0],
link_data[1],
link_data[2],
)
# 键是目标节点的输入link_id
link_map[link_id] = [str(origin_node_id), origin_slot_index]
for node in patched_workflow["nodes"]:
node_id = str(node["id"])
inputs_dict = self._get_node_inputs(node, link_map)
prompt_api_format[node_id] = {
"inputs": inputs_dict,
"class_type": node["type"],
"_meta": {
"title": node.get("title", ""),
},
}
# 过滤掉类型为 Note 的节点
prompt_api_format = {
k: v
for k, v in prompt_api_format.items()
if v["class_type"] not in ["Note"]
}
return prompt_api_format
def get_api_spec(self) -> ComfyAPISpec:
"""
获取API规范
Returns:
ComfyAPISpec: API规范
"""
return self._api_spec
def get_inputs_json_schema(self) -> ComfyJSONSchema:
"""
获取输入参数的JSON Schema
Returns:
ComfyJSONSchema: JSON Schema
"""
return self._inputs_json_schema
def _parse_api_spec(self) -> ComfyAPISpec:
"""
解析工作流并根据规范生成嵌套结构的API参数名。
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
"""
inputs = {}
outputs = {}
for node_id, node in self._nodes_map.items():
title: str = node.title
if not title:
continue
if title.startswith(API_INPUT_PREFIX):
# 提取节点名去掉API_INPUT_PREFIX前缀
node_name = title[len(API_INPUT_PREFIX) :]
# 如果节点名不在inputs中创建新的节点条目
if node_name not in inputs:
inputs[node_name] = ComfyAPINodeSpec(
node_id=node_id,
class_type=node.type,
inputs={},
outputs={}
)
if node.inputs:
for a_input in node.inputs:
if a_input.link is None and a_input.widget:
widget_name = a_input.widget.name
# 特殊处理隐藏LoadImage节点的upload字段
if node.type == "LoadImage" and widget_name == "upload":
continue
# 直接使用widget_name作为字段名
inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec(
type=self._get_param_type(a_input.type or "STRING"),
widget_name=a_input.widget.name
)
elif title.startswith(API_OUTPUT_PREFIX):
# 提取节点名去掉API_OUTPUT_PREFIX前缀
node_name = title[len(API_OUTPUT_PREFIX) :]
# 如果节点名不在outputs中创建新的节点条目
if node_name not in outputs:
outputs[node_name] = ComfyAPINodeSpec(
node_id=node_id,
class_type=node.type,
inputs={},
outputs={}
)
if node.outputs:
for an_output in node.outputs:
output_name = an_output.name
# 直接使用output_name作为字段名
outputs[node_name].outputs[output_name] = ComfyAPIOutputSpec(
output_name=an_output.name
)
return ComfyAPISpec(inputs=inputs, outputs=outputs)
def _parse_inputs_json_schema(self) -> ComfyJSONSchema:
"""
解析工作流,生成符合 JSON Schema 标准的 API 规范。
返回一个只包含输入参数的 JSON Schema 对象。
"""
# 收集所有输入节点
properties = {}
required = []
for _, node in self._nodes_map.items():
title: str = node.title
if not title:
continue
if title.startswith(API_INPUT_PREFIX):
# 提取节点名去掉API_INPUT_PREFIX前缀
node_name = title[len(API_INPUT_PREFIX) :]
node_properties = {}
node_required = []
if node.inputs:
for a_input in node.inputs:
if a_input.link is None and a_input.widget:
widget_name = a_input.widget.name
# 特殊处理隐藏LoadImage节点的upload字段
if node.type == "LoadImage" and widget_name == "upload":
continue
# 获取节点的当前值作为默认值
# 创建空的link_map因为输入节点通常没有连接
empty_link_map = {}
node_current_inputs = self._get_node_inputs(node, empty_link_map)
current_value = node_current_inputs.get(widget_name)
# 生成字段的 JSON Schema 定义,包含默认值
field_schema = self._generate_field_schema_model(a_input, current_value)
node_properties[widget_name] = field_schema
# 如果字段是必需的添加到required列表
# TODO: 从输入配置中获取是否必需的信息
node_required.append(widget_name)
# 创建节点的JSON Schema
if node_properties:
properties[node_name] = ComfyJSONSchemaNode(
title=f"输入节点: {node_name}",
description=f"节点类型: {node.type}",
properties=node_properties,
required=node_required
)
# 如果节点有必需字段,将整个节点标记为必需
if node_required:
required.append(node_name)
return ComfyJSONSchema(properties=properties, required=required)
async def _patch_workflow(
self,
server: ComfyUIServerInfo,
request_data: dict[str, Any],
) -> dict:
"""
将request_data中的参数值patch到workflow_data中。并返回修改后的workflow_data。
request_data结构: {"节点名称": {"字段名称": "字段值"}}
"""
nodes_map = {k: v.model_copy(deep=True) for k, v in self._nodes_map.items()}
for node_name, node_fields in request_data.items():
if node_name not in self._api_spec.inputs:
continue
node_spec = self._api_spec.inputs[node_name]
node_id = node_spec.node_id
if node_id not in nodes_map:
continue
target_node = nodes_map[node_id]
# 处理该节点下的所有字段
for field_name, value in node_fields.items():
if field_name not in node_spec.inputs:
continue
field_spec = node_spec.inputs[field_name]
widget_name_to_patch = field_spec.widget_name
# 找到所有属于widget的输入并确定目标widget的索引
widget_inputs = [
inp for inp in (target_node.inputs or []) if inp.widget is not None
]
target_widget_index = -1
for i, widget_input in enumerate(widget_inputs):
if widget_input.widget and widget_input.widget.name == widget_name_to_patch:
target_widget_index = i
break
if target_widget_index == -1:
logger.warning(
f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。"
)
continue
# 确保 `widgets_values` 存在且为列表
if not target_node.widgets_values or not isinstance(target_node.widgets_values, list):
# 如果不存在或格式错误根据widget数量创建一个占位符列表
target_node.widgets_values = [None] * len(widget_inputs)
# 确保 `widgets_values` 列表足够长
while len(target_node.widgets_values) <= target_widget_index:
target_node.widgets_values.append(None)
# 根据API规范转换数据类型
target_type = str
if field_spec.type == "int":
target_type = int
elif field_spec.type == "float":
target_type = float
# 在正确的位置上更新值
try:
if target_node.type == "LoadImage":
value = await self._upload_image_to_comfy(server, value)
target_node.widgets_values[target_widget_index] = target_type(value)
except (ValueError, TypeError) as e:
logger.warning(
f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec.type}'。错误: {e}"
)
continue
# 创建新的workflow_data副本
patched_workflow_data = self.workflow_data.model_copy(deep=True)
patched_workflow_data.nodes = list(nodes_map.values())
# 对workflow_data中的nodes进行排序按照id从小到大
patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id)
# 转换为字典格式以保持向后兼容
return patched_workflow_data.model_dump()
def _convert_workflow_to_prompt_api_format(self, workflow_data: dict) -> dict:
"""
将工作流API格式转换为提交到/prompt端点的格式。
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
prompt_api_format = {}
# 建立从link_id到源节点的映射
link_map = {}
for link_data in workflow_data.get("links", []):
link_id, origin_node_id, origin_slot_index = (
link_data[0],
link_data[1],
link_data[2],
)
# 键是目标节点的输入link_id
link_map[link_id] = [str(origin_node_id), origin_slot_index]
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = self._get_node_inputs(node, link_map)
prompt_api_format[node_id] = {
"inputs": inputs_dict,
"class_type": node["type"],
"_meta": {
"title": node.get("title", ""),
},
}
# 过滤掉类型为 Note 的节点
prompt_api_format = {
k: v
for k, v in prompt_api_format.items()
if v["class_type"] not in ["Note"]
}
return prompt_api_format
async def _upload_image_to_comfy(
self, server: ComfyUIServerInfo, file_path: str
) -> str:
"""
上传文件到服务器。
file 是一个http链接
返回一个名称
"""
file_name = file_path.split("/")[-1]
file_data = await self._download_file(file_path)
form_data = aiohttp.FormData()
form_data.add_field(
"image", file_data, filename=file_name, content_type="image/*"
)
form_data.add_field("type", "input")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{server.http_url}/api/upload/image", data=form_data
) as response:
response.raise_for_status()
result = await response.json()
return result["name"]
async def _download_file(self, src: str) -> bytes:
"""
下载文件到本地。
"""
async with aiohttp.ClientSession() as session:
async with session.get(src) as response:
response.raise_for_status()
return await response.read()
def _generate_field_schema(
self, input_field: dict, default_value: Any = None
) -> dict:
"""
根据输入字段信息生成 JSON Schema 字段定义(字典版本,向后兼容)
Args:
input_field: 输入字段配置
default_value: 字段的默认值
"""
field_schema = {
"title": input_field["widget"]["name"],
"description": input_field.get("description", ""),
}
# 根据字段类型设置 schema 类型
field_type = self._get_param_type(input_field.get("type", "STRING"))
if field_type == "int":
field_schema.update(
{
"type": "integer",
"minimum": input_field.get("min", None),
"maximum": input_field.get("max", None),
}
)
elif field_type == "float":
field_schema.update(
{
"type": "number",
"minimum": input_field.get("min", None),
"maximum": input_field.get("max", None),
}
)
elif field_type == "UploadFile":
field_schema.update(
{
"type": "string",
"format": "binary",
"contentMediaType": "image/*",
"pattern": ".*\\.(jpg|jpeg|png)$",
"description": "[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件",
}
)
else: # string
field_schema.update({"type": "string"})
# 处理枚举值
if "options" in input_field:
field_schema["enum"] = input_field["options"]
field_schema[
"description"
] += f" 可选值: {', '.join(input_field['options'])}"
# 设置默认值
if default_value is not None:
field_schema["default"] = default_value
elif "default" in input_field:
field_schema["default"] = input_field["default"]
return field_schema
def _generate_field_schema_model(
self, input_field: ComfyWorkflowNodeInput, default_value: Any = None
) -> ComfyJSONSchemaProperty:
"""
根据输入字段信息生成 JSON Schema 字段定义(模型版本)
Args:
input_field: 输入字段配置
default_value: 字段的默认值
"""
title = input_field.widget.name if input_field.widget else input_field.name
description = input_field.localized_name or ""
# 根据字段类型设置 schema 类型
field_type = self._get_param_type(input_field.type or "STRING")
if field_type == "int":
return ComfyJSONSchemaProperty(
type="integer",
title=title,
description=description,
default=default_value,
# TODO: 从输入配置中获取min/max值
)
elif field_type == "float":
return ComfyJSONSchemaProperty(
type="number",
title=title,
description=description,
default=default_value,
# TODO: 从输入配置中获取min/max值
)
elif field_type == "UploadFile":
return ComfyJSONSchemaProperty(
type="string",
title=title,
format="binary",
contentMediaType="image/*",
pattern=".*\\.(jpg|jpeg|png)$",
description="[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件",
default=default_value,
)
else: # string
return ComfyJSONSchemaProperty(
type="string",
title=title,
description=description,
default=default_value,
# TODO: 从输入配置中获取枚举值
)
def _get_param_type(self, input_type_str: str) -> str:
"""
根据输入类型字符串确定参数类型
"""
input_type_str = input_type_str.upper()
if "COMBO" in input_type_str:
return "UploadFile"
elif "INT" in input_type_str:
return "int"
elif "FLOAT" in input_type_str:
return "float"
else:
return "string"
def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict:
"""
获取节点的输入字段值
Args:
node: 节点数据可以是ComfyWorkflowNode模型或字典
link_map: 从link_id到源节点的映射
Returns:
dict: 包含节点输入字段值的字典
"""
inputs_dict = {}
# 兼容处理:如果是字典,转换为模型
if isinstance(node, dict):
try:
node_model = ComfyWorkflowNode.model_validate(node)
except Exception:
# 如果转换失败,使用原来的字典方式
return self._get_node_inputs_dict(node, link_map)
else:
node_model = node
# 1. 处理控件输入 (widgets)
widgets_values = node_model.widgets_values or []
widget_cursor = 0
for input_config in (node_model.inputs or []):
# 如果是widget并且有对应的widgets_values
if input_config.widget and widget_cursor < len(widgets_values):
widget_name = input_config.widget.name
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 特殊处理:如果节点是 EG_CYQ_JB 则需要忽略widgets_values的第1个值这个是用来控制control net的它不参与赋值
if node_model.type == "EG_CYQ_JB" and widget_cursor == 1:
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in (node_model.inputs or []):
if input_config.link is not None:
link_id = input_config.link
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config.name] = link_map[link_id]
return inputs_dict
def _get_node_inputs_dict(self, node: dict, link_map: dict) -> dict:
"""
获取节点的输入字段值(字典版本,用于向后兼容)
Args:
node: 节点数据字典
link_map: 从link_id到源节点的映射
Returns:
dict: 包含节点输入字段值的字典
"""
inputs_dict = {}
# 1. 处理控件输入 (widgets)
widgets_values = node.get("widgets_values", [])
widget_cursor = 0
for input_config in node.get("inputs", []):
# 如果是widget并且有对应的widgets_values
if "widget" in input_config and widget_cursor < len(widgets_values):
widget_name = input_config["widget"].get("name")
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 特殊处理:如果节点是 EG_CYQ_JB 则需要忽略widgets_values的第1个值这个是用来控制control net的它不参与赋值
if node["type"] == "EG_CYQ_JB" and widget_cursor == 1:
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config["name"]] = link_map[link_id]
return inputs_dict