import logging from typing import Any, Dict, List, Optional, Union import aiohttp from pydantic import BaseModel from workflow_service.comfy.comfy_server import ComfyUIServerInfo logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) API_INPUT_PREFIX = "INPUT_" API_OUTPUT_PREFIX = "OUTPUT_" class ComfyWorkflowNodeWidget(BaseModel): """节点控件定义""" name: str class ComfyWorkflowNodeInput(BaseModel): """节点输入定义""" label: Optional[str] = None localized_name: Optional[str] = None name: str shape: Optional[int] = None type: str link: Optional[int] = None widget: Optional[ComfyWorkflowNodeWidget] = None class ComfyWorkflowNodeOutput(BaseModel): """节点输出定义""" label: Optional[str] = None localized_name: Optional[str] = None name: str type: str links: Optional[List[int]] = None class ComfyWorkflowNodeProperties(BaseModel): """节点属性定义""" cnr_id: Optional[str] = None aux_id: Optional[str] = None ver: Optional[str] = None Node_name_for_SR: Optional[str] = None class ComfyAPIFieldSpec(BaseModel): """API字段规范""" type: str widget_name: str class ComfyAPIOutputSpec(BaseModel): """API输出规范""" output_name: str class ComfyAPINodeSpec(BaseModel): """API节点规范""" node_id: str class_type: str inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {} outputs: Optional[Dict[str, ComfyAPIOutputSpec]] = {} class ComfyAPISpec(BaseModel): """API规范定义""" inputs: Dict[str, ComfyAPINodeSpec] = {} outputs: Dict[str, ComfyAPINodeSpec] = {} class ComfyJSONSchemaProperty(BaseModel): """JSON Schema属性定义""" type: str title: Optional[str] = None description: Optional[str] = None default: Optional[Any] = None minimum: Optional[Union[int, float]] = None maximum: Optional[Union[int, float]] = None enum: Optional[List[str]] = None format: Optional[str] = None contentMediaType: Optional[str] = None pattern: Optional[str] = None class ComfyJSONSchemaNode(BaseModel): """JSON Schema节点定义""" type: str = "object" title: str description: str properties: Dict[str, ComfyJSONSchemaProperty] = {} required: List[str] = [] additionalProperties: bool = False class ComfyJSONSchema(BaseModel): """JSON Schema定义""" schema: str = "http://json-schema.org/draft-07/schema#" type: str = "object" title: str = "ComfyUI Workflow Input Schema" description: str = "工作流输入参数定义" properties: Dict[str, ComfyJSONSchemaNode] = {} required: List[str] = [] class Config: fields = {"schema": "$schema"} class ComfyWorkflowNode(BaseModel): """工作流节点定义""" id: int type: str pos: List[float] size: List[float] flags: Dict[str, Any] = {} order: int mode: int inputs: Optional[List[ComfyWorkflowNodeInput]] = [] outputs: Optional[List[ComfyWorkflowNodeOutput]] = [] properties: ComfyWorkflowNodeProperties = ComfyWorkflowNodeProperties() widgets_values: Optional[List[Any]] = [] title: Optional[str] = None class ComfyWorkflowExtra(BaseModel): """工作流额外配置""" ds: Optional[Dict[str, Any]] = None frontendVersion: Optional[str] = None VHS_latentpreview: Optional[bool] = None VHS_latentpreviewrate: Optional[int] = None VHS_MetadataImage: Optional[bool] = None VHS_KeepIntermediate: Optional[bool] = None class ComfyWorkflowDataSpec(BaseModel): """ComfyUI工作流数据规范""" id: str revision: int = 0 last_node_id: int last_link_id: int nodes: List[ComfyWorkflowNode] links: List[List[Union[int, str]]] = [] groups: List[Any] = [] config: Dict[str, Any] = {} extra: ComfyWorkflowExtra = ComfyWorkflowExtra() version: float class ComfyWorkflow: """ ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力 """ def __init__(self, workflow_data: dict, workflow_name: str = None): """ 初始化工作流实例 Args: workflow_data: 工作流数据 workflow_name: 工作流名称 """ self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data) self.workflow_name = workflow_name self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes} self._api_spec = self._parse_api_spec() self._inputs_json_schema = self._parse_inputs_json_schema() async def build_prompt( self, server: ComfyUIServerInfo, request_data: dict[str, Any] ) -> dict: """ 构建prompt Args: request_data: 请求数据 Returns: dict: 构建好的prompt """ patched_workflow = await self._patch_workflow(server, request_data) if "nodes" not in patched_workflow: raise ValueError("无效的工作流格式") prompt_api_format = {} # 建立从link_id到源节点的映射 link_map = {} for link_data in patched_workflow.get("links", []): link_id, origin_node_id, origin_slot_index = ( link_data[0], link_data[1], link_data[2], ) # 键是目标节点的输入link_id link_map[link_id] = [str(origin_node_id), origin_slot_index] for node in patched_workflow["nodes"]: node_id = str(node["id"]) inputs_dict = self._get_node_inputs(node, link_map) prompt_api_format[node_id] = { "inputs": inputs_dict, "class_type": node["type"], "_meta": { "title": node.get("title", ""), }, } # 过滤掉类型为 Note 的节点 prompt_api_format = { k: v for k, v in prompt_api_format.items() if v["class_type"] not in ["Note"] } return prompt_api_format def get_api_spec(self) -> ComfyAPISpec: """ 获取API规范 Returns: ComfyAPISpec: API规范 """ return self._api_spec def get_inputs_json_schema(self) -> ComfyJSONSchema: """ 获取输入参数的JSON Schema Returns: ComfyJSONSchema: JSON Schema """ return self._inputs_json_schema def _parse_api_spec(self) -> ComfyAPISpec: """ 解析工作流,并根据规范生成嵌套结构的API参数名。 结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}} """ inputs = {} outputs = {} for node_id, node in self._nodes_map.items(): title: str = node.title if not title: continue if title.startswith(API_INPUT_PREFIX): # 提取节点名(去掉API_INPUT_PREFIX前缀) node_name = title[len(API_INPUT_PREFIX) :] # 如果节点名不在inputs中,创建新的节点条目 if node_name not in inputs: inputs[node_name] = ComfyAPINodeSpec( node_id=node_id, class_type=node.type, inputs={}, outputs={} ) if node.inputs: for a_input in node.inputs: if a_input.link is None and a_input.widget: widget_name = a_input.widget.name # 特殊处理:隐藏LoadImage节点的upload字段 if node.type == "LoadImage" and widget_name == "upload": continue # 直接使用widget_name作为字段名 inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec( type=self._get_param_type(a_input.type or "STRING"), widget_name=a_input.widget.name ) elif title.startswith(API_OUTPUT_PREFIX): # 提取节点名(去掉API_OUTPUT_PREFIX前缀) node_name = title[len(API_OUTPUT_PREFIX) :] # 如果节点名不在outputs中,创建新的节点条目 if node_name not in outputs: outputs[node_name] = ComfyAPINodeSpec( node_id=node_id, class_type=node.type, inputs={}, outputs={} ) if node.outputs: for an_output in node.outputs: output_name = an_output.name # 直接使用output_name作为字段名 outputs[node_name].outputs[output_name] = ComfyAPIOutputSpec( output_name=an_output.name ) return ComfyAPISpec(inputs=inputs, outputs=outputs) def _parse_inputs_json_schema(self) -> ComfyJSONSchema: """ 解析工作流,生成符合 JSON Schema 标准的 API 规范。 返回一个只包含输入参数的 JSON Schema 对象。 """ # 收集所有输入节点 properties = {} required = [] for _, node in self._nodes_map.items(): title: str = node.title if not title: continue if title.startswith(API_INPUT_PREFIX): # 提取节点名(去掉API_INPUT_PREFIX前缀) node_name = title[len(API_INPUT_PREFIX) :] node_properties = {} node_required = [] if node.inputs: for a_input in node.inputs: if a_input.link is None and a_input.widget: widget_name = a_input.widget.name # 特殊处理:隐藏LoadImage节点的upload字段 if node.type == "LoadImage" and widget_name == "upload": continue # 获取节点的当前值作为默认值 # 创建空的link_map,因为输入节点通常没有连接 empty_link_map = {} node_current_inputs = self._get_node_inputs(node, empty_link_map) current_value = node_current_inputs.get(widget_name) # 生成字段的 JSON Schema 定义,包含默认值 field_schema = self._generate_field_schema_model(a_input, current_value) node_properties[widget_name] = field_schema # 如果字段是必需的,添加到required列表 # TODO: 从输入配置中获取是否必需的信息 node_required.append(widget_name) # 创建节点的JSON Schema if node_properties: properties[node_name] = ComfyJSONSchemaNode( title=f"输入节点: {node_name}", description=f"节点类型: {node.type}", properties=node_properties, required=node_required ) # 如果节点有必需字段,将整个节点标记为必需 if node_required: required.append(node_name) return ComfyJSONSchema(properties=properties, required=required) async def _patch_workflow( self, server: ComfyUIServerInfo, request_data: dict[str, Any], ) -> dict: """ 将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。 request_data结构: {"节点名称": {"字段名称": "字段值"}} """ nodes_map = {k: v.model_copy(deep=True) for k, v in self._nodes_map.items()} for node_name, node_fields in request_data.items(): if node_name not in self._api_spec.inputs: continue node_spec = self._api_spec.inputs[node_name] node_id = node_spec.node_id if node_id not in nodes_map: continue target_node = nodes_map[node_id] # 处理该节点下的所有字段 for field_name, value in node_fields.items(): if field_name not in node_spec.inputs: continue field_spec = node_spec.inputs[field_name] widget_name_to_patch = field_spec.widget_name # 找到所有属于widget的输入,并确定目标widget的索引 widget_inputs = [ inp for inp in (target_node.inputs or []) if inp.widget is not None ] target_widget_index = -1 for i, widget_input in enumerate(widget_inputs): if widget_input.widget and widget_input.widget.name == widget_name_to_patch: target_widget_index = i break if target_widget_index == -1: logger.warning( f"在节点 {node_id} 中未找到名为 '{widget_name_to_patch}' 的 widget。跳过此参数。" ) continue # 确保 `widgets_values` 存在且为列表 if not target_node.widgets_values or not isinstance(target_node.widgets_values, list): # 如果不存在或格式错误,根据widget数量创建一个占位符列表 target_node.widgets_values = [None] * len(widget_inputs) # 确保 `widgets_values` 列表足够长 while len(target_node.widgets_values) <= target_widget_index: target_node.widgets_values.append(None) # 根据API规范转换数据类型 target_type = str if field_spec.type == "int": target_type = int elif field_spec.type == "float": target_type = float # 在正确的位置上更新值 try: if target_node.type == "LoadImage": value = await self._upload_image_to_comfy(server, value) target_node.widgets_values[target_widget_index] = target_type(value) except (ValueError, TypeError) as e: logger.warning( f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec.type}'。错误: {e}" ) continue # 创建新的workflow_data副本 patched_workflow_data = self.workflow_data.model_copy(deep=True) patched_workflow_data.nodes = list(nodes_map.values()) # 对workflow_data中的nodes进行排序,按照id从小到大 patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id) # 转换为字典格式以保持向后兼容 return patched_workflow_data.model_dump() def _convert_workflow_to_prompt_api_format(self, workflow_data: dict) -> dict: """ 将工作流(API格式)转换为提交到/prompt端点的格式。 此函数现在能正确处理已通过 `patch_workflow` 修改的 `widgets_values`。 """ if "nodes" not in workflow_data: raise ValueError("无效的工作流格式") prompt_api_format = {} # 建立从link_id到源节点的映射 link_map = {} for link_data in workflow_data.get("links", []): link_id, origin_node_id, origin_slot_index = ( link_data[0], link_data[1], link_data[2], ) # 键是目标节点的输入link_id link_map[link_id] = [str(origin_node_id), origin_slot_index] for node in workflow_data["nodes"]: node_id = str(node["id"]) inputs_dict = self._get_node_inputs(node, link_map) prompt_api_format[node_id] = { "inputs": inputs_dict, "class_type": node["type"], "_meta": { "title": node.get("title", ""), }, } # 过滤掉类型为 Note 的节点 prompt_api_format = { k: v for k, v in prompt_api_format.items() if v["class_type"] not in ["Note"] } return prompt_api_format async def _upload_image_to_comfy( self, server: ComfyUIServerInfo, file_path: str ) -> str: """ 上传文件到服务器。 file 是一个http链接 返回一个名称 """ file_name = file_path.split("/")[-1] file_data = await self._download_file(file_path) form_data = aiohttp.FormData() form_data.add_field( "image", file_data, filename=file_name, content_type="image/*" ) form_data.add_field("type", "input") async with aiohttp.ClientSession() as session: async with session.post( f"{server.http_url}/api/upload/image", data=form_data ) as response: response.raise_for_status() result = await response.json() return result["name"] async def _download_file(self, src: str) -> bytes: """ 下载文件到本地。 """ async with aiohttp.ClientSession() as session: async with session.get(src) as response: response.raise_for_status() return await response.read() def _generate_field_schema( self, input_field: dict, default_value: Any = None ) -> dict: """ 根据输入字段信息生成 JSON Schema 字段定义(字典版本,向后兼容) Args: input_field: 输入字段配置 default_value: 字段的默认值 """ field_schema = { "title": input_field["widget"]["name"], "description": input_field.get("description", ""), } # 根据字段类型设置 schema 类型 field_type = self._get_param_type(input_field.get("type", "STRING")) if field_type == "int": field_schema.update( { "type": "integer", "minimum": input_field.get("min", None), "maximum": input_field.get("max", None), } ) elif field_type == "float": field_schema.update( { "type": "number", "minimum": input_field.get("min", None), "maximum": input_field.get("max", None), } ) elif field_type == "UploadFile": field_schema.update( { "type": "string", "format": "binary", "contentMediaType": "image/*", "pattern": ".*\\.(jpg|jpeg|png)$", "description": "[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件", } ) else: # string field_schema.update({"type": "string"}) # 处理枚举值 if "options" in input_field: field_schema["enum"] = input_field["options"] field_schema[ "description" ] += f" 可选值: {', '.join(input_field['options'])}" # 设置默认值 if default_value is not None: field_schema["default"] = default_value elif "default" in input_field: field_schema["default"] = input_field["default"] return field_schema def _generate_field_schema_model( self, input_field: ComfyWorkflowNodeInput, default_value: Any = None ) -> ComfyJSONSchemaProperty: """ 根据输入字段信息生成 JSON Schema 字段定义(模型版本) Args: input_field: 输入字段配置 default_value: 字段的默认值 """ title = input_field.widget.name if input_field.widget else input_field.name description = input_field.localized_name or "" # 根据字段类型设置 schema 类型 field_type = self._get_param_type(input_field.type or "STRING") if field_type == "int": return ComfyJSONSchemaProperty( type="integer", title=title, description=description, default=default_value, # TODO: 从输入配置中获取min/max值 ) elif field_type == "float": return ComfyJSONSchemaProperty( type="number", title=title, description=description, default=default_value, # TODO: 从输入配置中获取min/max值 ) elif field_type == "UploadFile": return ComfyJSONSchemaProperty( type="string", title=title, format="binary", contentMediaType="image/*", pattern=".*\\.(jpg|jpeg|png)$", description="[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件", default=default_value, ) else: # string return ComfyJSONSchemaProperty( type="string", title=title, description=description, default=default_value, # TODO: 从输入配置中获取枚举值 ) def _get_param_type(self, input_type_str: str) -> str: """ 根据输入类型字符串确定参数类型 """ input_type_str = input_type_str.upper() if "COMBO" in input_type_str: return "UploadFile" elif "INT" in input_type_str: return "int" elif "FLOAT" in input_type_str: return "float" else: return "string" def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict: """ 获取节点的输入字段值 Args: node: 节点数据(可以是ComfyWorkflowNode模型或字典) link_map: 从link_id到源节点的映射 Returns: dict: 包含节点输入字段值的字典 """ inputs_dict = {} # 兼容处理:如果是字典,转换为模型 if isinstance(node, dict): try: node_model = ComfyWorkflowNode.model_validate(node) except Exception: # 如果转换失败,使用原来的字典方式 return self._get_node_inputs_dict(node, link_map) else: node_model = node # 1. 处理控件输入 (widgets) widgets_values = node_model.widgets_values or [] widget_cursor = 0 for input_config in (node_model.inputs or []): # 如果是widget并且有对应的widgets_values if input_config.widget and widget_cursor < len(widgets_values): widget_name = input_config.widget.name if widget_name: # 使用widgets_values中的值,因为这里已经包含了API传入的修改 inputs_dict[widget_name] = widgets_values[widget_cursor] widget_cursor += 1 # 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值 if node_model.type == "EG_CYQ_JB" and widget_cursor == 1: widget_cursor += 1 # 2. 处理连接输入 (links) for input_config in (node_model.inputs or []): if input_config.link is not None: link_id = input_config.link if link_id in link_map: # 输入名称是input_config中的'name' inputs_dict[input_config.name] = link_map[link_id] return inputs_dict def _get_node_inputs_dict(self, node: dict, link_map: dict) -> dict: """ 获取节点的输入字段值(字典版本,用于向后兼容) Args: node: 节点数据字典 link_map: 从link_id到源节点的映射 Returns: dict: 包含节点输入字段值的字典 """ inputs_dict = {} # 1. 处理控件输入 (widgets) widgets_values = node.get("widgets_values", []) widget_cursor = 0 for input_config in node.get("inputs", []): # 如果是widget并且有对应的widgets_values if "widget" in input_config and widget_cursor < len(widgets_values): widget_name = input_config["widget"].get("name") if widget_name: # 使用widgets_values中的值,因为这里已经包含了API传入的修改 inputs_dict[widget_name] = widgets_values[widget_cursor] widget_cursor += 1 # 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值 if node["type"] == "EG_CYQ_JB" and widget_cursor == 1: widget_cursor += 1 # 2. 处理连接输入 (links) for input_config in node.get("inputs", []): if "link" in input_config and input_config["link"] is not None: link_id = input_config["link"] if link_id in link_map: # 输入名称是input_config中的'name' inputs_dict[input_config["name"]] = link_map[link_id] return inputs_dict