diff --git a/workflow_service/comfy/comfy_queue.py b/workflow_service/comfy/comfy_queue.py index 17061bb..f5afa44 100644 --- a/workflow_service/comfy/comfy_queue.py +++ b/workflow_service/comfy/comfy_queue.py @@ -11,7 +11,7 @@ from typing import Dict, Any, Optional import aiohttp from aiohttp import ClientTimeout -from workflow_service.comfy.comfy_workflow import ComfyWorkflow +from workflow_service.comfy.comfy_workflow import ComfyAPISpec, ComfyWorkflow from workflow_service.config import Settings from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo from workflow_service.database.api import ( @@ -165,7 +165,7 @@ class WorkflowQueueManager: self, workflow_name: str, workflow_data: dict, - api_spec: dict, + api_spec: ComfyAPISpec, request_data: dict, ): """添加新任务到队列""" @@ -176,7 +176,7 @@ class WorkflowQueueManager: workflow_run_id=workflow_run_id, workflow_name=workflow_name, workflow_json=json.dumps(workflow_data), - api_spec=json.dumps(api_spec), + api_spec=json.dumps(api_spec.model_dump()), request_data=json.dumps(request_data), ) diff --git a/workflow_service/comfy/comfy_workflow.py b/workflow_service/comfy/comfy_workflow.py index 408a534..a5c56d9 100644 --- a/workflow_service/comfy/comfy_workflow.py +++ b/workflow_service/comfy/comfy_workflow.py @@ -1,7 +1,8 @@ import logging -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Union import aiohttp +from pydantic import BaseModel from workflow_service.comfy.comfy_server import ComfyUIServerInfo @@ -13,6 +14,147 @@ API_INPUT_PREFIX = "INPUT_" API_OUTPUT_PREFIX = "OUTPUT_" +class ComfyWorkflowNodeWidget(BaseModel): + """节点控件定义""" + name: str + + +class ComfyWorkflowNodeInput(BaseModel): + """节点输入定义""" + + label: Optional[str] = None + localized_name: Optional[str] = None + name: str + shape: Optional[int] = None + type: str + link: Optional[int] = None + widget: Optional[ComfyWorkflowNodeWidget] = None + + +class ComfyWorkflowNodeOutput(BaseModel): + """节点输出定义""" + + label: Optional[str] = None + localized_name: Optional[str] = None + name: str + type: str + links: Optional[List[int]] = None + + +class ComfyWorkflowNodeProperties(BaseModel): + """节点属性定义""" + + cnr_id: Optional[str] = None + aux_id: Optional[str] = None + ver: Optional[str] = None + Node_name_for_SR: Optional[str] = None + + +class ComfyAPIFieldSpec(BaseModel): + """API字段规范""" + type: str + widget_name: str + + +class ComfyAPIOutputSpec(BaseModel): + """API输出规范""" + output_name: str + + +class ComfyAPINodeSpec(BaseModel): + """API节点规范""" + node_id: str + class_type: str + inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {} + outputs: Optional[Dict[str, ComfyAPIOutputSpec]] = {} + + +class ComfyAPISpec(BaseModel): + """API规范定义""" + inputs: Dict[str, ComfyAPINodeSpec] = {} + outputs: Dict[str, ComfyAPINodeSpec] = {} + + +class ComfyJSONSchemaProperty(BaseModel): + """JSON Schema属性定义""" + type: str + title: Optional[str] = None + description: Optional[str] = None + default: Optional[Any] = None + minimum: Optional[Union[int, float]] = None + maximum: Optional[Union[int, float]] = None + enum: Optional[List[str]] = None + format: Optional[str] = None + contentMediaType: Optional[str] = None + pattern: Optional[str] = None + + +class ComfyJSONSchemaNode(BaseModel): + """JSON Schema节点定义""" + type: str = "object" + title: str + description: str + properties: Dict[str, ComfyJSONSchemaProperty] = {} + required: List[str] = [] + additionalProperties: bool = False + + +class ComfyJSONSchema(BaseModel): + """JSON Schema定义""" + schema: str = "http://json-schema.org/draft-07/schema#" + type: str = "object" + title: str = "ComfyUI Workflow Input Schema" + description: str = "工作流输入参数定义" + properties: Dict[str, ComfyJSONSchemaNode] = {} + required: List[str] = [] + + class Config: + fields = {"schema": "$schema"} + + +class ComfyWorkflowNode(BaseModel): + """工作流节点定义""" + + id: int + type: str + pos: List[float] + size: List[float] + flags: Dict[str, Any] = {} + order: int + mode: int + inputs: Optional[List[ComfyWorkflowNodeInput]] = [] + outputs: Optional[List[ComfyWorkflowNodeOutput]] = [] + properties: ComfyWorkflowNodeProperties = ComfyWorkflowNodeProperties() + widgets_values: Optional[List[Any]] = [] + title: Optional[str] = None + + +class ComfyWorkflowExtra(BaseModel): + """工作流额外配置""" + + ds: Optional[Dict[str, Any]] = None + frontendVersion: Optional[str] = None + VHS_latentpreview: Optional[bool] = None + VHS_latentpreviewrate: Optional[int] = None + VHS_MetadataImage: Optional[bool] = None + VHS_KeepIntermediate: Optional[bool] = None + + +class ComfyWorkflowDataSpec(BaseModel): + """ComfyUI工作流数据规范""" + + id: str + revision: int = 0 + last_node_id: int + last_link_id: int + nodes: List[ComfyWorkflowNode] + links: List[List[Union[int, str]]] = [] + groups: List[Any] = [] + config: Dict[str, Any] = {} + extra: ComfyWorkflowExtra = ComfyWorkflowExtra() + version: float + + class ComfyWorkflow: """ ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力 @@ -25,8 +167,8 @@ class ComfyWorkflow: Args: workflow_data: 工作流数据 """ - self.workflow_data = workflow_data - self._nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} + self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data) + self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes} self._api_spec = self._parse_api_spec() self._inputs_json_schema = self._parse_inputs_json_schema() @@ -80,33 +222,34 @@ class ComfyWorkflow: return prompt_api_format - def get_api_spec(self) -> Dict[str, Dict[str, Any]]: + def get_api_spec(self) -> ComfyAPISpec: """ 获取API规范 Returns: - Dict[str, Dict[str, Any]]: API规范 + ComfyAPISpec: API规范 """ return self._api_spec - def get_inputs_json_schema(self) -> dict: + def get_inputs_json_schema(self) -> ComfyJSONSchema: """ 获取输入参数的JSON Schema Returns: - dict: JSON Schema + ComfyJSONSchema: JSON Schema """ return self._inputs_json_schema - def _parse_api_spec(self) -> Dict[str, Dict[str, Any]]: + def _parse_api_spec(self) -> ComfyAPISpec: """ 解析工作流,并根据规范生成嵌套结构的API参数名。 结构: {'节点名': {'node_id': '节点ID', '字段名': {字段信息}}} """ - spec = {"inputs": {}, "outputs": {}} + inputs = {} + outputs = {} for node_id, node in self._nodes_map.items(): - title: str = node.get("title") + title: str = node.title if not title: continue @@ -115,82 +258,64 @@ class ComfyWorkflow: node_name = title[len(API_INPUT_PREFIX) :] # 如果节点名不在inputs中,创建新的节点条目 - if node_name not in spec["inputs"]: - spec["inputs"][node_name] = { - "node_id": node_id, - "class_type": node.get("type"), - } + if node_name not in inputs: + inputs[node_name] = ComfyAPINodeSpec( + node_id=node_id, + class_type=node.type, + inputs={}, + outputs={} + ) - if "inputs" in node: - for a_input in node.get("inputs", []): - if a_input.get("link") is None and "widget" in a_input: - widget_name = a_input["widget"]["name"] + if node.inputs: + for a_input in node.inputs: + if a_input.link is None and a_input.widget: + widget_name = a_input.widget.name # 特殊处理:隐藏LoadImage节点的upload字段 - if ( - node.get("type") == "LoadImage" - and widget_name == "upload" - ): + if node.type == "LoadImage" and widget_name == "upload": continue - # 确保inputs字段存在 - if "inputs" not in spec["inputs"][node_name]: - spec["inputs"][node_name]["inputs"] = {} - # 直接使用widget_name作为字段名 - spec["inputs"][node_name]["inputs"][widget_name] = { - "type": self._get_param_type( - a_input.get("type", "STRING") - ), - "widget_name": a_input["widget"]["name"], - } + inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec( + type=self._get_param_type(a_input.type or "STRING"), + widget_name=a_input.widget.name + ) elif title.startswith(API_OUTPUT_PREFIX): # 提取节点名(去掉API_OUTPUT_PREFIX前缀) node_name = title[len(API_OUTPUT_PREFIX) :] # 如果节点名不在outputs中,创建新的节点条目 - if node_name not in spec["outputs"]: - spec["outputs"][node_name] = { - "node_id": node_id, - "class_type": node.get("type"), - } + if node_name not in outputs: + outputs[node_name] = ComfyAPINodeSpec( + node_id=node_id, + class_type=node.type, + inputs={}, + outputs={} + ) - if "outputs" in node: - for an_output in node.get("outputs", []): - output_name = an_output["name"] - - # 确保outputs字段存在 - if "outputs" not in spec["outputs"][node_name]: - spec["outputs"][node_name]["outputs"] = {} + if node.outputs: + for an_output in node.outputs: + output_name = an_output.name # 直接使用output_name作为字段名 - spec["outputs"][node_name]["outputs"][output_name] = { - "output_name": an_output["name"], - } + outputs[node_name].outputs[output_name] = ComfyAPIOutputSpec( + output_name=an_output.name + ) - return spec + return ComfyAPISpec(inputs=inputs, outputs=outputs) - def _parse_inputs_json_schema(self) -> dict: + def _parse_inputs_json_schema(self) -> ComfyJSONSchema: """ 解析工作流,生成符合 JSON Schema 标准的 API 规范。 返回一个只包含输入参数的 JSON Schema 对象。 """ - # 基础 JSON Schema 结构,只包含输入 - schema = { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "title": "ComfyUI Workflow Input Schema", - "description": "工作流输入参数定义", - "properties": {}, - "required": [], - } - # 收集所有输入节点 - input_nodes = {} + properties = {} + required = [] - for node_id, node in self._nodes_map.items(): - title: str = node.get("title") + for _, node in self._nodes_map.items(): + title: str = node.title if not title: continue @@ -198,62 +323,46 @@ class ComfyWorkflow: # 提取节点名(去掉API_INPUT_PREFIX前缀) node_name = title[len(API_INPUT_PREFIX) :] - if node_name not in input_nodes: - input_nodes[node_name] = { - "node_id": node_id, - "class_type": node.get("type"), - "properties": {}, - "required": [], - } + node_properties = {} + node_required = [] - if "inputs" in node: - for a_input in node.get("inputs", []): - if a_input.get("link") is None and "widget" in a_input: - widget_name = a_input["widget"]["name"] + if node.inputs: + for a_input in node.inputs: + if a_input.link is None and a_input.widget: + widget_name = a_input.widget.name # 特殊处理:隐藏LoadImage节点的upload字段 - if ( - node.get("type") == "LoadImage" - and widget_name == "upload" - ): + if node.type == "LoadImage" and widget_name == "upload": continue # 获取节点的当前值作为默认值 # 创建空的link_map,因为输入节点通常没有连接 empty_link_map = {} - node_current_inputs = self._get_node_inputs( - node, empty_link_map - ) + node_current_inputs = self._get_node_inputs(node, empty_link_map) current_value = node_current_inputs.get(widget_name) # 生成字段的 JSON Schema 定义,包含默认值 - field_schema = self._generate_field_schema( - a_input, current_value - ) - input_nodes[node_name]["properties"][ - widget_name - ] = field_schema + field_schema = self._generate_field_schema_model(a_input, current_value) + node_properties[widget_name] = field_schema # 如果字段是必需的,添加到required列表 - if a_input.get("required", True): # 默认所有字段都是必需的 - input_nodes[node_name]["required"].append(widget_name) + # TODO: 从输入配置中获取是否必需的信息 + node_required.append(widget_name) - # 将收集的节点信息转换为 JSON Schema 格式 - for node_name, node_info in input_nodes.items(): - schema["properties"][node_name] = { - "type": "object", - "title": f"输入节点: {node_name}", - "description": f"节点类型: {node_info['class_type']}", - "properties": node_info["properties"], - "required": node_info["required"] if node_info["required"] else [], - "additionalProperties": False, - } + # 创建节点的JSON Schema + if node_properties: + properties[node_name] = ComfyJSONSchemaNode( + title=f"输入节点: {node_name}", + description=f"节点类型: {node.type}", + properties=node_properties, + required=node_required + ) - # 如果节点有必需字段,将整个节点标记为必需 - if node_info["required"]: - schema["required"].append(node_name) + # 如果节点有必需字段,将整个节点标记为必需 + if node_required: + required.append(node_name) - return schema + return ComfyJSONSchema(properties=properties, required=required) async def _patch_workflow( self, @@ -265,17 +374,15 @@ class ComfyWorkflow: request_data结构: {"节点名称": {"字段名称": "字段值"}} """ - if "nodes" not in self.workflow_data: - raise ValueError("无效的工作流格式") - nodes_map = self._nodes_map.copy() + nodes_map = {k: v.model_copy(deep=True) for k, v in self._nodes_map.items()} for node_name, node_fields in request_data.items(): - if node_name not in self._api_spec["inputs"]: + if node_name not in self._api_spec.inputs: continue - node_spec = self._api_spec["inputs"][node_name] - node_id = node_spec["node_id"] + node_spec = self._api_spec.inputs[node_name] + node_id = node_spec.node_id if node_id not in nodes_map: continue @@ -283,24 +390,20 @@ class ComfyWorkflow: # 处理该节点下的所有字段 for field_name, value in node_fields.items(): - if "inputs" not in node_spec or field_name not in node_spec["inputs"]: + if field_name not in node_spec.inputs: continue - field_spec = node_spec["inputs"][field_name] - widget_name_to_patch = field_spec["widget_name"] + field_spec = node_spec.inputs[field_name] + widget_name_to_patch = field_spec.widget_name # 找到所有属于widget的输入,并确定目标widget的索引 widget_inputs = [ - inp for inp in target_node.get("inputs", []) if "widget" in inp + inp for inp in (target_node.inputs or []) if inp.widget is not None ] target_widget_index = -1 for i, widget_input in enumerate(widget_inputs): - # "widget" 字典可能不存在或没有 "name" 键 - if ( - widget_input.get("widget", {}).get("name") - == widget_name_to_patch - ): + if widget_input.widget and widget_input.widget.name == widget_name_to_patch: target_widget_index = i break @@ -311,46 +414,41 @@ class ComfyWorkflow: continue # 确保 `widgets_values` 存在且为列表 - if "widgets_values" not in target_node or not isinstance( - target_node.get("widgets_values"), list - ): + if not target_node.widgets_values or not isinstance(target_node.widgets_values, list): # 如果不存在或格式错误,根据widget数量创建一个占位符列表 - target_node["widgets_values"] = [None] * len(widget_inputs) + target_node.widgets_values = [None] * len(widget_inputs) # 确保 `widgets_values` 列表足够长 - while len(target_node["widgets_values"]) <= target_widget_index: - target_node["widgets_values"].append(None) + while len(target_node.widgets_values) <= target_widget_index: + target_node.widgets_values.append(None) # 根据API规范转换数据类型 target_type = str - if field_spec["type"] == "int": + if field_spec.type == "int": target_type = int - elif field_spec["type"] == "float": + elif field_spec.type == "float": target_type = float # 在正确的位置上更新值 try: - if target_node.get("type") == "LoadImage": + if target_node.type == "LoadImage": value = await self._upload_image_to_comfy(server, value) - target_node["widgets_values"][target_widget_index] = target_type( - value - ) + target_node.widgets_values[target_widget_index] = target_type(value) except (ValueError, TypeError) as e: logger.warning( - f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec['type']}'。错误: {e}" + f"无法将节点 '{node_name}' 的字段 '{field_name}' 的值 '{value}' 转换为类型 '{field_spec.type}'。错误: {e}" ) continue # 创建新的workflow_data副本 - patched_workflow = self.workflow_data.copy() - patched_workflow["nodes"] = list(nodes_map.values()) + patched_workflow_data = self.workflow_data.model_copy(deep=True) + patched_workflow_data.nodes = list(nodes_map.values()) # 对workflow_data中的nodes进行排序,按照id从小到大 - patched_workflow["nodes"] = sorted( - patched_workflow["nodes"], key=lambda x: x["id"] - ) + patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id) - return patched_workflow + # 转换为字典格式以保持向后兼容 + return patched_workflow_data.model_dump() def _convert_workflow_to_prompt_api_format(self, workflow_data: dict) -> dict: """ @@ -432,7 +530,7 @@ class ComfyWorkflow: self, input_field: dict, default_value: Any = None ) -> dict: """ - 根据输入字段信息生成 JSON Schema 字段定义 + 根据输入字段信息生成 JSON Schema 字段定义(字典版本,向后兼容) Args: input_field: 输入字段配置 @@ -490,6 +588,57 @@ class ComfyWorkflow: return field_schema + def _generate_field_schema_model( + self, input_field: ComfyWorkflowNodeInput, default_value: Any = None + ) -> ComfyJSONSchemaProperty: + """ + 根据输入字段信息生成 JSON Schema 字段定义(模型版本) + + Args: + input_field: 输入字段配置 + default_value: 字段的默认值 + """ + title = input_field.widget.name if input_field.widget else input_field.name + description = input_field.localized_name or "" + + # 根据字段类型设置 schema 类型 + field_type = self._get_param_type(input_field.type or "STRING") + + if field_type == "int": + return ComfyJSONSchemaProperty( + type="integer", + title=title, + description=description, + default=default_value, + # TODO: 从输入配置中获取min/max值 + ) + elif field_type == "float": + return ComfyJSONSchemaProperty( + type="number", + title=title, + description=description, + default=default_value, + # TODO: 从输入配置中获取min/max值 + ) + elif field_type == "UploadFile": + return ComfyJSONSchemaProperty( + type="string", + title=title, + format="binary", + contentMediaType="image/*", + pattern=".*\\.(jpg|jpeg|png)$", + description="[ext:jpg,jpeg,png][file-size:1MB] 上传图片文件", + default=default_value, + ) + else: # string + return ComfyJSONSchemaProperty( + type="string", + title=title, + description=description, + default=default_value, + # TODO: 从输入配置中获取枚举值 + ) + def _get_param_type(self, input_type_str: str) -> str: """ 根据输入类型字符串确定参数类型 @@ -504,12 +653,62 @@ class ComfyWorkflow: else: return "string" - def _get_node_inputs(self, node: dict, link_map: dict) -> dict: + def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict: """ 获取节点的输入字段值 Args: - node: 节点数据 + node: 节点数据(可以是ComfyWorkflowNode模型或字典) + link_map: 从link_id到源节点的映射 + + Returns: + dict: 包含节点输入字段值的字典 + """ + inputs_dict = {} + + # 兼容处理:如果是字典,转换为模型 + if isinstance(node, dict): + try: + node_model = ComfyWorkflowNode.model_validate(node) + except Exception: + # 如果转换失败,使用原来的字典方式 + return self._get_node_inputs_dict(node, link_map) + else: + node_model = node + + # 1. 处理控件输入 (widgets) + widgets_values = node_model.widgets_values or [] + widget_cursor = 0 + + for input_config in (node_model.inputs or []): + # 如果是widget并且有对应的widgets_values + if input_config.widget and widget_cursor < len(widgets_values): + widget_name = input_config.widget.name + if widget_name: + # 使用widgets_values中的值,因为这里已经包含了API传入的修改 + inputs_dict[widget_name] = widgets_values[widget_cursor] + widget_cursor += 1 + + # 特殊处理:如果节点是 EG_CYQ_JB ,则需要忽略widgets_values的第1个值,这个是用来控制control net的,它不参与赋值 + if node_model.type == "EG_CYQ_JB" and widget_cursor == 1: + widget_cursor += 1 + + # 2. 处理连接输入 (links) + for input_config in (node_model.inputs or []): + if input_config.link is not None: + link_id = input_config.link + if link_id in link_map: + # 输入名称是input_config中的'name' + inputs_dict[input_config.name] = link_map[link_id] + + return inputs_dict + + def _get_node_inputs_dict(self, node: dict, link_map: dict) -> dict: + """ + 获取节点的输入字段值(字典版本,用于向后兼容) + + Args: + node: 节点数据字典 link_map: 从link_id到源节点的映射 Returns: