ComfyUI-WorkflowPublisher/workflow_service/comfy/comfy_workflow.py

478 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ast import List
import json
import logging
from typing import Any, Dict
import aiohttp
from workflow_service.comfy.comfy_server import ComfyUIServerInfo
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
async def build_prompt(
workflow_data: dict,
api_spec: dict[str, dict[str, Any]],
request_data: dict[str, Any],
server: ComfyUIServerInfo,
):
"""
构建prompt
"""
patched_workflow = await _patch_workflow(
workflow_data, api_spec, request_data, server
)
prompt = _convert_workflow_to_prompt_api_format(patched_workflow)
return prompt
def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
"""
解析工作流并根据规范生成嵌套结构的API参数名。
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
"""
spec = {"inputs": {}, "outputs": {}}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError(
"Invalid workflow format: 'nodes' key not found or is not a list."
)
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for node_id, node in nodes_map.items():
title: str = node.get("title")
if not title:
continue
if title.startswith(API_INPUT_PREFIX):
# 提取节点名去掉API_INPUT_PREFIX前缀
node_name = title[len(API_INPUT_PREFIX) :]
# 如果节点名不在inputs中创建新的节点条目
if node_name not in spec["inputs"]:
spec["inputs"][node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
}
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"]
# 特殊处理隐藏LoadImage节点的upload字段
if node.get("type") == "LoadImage" and widget_name == "upload":
continue
# 确保inputs字段存在
if "inputs" not in spec["inputs"][node_name]:
spec["inputs"][node_name]["inputs"] = {}
# 直接使用widget_name作为字段名
spec["inputs"][node_name]["inputs"][widget_name] = {
"type": _get_param_type(a_input.get("type", "STRING")),
"widget_name": a_input["widget"]["name"],
}
elif title.startswith(API_OUTPUT_PREFIX):
# 提取节点名去掉API_OUTPUT_PREFIX前缀
node_name = title[len(API_OUTPUT_PREFIX) :]
# 如果节点名不在outputs中创建新的节点条目
if node_name not in spec["outputs"]:
spec["outputs"][node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
}
if "outputs" in node:
for an_output in node.get("outputs", []):
output_name = an_output["name"]
# 确保outputs字段存在
if "outputs" not in spec["outputs"][node_name]:
spec["outputs"][node_name]["outputs"] = {}
# 直接使用output_name作为字段名
spec["outputs"][node_name]["outputs"][output_name] = {
"output_name": an_output["name"],
}
return spec
def parse_inputs_json_schema(workflow_data: dict) -> dict:
"""
解析工作流,生成符合 JSON Schema 标准的 API 规范。
返回一个只包含输入参数的 JSON Schema 对象。
"""
# 基础 JSON Schema 结构,只包含输入
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"title": "ComfyUI Workflow Input Schema",
"description": "工作流输入参数定义",
"properties": {},
"required": [],
}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError(
"Invalid workflow format: 'nodes' key not found or is not a list."
)
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
# 收集所有输入节点
input_nodes = {}
for node_id, node in nodes_map.items():
title: str = node.get("title")
if not title:
continue
if title.startswith(API_INPUT_PREFIX):
# 提取节点名去掉API_INPUT_PREFIX前缀
node_name = title[len(API_INPUT_PREFIX) :]
if node_name not in input_nodes:
input_nodes[node_name] = {
"node_id": node_id,
"class_type": node.get("type"),
"properties": {},
"required": [],
}
if "inputs" in node:
for a_input in node.get("inputs", []):
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"]
# 特殊处理隐藏LoadImage节点的upload字段
if node.get("type") == "LoadImage" and widget_name == "upload":
continue
# 获取节点的当前值作为默认值
# 创建空的link_map因为输入节点通常没有连接
empty_link_map = {}
node_current_inputs = _get_node_inputs(node, empty_link_map)
current_value = node_current_inputs.get(widget_name)
# 生成字段的 JSON Schema 定义,包含默认值
field_schema = _generate_field_schema(a_input, current_value)
input_nodes[node_name]["properties"][widget_name] = field_schema
# 如果字段是必需的添加到required列表
if a_input.get("required", True): # 默认所有字段都是必需的
input_nodes[node_name]["required"].append(widget_name)
# 将收集的节点信息转换为 JSON Schema 格式
for node_name, node_info in input_nodes.items():
schema["properties"][node_name] = {
"type": "object",
"title": f"输入节点: {node_name}",
"description": f"节点类型: {node_info['class_type']}",
"properties": node_info["properties"],
"required": node_info["required"] if node_info["required"] else [],
"additionalProperties": False,
}
# 如果节点有必需字段,将整个节点标记为必需
if node_info["required"]:
schema["required"].append(node_name)
return schema
########################################################
# 以下是内部函数
########################################################
async def _patch_workflow(
workflow_data: dict,
api_spec: dict[str, dict[str, Any]],
request_data: dict[str, Any],
server: ComfyUIServerInfo,
) -> dict:
"""
将request_data中的参数值patch到workflow_data中。并返回修改后的workflow_data。
request_data结构: {"节点名称": {"字段名称": "字段值"}}
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for node_name, node_fields in request_data.items():
if node_name not in api_spec["inputs"]:
continue
node_spec = api_spec["inputs"][node_name]
node_id = node_spec["node_id"]
if node_id not in nodes_map:
continue
target_node = nodes_map[node_id]
# 处理该节点下的所有字段
for field_name, value in node_fields.items():
if "inputs" not in node_spec or field_name not in node_spec["inputs"]:
continue
field_spec = node_spec["inputs"][field_name]
widget_name_to_patch = field_spec["widget_name"]
# 找到所有属于widget的输入并确定目标widget的索引
widget_inputs = [
inp for inp in target_node.get("inputs", []) if "widget" in inp
]
target_widget_index = -1
for i, widget_input in enumerate(widget_inputs):
# "widget" 字典可能不存在或没有 "name" 键
if widget_input.get("widget", {}).get("name") == widget_name_to_patch:
target_widget_index = i
break
if target_widget_index == -1:
logger.warning(
f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。"
)
continue
# 确保 `widgets_values` 存在且为列表
if "widgets_values" not in target_node or not isinstance(
target_node.get("widgets_values"), list
):
# 如果不存在或格式错误根据widget数量创建一个占位符列表
target_node["widgets_values"] = [None] * len(widget_inputs)
# 确保 `widgets_values` 列表足够长
while len(target_node["widgets_values"]) <= target_widget_index:
target_node["widgets_values"].append(None)
# 根据API规范转换数据类型
target_type = str
if field_spec["type"] == "int":
target_type = int
elif field_spec["type"] == "float":
target_type = float
# 在正确的位置上更新值
try:
if target_node.get("type") == "LoadImage":
value = await _upload_image_to_comfy(value, server)
target_node["widgets_values"][target_widget_index] = target_type(value)
except (ValueError, TypeError) as e:
logger.warning(
f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec['type']}'。错误: {e}"
)
continue
workflow_data["nodes"] = list(nodes_map.values())
# 对workflow_data中的nodes进行排序按照id从小到大
workflow_data["nodes"] = sorted(workflow_data["nodes"], key=lambda x: x["id"])
return workflow_data
def _convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
"""
将工作流API格式转换为提交到/prompt端点的格式。
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
"""
if "nodes" not in workflow_data:
raise ValueError("无效的工作流格式")
prompt_api_format = {}
# 建立从link_id到源节点的映射
link_map = {}
for link_data in workflow_data.get("links", []):
(
link_id,
origin_node_id,
origin_slot_index,
target_node_id,
target_slot_index,
link_type,
) = link_data
# 键是目标节点的输入link_id
link_map[link_id] = [str(origin_node_id), origin_slot_index]
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = _get_node_inputs(node, link_map)
prompt_api_format[node_id] = {
"inputs": inputs_dict,
"class_type": node["type"],
"_meta": {
"title": node.get("title", ""),
},
}
# 过滤掉类型为 Note 的节点
prompt_api_format = {
k: v for k, v in prompt_api_format.items() if v["class_type"] not in ["Note"]
}
return prompt_api_format
async def _upload_image_to_comfy(file_path: str, server: ComfyUIServerInfo) -> str:
"""
上传文件到服务器。
file 是一个http链接
返回一个名称
"""
file_name = file_path.split("/")[-1]
file_data = await _download_file(file_path)
form_data = aiohttp.FormData()
form_data.add_field("image", file_data, filename=file_name, content_type="image/*")
form_data.add_field("type", "input")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{server.http_url}/api/upload/image", data=form_data
) as response:
response.raise_for_status()
result = await response.json()
return result["name"]
async def _download_file(src: str) -> bytes:
"""
下载文件到本地。
"""
async with aiohttp.ClientSession() as session:
async with session.get(src) as response:
response.raise_for_status()
return await response.read()
def _generate_field_schema(input_field: dict, default_value: Any = None) -> dict:
"""
根据输入字段信息生成 JSON Schema 字段定义
Args:
input_field: 输入字段配置
default_value: 字段的默认值
"""
field_schema = {
"title": input_field["widget"]["name"],
"description": input_field.get("description", ""),
}
# 根据字段类型设置 schema 类型
field_type = _get_param_type(input_field.get("type", "STRING"))
if field_type == "int":
field_schema.update(
{
"type": "integer",
"minimum": input_field.get("min", None),
"maximum": input_field.get("max", None),
}
)
elif field_type == "float":
field_schema.update(
{
"type": "number",
"minimum": input_field.get("min", None),
"maximum": input_field.get("max", None),
}
)
elif field_type == "UploadFile":
field_schema.update(
{
"type": "string",
"format": "binary",
"contentMediaType": "image/*",
"pattern": ".*\\.(jpg|jpeg|png)$",
"description": "[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件",
}
)
else: # string
field_schema.update({"type": "string"})
# 处理枚举值
if "options" in input_field:
field_schema["enum"] = input_field["options"]
field_schema[
"description"
] += f" 可选值: {', '.join(input_field['options'])}"
# 设置默认值
if default_value is not None:
field_schema["default"] = default_value
elif "default" in input_field:
field_schema["default"] = input_field["default"]
return field_schema
def _get_param_type(input_type_str: str) -> str:
"""
根据输入类型字符串确定参数类型
"""
input_type_str = input_type_str.upper()
if "COMBO" in input_type_str:
return "UploadFile"
elif "INT" in input_type_str:
return "int"
elif "FLOAT" in input_type_str:
return "float"
else:
return "string"
def _get_node_inputs(node: dict, link_map: dict) -> dict:
"""
获取节点的输入字段值
Args:
node: 节点数据
link_map: 从link_id到源节点的映射
Returns:
dict: 包含节点输入字段值的字典
"""
inputs_dict = {}
# 1. 处理控件输入 (widgets)
widgets_values = node.get("widgets_values", [])
widget_cursor = 0
for input_config in node.get("inputs", []):
# 如果是widget并且有对应的widgets_values
if "widget" in input_config and widget_cursor < len(widgets_values):
widget_name = input_config["widget"].get("name")
if widget_name:
# 使用widgets_values中的值因为这里已经包含了API传入的修改
inputs_dict[widget_name] = widgets_values[widget_cursor]
widget_cursor += 1
# 特殊处理:如果节点是 EG_CYQ_JB 则需要忽略widgets_values的第1个值这个是用来控制control net的它不参与赋值
if node["type"] == "EG_CYQ_JB" and widget_cursor == 1:
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
# 输入名称是input_config中的'name'
inputs_dict[input_config["name"]] = link_map[link_id]
return inputs_dict