744 lines
26 KiB
Python
744 lines
26 KiB
Python
import logging
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
import aiohttp
|
||
from pydantic import BaseModel
|
||
from workflow_service.comfy.comfy_server import ComfyUIServerInfo
|
||
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
API_INPUT_PREFIX = "INPUT_"
|
||
API_OUTPUT_PREFIX = "OUTPUT_"
|
||
|
||
|
||
class ComfyWorkflowNodeWidget(BaseModel):
|
||
"""节点控件定义"""
|
||
name: str
|
||
|
||
|
||
class ComfyWorkflowNodeInput(BaseModel):
|
||
"""节点输入定义"""
|
||
|
||
label: Optional[str] = None
|
||
localized_name: Optional[str] = None
|
||
name: str
|
||
shape: Optional[int] = None
|
||
type: str
|
||
link: Optional[int] = None
|
||
widget: Optional[ComfyWorkflowNodeWidget] = None
|
||
|
||
|
||
class ComfyWorkflowNodeOutput(BaseModel):
|
||
"""节点输出定义"""
|
||
|
||
label: Optional[str] = None
|
||
localized_name: Optional[str] = None
|
||
name: str
|
||
type: str
|
||
links: Optional[List[int]] = None
|
||
|
||
|
||
class ComfyWorkflowNodeProperties(BaseModel):
|
||
"""节点属性定义"""
|
||
|
||
cnr_id: Optional[str] = None
|
||
aux_id: Optional[str] = None
|
||
ver: Optional[str] = None
|
||
Node_name_for_SR: Optional[str] = None
|
||
|
||
|
||
class ComfyAPIFieldSpec(BaseModel):
|
||
"""API字段规范"""
|
||
type: str
|
||
widget_name: str
|
||
|
||
|
||
class ComfyAPIOutputSpec(BaseModel):
|
||
"""API输出规范"""
|
||
output_name: str
|
||
|
||
|
||
class ComfyAPINodeSpec(BaseModel):
|
||
"""API节点规范"""
|
||
node_id: str
|
||
class_type: str
|
||
inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {}
|
||
outputs: Optional[Dict[str, ComfyAPIOutputSpec]] = {}
|
||
|
||
|
||
class ComfyAPISpec(BaseModel):
|
||
"""API规范定义"""
|
||
inputs: Dict[str, ComfyAPINodeSpec] = {}
|
||
outputs: Dict[str, ComfyAPINodeSpec] = {}
|
||
|
||
|
||
class ComfyJSONSchemaProperty(BaseModel):
|
||
"""JSON Schema属性定义"""
|
||
type: str
|
||
title: Optional[str] = None
|
||
description: Optional[str] = None
|
||
default: Optional[Any] = None
|
||
minimum: Optional[Union[int, float]] = None
|
||
maximum: Optional[Union[int, float]] = None
|
||
enum: Optional[List[str]] = None
|
||
format: Optional[str] = None
|
||
contentMediaType: Optional[str] = None
|
||
pattern: Optional[str] = None
|
||
|
||
|
||
class ComfyJSONSchemaNode(BaseModel):
|
||
"""JSON Schema节点定义"""
|
||
type: str = "object"
|
||
title: str
|
||
description: str
|
||
properties: Dict[str, ComfyJSONSchemaProperty] = {}
|
||
required: List[str] = []
|
||
additionalProperties: bool = False
|
||
|
||
|
||
class ComfyJSONSchema(BaseModel):
|
||
"""JSON Schema定义"""
|
||
schema: str = "http://json-schema.org/draft-07/schema#"
|
||
type: str = "object"
|
||
title: str = "ComfyUI Workflow Input Schema"
|
||
description: str = "工作流输入参数定义"
|
||
properties: Dict[str, ComfyJSONSchemaNode] = {}
|
||
required: List[str] = []
|
||
|
||
class Config:
|
||
fields = {"schema": "$schema"}
|
||
|
||
|
||
class ComfyWorkflowNode(BaseModel):
|
||
"""工作流节点定义"""
|
||
|
||
id: int
|
||
type: str
|
||
pos: List[float]
|
||
size: List[float]
|
||
flags: Dict[str, Any] = {}
|
||
order: int
|
||
mode: int
|
||
inputs: Optional[List[ComfyWorkflowNodeInput]] = []
|
||
outputs: Optional[List[ComfyWorkflowNodeOutput]] = []
|
||
properties: ComfyWorkflowNodeProperties = ComfyWorkflowNodeProperties()
|
||
widgets_values: Optional[List[Any]] = []
|
||
title: Optional[str] = None
|
||
|
||
|
||
class ComfyWorkflowExtra(BaseModel):
|
||
"""工作流额外配置"""
|
||
|
||
ds: Optional[Dict[str, Any]] = None
|
||
frontendVersion: Optional[str] = None
|
||
VHS_latentpreview: Optional[bool] = None
|
||
VHS_latentpreviewrate: Optional[int] = None
|
||
VHS_MetadataImage: Optional[bool] = None
|
||
VHS_KeepIntermediate: Optional[bool] = None
|
||
|
||
|
||
class ComfyWorkflowDataSpec(BaseModel):
|
||
"""ComfyUI工作流数据规范"""
|
||
|
||
id: str
|
||
revision: int = 0
|
||
last_node_id: int
|
||
last_link_id: int
|
||
nodes: List[ComfyWorkflowNode]
|
||
links: List[List[Union[int, str]]] = []
|
||
groups: List[Any] = []
|
||
config: Dict[str, Any] = {}
|
||
extra: ComfyWorkflowExtra = ComfyWorkflowExtra()
|
||
version: float
|
||
|
||
|
||
class ComfyWorkflow:
|
||
"""
|
||
ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力
|
||
"""
|
||
|
||
def __init__(self, workflow_data: dict):
|
||
"""
|
||
初始化工作流实例
|
||
|
||
Args:
|
||
workflow_data: 工作流数据
|
||
"""
|
||
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
|
||
self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes}
|
||
self._api_spec = self._parse_api_spec()
|
||
self._inputs_json_schema = self._parse_inputs_json_schema()
|
||
|
||
async def build_prompt(
|
||
self, server: ComfyUIServerInfo, request_data: dict[str, Any]
|
||
) -> dict:
|
||
"""
|
||
构建prompt
|
||
|
||
Args:
|
||
request_data: 请求数据
|
||
|
||
Returns:
|
||
dict: 构建好的prompt
|
||
"""
|
||
patched_workflow = await self._patch_workflow(server, request_data)
|
||
if "nodes" not in patched_workflow:
|
||
raise ValueError("无效的工作流格式")
|
||
|
||
prompt_api_format = {}
|
||
|
||
# 建立从link_id到源节点的映射
|
||
link_map = {}
|
||
for link_data in patched_workflow.get("links", []):
|
||
link_id, origin_node_id, origin_slot_index = (
|
||
link_data[0],
|
||
link_data[1],
|
||
link_data[2],
|
||
)
|
||
|
||
# 键是目标节点的输入link_id
|
||
link_map[link_id] = [str(origin_node_id), origin_slot_index]
|
||
|
||
for node in patched_workflow["nodes"]:
|
||
node_id = str(node["id"])
|
||
inputs_dict = self._get_node_inputs(node, link_map)
|
||
prompt_api_format[node_id] = {
|
||
"inputs": inputs_dict,
|
||
"class_type": node["type"],
|
||
"_meta": {
|
||
"title": node.get("title", ""),
|
||
},
|
||
}
|
||
|
||
# 过滤掉类型为 Note 的节点
|
||
prompt_api_format = {
|
||
k: v
|
||
for k, v in prompt_api_format.items()
|
||
if v["class_type"] not in ["Note"]
|
||
}
|
||
|
||
return prompt_api_format
|
||
|
||
def get_api_spec(self) -> ComfyAPISpec:
|
||
"""
|
||
获取API规范
|
||
|
||
Returns:
|
||
ComfyAPISpec: API规范
|
||
"""
|
||
return self._api_spec
|
||
|
||
def get_inputs_json_schema(self) -> ComfyJSONSchema:
|
||
"""
|
||
获取输入参数的JSON Schema
|
||
|
||
Returns:
|
||
ComfyJSONSchema: JSON Schema
|
||
"""
|
||
return self._inputs_json_schema
|
||
|
||
def _parse_api_spec(self) -> ComfyAPISpec:
|
||
"""
|
||
解析工作流,并根据规范生成嵌套结构的API参数名。
|
||
结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}}
|
||
"""
|
||
inputs = {}
|
||
outputs = {}
|
||
|
||
for node_id, node in self._nodes_map.items():
|
||
title: str = node.title
|
||
if not title:
|
||
continue
|
||
|
||
if title.startswith(API_INPUT_PREFIX):
|
||
# 提取节点名(去掉API_INPUT_PREFIX前缀)
|
||
node_name = title[len(API_INPUT_PREFIX) :]
|
||
|
||
# 如果节点名不在inputs中,创建新的节点条目
|
||
if node_name not in inputs:
|
||
inputs[node_name] = ComfyAPINodeSpec(
|
||
node_id=node_id,
|
||
class_type=node.type,
|
||
inputs={},
|
||
outputs={}
|
||
)
|
||
|
||
if node.inputs:
|
||
for a_input in node.inputs:
|
||
if a_input.link is None and a_input.widget:
|
||
widget_name = a_input.widget.name
|
||
|
||
# 特殊处理:隐藏LoadImage节点的upload字段
|
||
if node.type == "LoadImage" and widget_name == "upload":
|
||
continue
|
||
|
||
# 直接使用widget_name作为字段名
|
||
inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec(
|
||
type=self._get_param_type(a_input.type or "STRING"),
|
||
widget_name=a_input.widget.name
|
||
)
|
||
|
||
elif title.startswith(API_OUTPUT_PREFIX):
|
||
# 提取节点名(去掉API_OUTPUT_PREFIX前缀)
|
||
node_name = title[len(API_OUTPUT_PREFIX) :]
|
||
|
||
# 如果节点名不在outputs中,创建新的节点条目
|
||
if node_name not in outputs:
|
||
outputs[node_name] = ComfyAPINodeSpec(
|
||
node_id=node_id,
|
||
class_type=node.type,
|
||
inputs={},
|
||
outputs={}
|
||
)
|
||
|
||
if node.outputs:
|
||
for an_output in node.outputs:
|
||
output_name = an_output.name
|
||
|
||
# 直接使用output_name作为字段名
|
||
outputs[node_name].outputs[output_name] = ComfyAPIOutputSpec(
|
||
output_name=an_output.name
|
||
)
|
||
|
||
return ComfyAPISpec(inputs=inputs, outputs=outputs)
|
||
|
||
def _parse_inputs_json_schema(self) -> ComfyJSONSchema:
|
||
"""
|
||
解析工作流,生成符合 JSON Schema 标准的 API 规范。
|
||
返回一个只包含输入参数的 JSON Schema 对象。
|
||
"""
|
||
# 收集所有输入节点
|
||
properties = {}
|
||
required = []
|
||
|
||
for _, node in self._nodes_map.items():
|
||
title: str = node.title
|
||
if not title:
|
||
continue
|
||
|
||
if title.startswith(API_INPUT_PREFIX):
|
||
# 提取节点名(去掉API_INPUT_PREFIX前缀)
|
||
node_name = title[len(API_INPUT_PREFIX) :]
|
||
|
||
node_properties = {}
|
||
node_required = []
|
||
|
||
if node.inputs:
|
||
for a_input in node.inputs:
|
||
if a_input.link is None and a_input.widget:
|
||
widget_name = a_input.widget.name
|
||
|
||
# 特殊处理:隐藏LoadImage节点的upload字段
|
||
if node.type == "LoadImage" and widget_name == "upload":
|
||
continue
|
||
|
||
# 获取节点的当前值作为默认值
|
||
# 创建空的link_map,因为输入节点通常没有连接
|
||
empty_link_map = {}
|
||
node_current_inputs = self._get_node_inputs(node, empty_link_map)
|
||
current_value = node_current_inputs.get(widget_name)
|
||
|
||
# 生成字段的 JSON Schema 定义,包含默认值
|
||
field_schema = self._generate_field_schema_model(a_input, current_value)
|
||
node_properties[widget_name] = field_schema
|
||
|
||
# 如果字段是必需的,添加到required列表
|
||
# TODO: 从输入配置中获取是否必需的信息
|
||
node_required.append(widget_name)
|
||
|
||
# 创建节点的JSON Schema
|
||
if node_properties:
|
||
properties[node_name] = ComfyJSONSchemaNode(
|
||
title=f"输入节点: {node_name}",
|
||
description=f"节点类型: {node.type}",
|
||
properties=node_properties,
|
||
required=node_required
|
||
)
|
||
|
||
# 如果节点有必需字段,将整个节点标记为必需
|
||
if node_required:
|
||
required.append(node_name)
|
||
|
||
return ComfyJSONSchema(properties=properties, required=required)
|
||
|
||
async def _patch_workflow(
|
||
self,
|
||
server: ComfyUIServerInfo,
|
||
request_data: dict[str, Any],
|
||
) -> dict:
|
||
"""
|
||
将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。
|
||
|
||
request_data结构: {"节点名称": {"字段名称": "字段值"}}
|
||
"""
|
||
|
||
nodes_map = {k: v.model_copy(deep=True) for k, v in self._nodes_map.items()}
|
||
|
||
for node_name, node_fields in request_data.items():
|
||
if node_name not in self._api_spec.inputs:
|
||
continue
|
||
|
||
node_spec = self._api_spec.inputs[node_name]
|
||
node_id = node_spec.node_id
|
||
if node_id not in nodes_map:
|
||
continue
|
||
|
||
target_node = nodes_map[node_id]
|
||
|
||
# 处理该节点下的所有字段
|
||
for field_name, value in node_fields.items():
|
||
if field_name not in node_spec.inputs:
|
||
continue
|
||
|
||
field_spec = node_spec.inputs[field_name]
|
||
widget_name_to_patch = field_spec.widget_name
|
||
|
||
# 找到所有属于widget的输入,并确定目标widget的索引
|
||
widget_inputs = [
|
||
inp for inp in (target_node.inputs or []) if inp.widget is not None
|
||
]
|
||
|
||
target_widget_index = -1
|
||
for i, widget_input in enumerate(widget_inputs):
|
||
if widget_input.widget and widget_input.widget.name == widget_name_to_patch:
|
||
target_widget_index = i
|
||
break
|
||
|
||
if target_widget_index == -1:
|
||
logger.warning(
|
||
f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。"
|
||
)
|
||
continue
|
||
|
||
# 确保 `widgets_values` 存在且为列表
|
||
if not target_node.widgets_values or not isinstance(target_node.widgets_values, list):
|
||
# 如果不存在或格式错误,根据widget数量创建一个占位符列表
|
||
target_node.widgets_values = [None] * len(widget_inputs)
|
||
|
||
# 确保 `widgets_values` 列表足够长
|
||
while len(target_node.widgets_values) <= target_widget_index:
|
||
target_node.widgets_values.append(None)
|
||
|
||
# 根据API规范转换数据类型
|
||
target_type = str
|
||
if field_spec.type == "int":
|
||
target_type = int
|
||
elif field_spec.type == "float":
|
||
target_type = float
|
||
|
||
# 在正确的位置上更新值
|
||
try:
|
||
if target_node.type == "LoadImage":
|
||
value = await self._upload_image_to_comfy(server, value)
|
||
target_node.widgets_values[target_widget_index] = target_type(value)
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(
|
||
f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec.type}'。错误: {e}"
|
||
)
|
||
continue
|
||
|
||
# 创建新的workflow_data副本
|
||
patched_workflow_data = self.workflow_data.model_copy(deep=True)
|
||
patched_workflow_data.nodes = list(nodes_map.values())
|
||
|
||
# 对workflow_data中的nodes进行排序,按照id从小到大
|
||
patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id)
|
||
|
||
# 转换为字典格式以保持向后兼容
|
||
return patched_workflow_data.model_dump()
|
||
|
||
def _convert_workflow_to_prompt_api_format(self, workflow_data: dict) -> dict:
|
||
"""
|
||
将工作流(API格式)转换为提交到/prompt端点的格式。
|
||
此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。
|
||
"""
|
||
if "nodes" not in workflow_data:
|
||
raise ValueError("无效的工作流格式")
|
||
|
||
prompt_api_format = {}
|
||
|
||
# 建立从link_id到源节点的映射
|
||
link_map = {}
|
||
for link_data in workflow_data.get("links", []):
|
||
link_id, origin_node_id, origin_slot_index = (
|
||
link_data[0],
|
||
link_data[1],
|
||
link_data[2],
|
||
)
|
||
|
||
# 键是目标节点的输入link_id
|
||
link_map[link_id] = [str(origin_node_id), origin_slot_index]
|
||
|
||
for node in workflow_data["nodes"]:
|
||
node_id = str(node["id"])
|
||
inputs_dict = self._get_node_inputs(node, link_map)
|
||
prompt_api_format[node_id] = {
|
||
"inputs": inputs_dict,
|
||
"class_type": node["type"],
|
||
"_meta": {
|
||
"title": node.get("title", ""),
|
||
},
|
||
}
|
||
|
||
# 过滤掉类型为 Note 的节点
|
||
prompt_api_format = {
|
||
k: v
|
||
for k, v in prompt_api_format.items()
|
||
if v["class_type"] not in ["Note"]
|
||
}
|
||
|
||
return prompt_api_format
|
||
|
||
async def _upload_image_to_comfy(
|
||
self, server: ComfyUIServerInfo, file_path: str
|
||
) -> str:
|
||
"""
|
||
上传文件到服务器。
|
||
file 是一个http链接
|
||
返回一个名称
|
||
"""
|
||
file_name = file_path.split("/")[-1]
|
||
file_data = await self._download_file(file_path)
|
||
|
||
form_data = aiohttp.FormData()
|
||
form_data.add_field(
|
||
"image", file_data, filename=file_name, content_type="image/*"
|
||
)
|
||
form_data.add_field("type", "input")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{server.http_url}/api/upload/image", data=form_data
|
||
) as response:
|
||
response.raise_for_status()
|
||
result = await response.json()
|
||
return result["name"]
|
||
|
||
async def _download_file(self, src: str) -> bytes:
|
||
"""
|
||
下载文件到本地。
|
||
"""
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(src) as response:
|
||
response.raise_for_status()
|
||
return await response.read()
|
||
|
||
def _generate_field_schema(
|
||
self, input_field: dict, default_value: Any = None
|
||
) -> dict:
|
||
"""
|
||
根据输入字段信息生成 JSON Schema 字段定义(字典版本,向后兼容)
|
||
|
||
Args:
|
||
input_field: 输入字段配置
|
||
default_value: 字段的默认值
|
||
"""
|
||
field_schema = {
|
||
"title": input_field["widget"]["name"],
|
||
"description": input_field.get("description", ""),
|
||
}
|
||
|
||
# 根据字段类型设置 schema 类型
|
||
field_type = self._get_param_type(input_field.get("type", "STRING"))
|
||
|
||
if field_type == "int":
|
||
field_schema.update(
|
||
{
|
||
"type": "integer",
|
||
"minimum": input_field.get("min", None),
|
||
"maximum": input_field.get("max", None),
|
||
}
|
||
)
|
||
elif field_type == "float":
|
||
field_schema.update(
|
||
{
|
||
"type": "number",
|
||
"minimum": input_field.get("min", None),
|
||
"maximum": input_field.get("max", None),
|
||
}
|
||
)
|
||
elif field_type == "UploadFile":
|
||
field_schema.update(
|
||
{
|
||
"type": "string",
|
||
"format": "binary",
|
||
"contentMediaType": "image/*",
|
||
"pattern": ".*\\.(jpg|jpeg|png)$",
|
||
"description": "[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件",
|
||
}
|
||
)
|
||
else: # string
|
||
field_schema.update({"type": "string"})
|
||
|
||
# 处理枚举值
|
||
if "options" in input_field:
|
||
field_schema["enum"] = input_field["options"]
|
||
field_schema[
|
||
"description"
|
||
] += f" 可选值: {', '.join(input_field['options'])}"
|
||
|
||
# 设置默认值
|
||
if default_value is not None:
|
||
field_schema["default"] = default_value
|
||
elif "default" in input_field:
|
||
field_schema["default"] = input_field["default"]
|
||
|
||
return field_schema
|
||
|
||
def _generate_field_schema_model(
|
||
self, input_field: ComfyWorkflowNodeInput, default_value: Any = None
|
||
) -> ComfyJSONSchemaProperty:
|
||
"""
|
||
根据输入字段信息生成 JSON Schema 字段定义(模型版本)
|
||
|
||
Args:
|
||
input_field: 输入字段配置
|
||
default_value: 字段的默认值
|
||
"""
|
||
title = input_field.widget.name if input_field.widget else input_field.name
|
||
description = input_field.localized_name or ""
|
||
|
||
# 根据字段类型设置 schema 类型
|
||
field_type = self._get_param_type(input_field.type or "STRING")
|
||
|
||
if field_type == "int":
|
||
return ComfyJSONSchemaProperty(
|
||
type="integer",
|
||
title=title,
|
||
description=description,
|
||
default=default_value,
|
||
# TODO: 从输入配置中获取min/max值
|
||
)
|
||
elif field_type == "float":
|
||
return ComfyJSONSchemaProperty(
|
||
type="number",
|
||
title=title,
|
||
description=description,
|
||
default=default_value,
|
||
# TODO: 从输入配置中获取min/max值
|
||
)
|
||
elif field_type == "UploadFile":
|
||
return ComfyJSONSchemaProperty(
|
||
type="string",
|
||
title=title,
|
||
format="binary",
|
||
contentMediaType="image/*",
|
||
pattern=".*\\.(jpg|jpeg|png)$",
|
||
description="[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件",
|
||
default=default_value,
|
||
)
|
||
else: # string
|
||
return ComfyJSONSchemaProperty(
|
||
type="string",
|
||
title=title,
|
||
description=description,
|
||
default=default_value,
|
||
# TODO: 从输入配置中获取枚举值
|
||
)
|
||
|
||
def _get_param_type(self, input_type_str: str) -> str:
|
||
"""
|
||
根据输入类型字符串确定参数类型
|
||
"""
|
||
input_type_str = input_type_str.upper()
|
||
if "COMBO" in input_type_str:
|
||
return "UploadFile"
|
||
elif "INT" in input_type_str:
|
||
return "int"
|
||
elif "FLOAT" in input_type_str:
|
||
return "float"
|
||
else:
|
||
return "string"
|
||
|
||
def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict:
|
||
"""
|
||
获取节点的输入字段值
|
||
|
||
Args:
|
||
node: 节点数据(可以是ComfyWorkflowNode模型或字典)
|
||
link_map: 从link_id到源节点的映射
|
||
|
||
Returns:
|
||
dict: 包含节点输入字段值的字典
|
||
"""
|
||
inputs_dict = {}
|
||
|
||
# 兼容处理:如果是字典,转换为模型
|
||
if isinstance(node, dict):
|
||
try:
|
||
node_model = ComfyWorkflowNode.model_validate(node)
|
||
except Exception:
|
||
# 如果转换失败,使用原来的字典方式
|
||
return self._get_node_inputs_dict(node, link_map)
|
||
else:
|
||
node_model = node
|
||
|
||
# 1. 处理控件输入 (widgets)
|
||
widgets_values = node_model.widgets_values or []
|
||
widget_cursor = 0
|
||
|
||
for input_config in (node_model.inputs or []):
|
||
# 如果是widget并且有对应的widgets_values
|
||
if input_config.widget and widget_cursor < len(widgets_values):
|
||
widget_name = input_config.widget.name
|
||
if widget_name:
|
||
# 使用widgets_values中的值,因为这里已经包含了API传入的修改
|
||
inputs_dict[widget_name] = widgets_values[widget_cursor]
|
||
widget_cursor += 1
|
||
|
||
# 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值
|
||
if node_model.type == "EG_CYQ_JB" and widget_cursor == 1:
|
||
widget_cursor += 1
|
||
|
||
# 2. 处理连接输入 (links)
|
||
for input_config in (node_model.inputs or []):
|
||
if input_config.link is not None:
|
||
link_id = input_config.link
|
||
if link_id in link_map:
|
||
# 输入名称是input_config中的'name'
|
||
inputs_dict[input_config.name] = link_map[link_id]
|
||
|
||
return inputs_dict
|
||
|
||
def _get_node_inputs_dict(self, node: dict, link_map: dict) -> dict:
|
||
"""
|
||
获取节点的输入字段值(字典版本,用于向后兼容)
|
||
|
||
Args:
|
||
node: 节点数据字典
|
||
link_map: 从link_id到源节点的映射
|
||
|
||
Returns:
|
||
dict: 包含节点输入字段值的字典
|
||
"""
|
||
inputs_dict = {}
|
||
|
||
# 1. 处理控件输入 (widgets)
|
||
widgets_values = node.get("widgets_values", [])
|
||
widget_cursor = 0
|
||
for input_config in node.get("inputs", []):
|
||
# 如果是widget并且有对应的widgets_values
|
||
if "widget" in input_config and widget_cursor < len(widgets_values):
|
||
widget_name = input_config["widget"].get("name")
|
||
if widget_name:
|
||
# 使用widgets_values中的值,因为这里已经包含了API传入的修改
|
||
inputs_dict[widget_name] = widgets_values[widget_cursor]
|
||
widget_cursor += 1
|
||
|
||
# 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值
|
||
if node["type"] == "EG_CYQ_JB" and widget_cursor == 1:
|
||
widget_cursor += 1
|
||
|
||
# 2. 处理连接输入 (links)
|
||
for input_config in node.get("inputs", []):
|
||
if "link" in input_config and input_config["link"] is not None:
|
||
link_id = input_config["link"]
|
||
if link_id in link_map:
|
||
# 输入名称是input_config中的'name'
|
||
inputs_dict[input_config["name"]] = link_map[link_id]
|
||
|
||
return inputs_dict
|