From fc265570d3b79fc8be96ead3c55da8b181f6031c Mon Sep 17 00:00:00 2001 From: iHeyTang Date: Thu, 21 Aug 2025 14:59:43 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=96=B0ComfyWorkflow?= =?UTF-8?q?=E7=B1=BB=E7=9A=84=E6=9E=84=E9=80=A0=E5=87=BD=E6=95=B0=EF=BC=8C?= =?UTF-8?q?=E7=AE=80=E5=8C=96=E5=B7=A5=E4=BD=9C=E6=B5=81=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E5=B9=B6=E5=9C=A8?= =?UTF-8?q?=E7=9B=B8=E5=85=B3API=E8=B0=83=E7=94=A8=E4=B8=AD=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E6=96=B0=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- workflow_service/comfy/comfy_queue.py | 49 ++++++++-------- workflow_service/comfy/comfy_workflow.py | 57 ++++++++++++------- workflow_service/database/api.py | 4 +- workflow_service/routes/run.py | 2 +- .../routes/runx/model_with_multi_dress.py | 2 +- workflow_service/routes/workflow.py | 3 +- 6 files changed, 66 insertions(+), 51 deletions(-) diff --git a/workflow_service/comfy/comfy_queue.py b/workflow_service/comfy/comfy_queue.py index 133db59..34fcdbc 100644 --- a/workflow_service/comfy/comfy_queue.py +++ b/workflow_service/comfy/comfy_queue.py @@ -258,12 +258,12 @@ class WorkflowQueueManager: if not workflow_run: raise Exception(f"找不到工作流运行记录: {workflow_run_id}") - workflow_data = json.loads(workflow_run["workflow_json"]) - request_data = json.loads(workflow_run["request_data"]) - + workflow_data = json.loads(workflow_run.workflow_json) + request_data = json.loads(workflow_run.request_data) + workflow = ComfyWorkflow(workflow_run.workflow_name, workflow_data) # 执行工作流 result = await _execute_prompt_on_server( - workflow_data, request_data, server, workflow_run_id + workflow, request_data, server, workflow_run_id ) # 保存处理后的结果到数据库 @@ -307,19 +307,19 @@ class WorkflowQueueManager: result = { "id": workflow_run_id, - "status": workflow_run["status"], - "created_at": workflow_run["created_at"], - "started_at": workflow_run["started_at"], - "completed_at": workflow_run["completed_at"], - "server_url": workflow_run["server_url"], - "error_message": workflow_run["error_message"], + "status": workflow_run.status, + "created_at": workflow_run.created_at, + "started_at": workflow_run.started_at, + "completed_at": workflow_run.completed_at, + "server_url": workflow_run.server_url, + "error_message": workflow_run.error_message, "nodes": nodes, } # 如果任务完成,从数据库获取结果 - if workflow_run["status"] == "completed" and workflow_run.get("result"): + if workflow_run.status == "completed" and workflow_run.result: try: - result["result"] = json.loads(workflow_run["result"]) + result["result"] = json.loads(workflow_run.result) except (json.JSONDecodeError, TypeError): result["result"] = None @@ -369,7 +369,7 @@ queue_manager = WorkflowQueueManager(monitor_interval=5) async def _execute_prompt_on_server( - workflow_data: dict, + workflow: ComfyWorkflow, request_data: dict, server: ComfyUIServerInfo, workflow_run_id: str, @@ -381,8 +381,7 @@ async def _execute_prompt_on_server( client_id = str(uuid.uuid4()) # 构建prompt - flow = ComfyWorkflow(workflow_data) - prompt = await flow.build_prompt(server, request_data) + prompt = await workflow.build_prompt(server, request_data) # 更新工作流运行状态,记录prompt_id和client_id await update_workflow_run_status( @@ -394,7 +393,7 @@ async def _execute_prompt_on_server( ) # 提交到ComfyUI - prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url) + prompt_id = await _queue_prompt(workflow, prompt, client_id, server.http_url) # 更新prompt_id await update_workflow_run_status( @@ -407,13 +406,13 @@ async def _execute_prompt_on_server( # 获取执行结果,现在支持节点级别的状态跟踪 results = await _get_execution_results( - workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id + workflow, prompt_id, client_id, server.ws_url, workflow_run_id ) return results async def _queue_prompt( - workflow: dict, prompt: dict, client_id: str, http_url: str + workflow: ComfyWorkflow, prompt: dict, client_id: str, http_url: str ) -> str: """通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。""" for node_id in prompt: @@ -423,7 +422,7 @@ async def _queue_prompt( "client_id": client_id, "extra_data": { "api_key_comfy_org": "", - "extra_pnginfo": {"workflow": workflow}, + "extra_pnginfo": {"workflow": workflow.workflow_data.model_dump()}, }, } async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session: @@ -443,7 +442,7 @@ async def _queue_prompt( async def _get_execution_results( - workflow_data: Dict, + workflow: ComfyWorkflow, prompt_id: str, client_id: str, ws_url: str, @@ -511,17 +510,17 @@ async def _get_execution_results( node = next( ( x - for x in workflow_data["nodes"] - if str(x["id"]) == node_id + for x in workflow.workflow_data.nodes + if str(x.id) == node_id ), None, ) if ( node - and node.get("title", "") - and node["title"].startswith(API_OUTPUT_PREFIX) + and node.title + and node.title.startswith(API_OUTPUT_PREFIX) ): - title = node["title"].replace(API_OUTPUT_PREFIX, "") + title = node.title.replace(API_OUTPUT_PREFIX, "") aggregated_outputs[title] = output_data logger.info( f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})" diff --git a/workflow_service/comfy/comfy_workflow.py b/workflow_service/comfy/comfy_workflow.py index 70b0ce4..6175ba7 100644 --- a/workflow_service/comfy/comfy_workflow.py +++ b/workflow_service/comfy/comfy_workflow.py @@ -16,6 +16,7 @@ API_OUTPUT_PREFIX = "OUTPUT_" class ComfyWorkflowNodeWidget(BaseModel): """节点控件定义""" + name: str @@ -52,17 +53,20 @@ class ComfyWorkflowNodeProperties(BaseModel): class ComfyAPIFieldSpec(BaseModel): """API字段规范""" + type: str widget_name: str class ComfyAPIOutputSpec(BaseModel): """API输出规范""" + output_name: str class ComfyAPINodeSpec(BaseModel): """API节点规范""" + node_id: str class_type: str inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {} @@ -71,12 +75,14 @@ class ComfyAPINodeSpec(BaseModel): class ComfyAPISpec(BaseModel): """API规范定义""" + inputs: Dict[str, ComfyAPINodeSpec] = {} outputs: Dict[str, ComfyAPINodeSpec] = {} class ComfyJSONSchemaProperty(BaseModel): """JSON Schema属性定义""" + type: str title: Optional[str] = None description: Optional[str] = None @@ -91,6 +97,7 @@ class ComfyJSONSchemaProperty(BaseModel): class ComfyJSONSchemaNode(BaseModel): """JSON Schema节点定义""" + type: str = "object" title: str description: str @@ -101,6 +108,7 @@ class ComfyJSONSchemaNode(BaseModel): class ComfyJSONSchema(BaseModel): """JSON Schema定义""" + schema: str = "http://json-schema.org/draft-07/schema#" type: str = "object" title: str = "ComfyUI Workflow Input Schema" @@ -160,13 +168,13 @@ class ComfyWorkflow: ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力 """ - def __init__(self, workflow_data: dict, workflow_name: str = None): + def __init__(self, workflow_name: str, workflow_data: dict): """ 初始化工作流实例 Args: - workflow_data: 工作流数据 workflow_name: 工作流名称 + workflow_data: 工作流数据 """ self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data) self.workflow_name = workflow_name @@ -262,10 +270,7 @@ class ComfyWorkflow: # 如果节点名不在inputs中,创建新的节点条目 if node_name not in inputs: inputs[node_name] = ComfyAPINodeSpec( - node_id=node_id, - class_type=node.type, - inputs={}, - outputs={} + node_id=node_id, class_type=node.type, inputs={}, outputs={} ) if node.inputs: @@ -280,7 +285,7 @@ class ComfyWorkflow: # 直接使用widget_name作为字段名 inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec( type=self._get_param_type(a_input.type or "STRING"), - widget_name=a_input.widget.name + widget_name=a_input.widget.name, ) elif title.startswith(API_OUTPUT_PREFIX): @@ -290,10 +295,7 @@ class ComfyWorkflow: # 如果节点名不在outputs中,创建新的节点条目 if node_name not in outputs: outputs[node_name] = ComfyAPINodeSpec( - node_id=node_id, - class_type=node.type, - inputs={}, - outputs={} + node_id=node_id, class_type=node.type, inputs={}, outputs={} ) if node.outputs: @@ -340,11 +342,15 @@ class ComfyWorkflow: # 获取节点的当前值作为默认值 # 创建空的link_map,因为输入节点通常没有连接 empty_link_map = {} - node_current_inputs = self._get_node_inputs(node, empty_link_map) + node_current_inputs = self._get_node_inputs( + node, empty_link_map + ) current_value = node_current_inputs.get(widget_name) # 生成字段的 JSON Schema 定义,包含默认值 - field_schema = self._generate_field_schema_model(a_input, current_value) + field_schema = self._generate_field_schema_model( + a_input, current_value + ) node_properties[widget_name] = field_schema # 如果字段是必需的,添加到required列表 @@ -357,7 +363,7 @@ class ComfyWorkflow: title=f"输入节点: {node_name}", description=f"节点类型: {node.type}", properties=node_properties, - required=node_required + required=node_required, ) # 如果节点有必需字段,将整个节点标记为必需 @@ -405,7 +411,10 @@ class ComfyWorkflow: target_widget_index = -1 for i, widget_input in enumerate(widget_inputs): - if widget_input.widget and widget_input.widget.name == widget_name_to_patch: + if ( + widget_input.widget + and widget_input.widget.name == widget_name_to_patch + ): target_widget_index = i break @@ -416,7 +425,9 @@ class ComfyWorkflow: continue # 确保 `widgets_values` 存在且为列表 - if not target_node.widgets_values or not isinstance(target_node.widgets_values, list): + if not target_node.widgets_values or not isinstance( + target_node.widgets_values, list + ): # 如果不存在或格式错误,根据widget数量创建一个占位符列表 target_node.widgets_values = [None] * len(widget_inputs) @@ -447,7 +458,9 @@ class ComfyWorkflow: patched_workflow_data.nodes = list(nodes_map.values()) # 对workflow_data中的nodes进行排序,按照id从小到大 - patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id) + patched_workflow_data.nodes = sorted( + patched_workflow_data.nodes, key=lambda x: x.id + ) # 转换为字典格式以保持向后兼容 return patched_workflow_data.model_dump() @@ -655,7 +668,9 @@ class ComfyWorkflow: else: return "string" - def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict: + def _get_node_inputs( + self, node: Union[ComfyWorkflowNode, dict], link_map: dict + ) -> dict: """ 获取节点的输入字段值 @@ -681,8 +696,8 @@ class ComfyWorkflow: # 1. 处理控件输入 (widgets) widgets_values = node_model.widgets_values or [] widget_cursor = 0 - - for input_config in (node_model.inputs or []): + + for input_config in node_model.inputs or []: # 如果是widget并且有对应的widgets_values if input_config.widget and widget_cursor < len(widgets_values): widget_name = input_config.widget.name @@ -696,7 +711,7 @@ class ComfyWorkflow: widget_cursor += 1 # 2. 处理连接输入 (links) - for input_config in (node_model.inputs or []): + for input_config in node_model.inputs or []: if input_config.link is not None: link_id = input_config.link if link_id in link_map: diff --git a/workflow_service/database/api.py b/workflow_service/database/api.py index 54026e9..9f98261 100644 --- a/workflow_service/database/api.py +++ b/workflow_service/database/api.py @@ -219,11 +219,11 @@ async def update_workflow_run_node_status( await session.commit() -async def get_workflow_run(workflow_run_id: str) -> Optional[dict]: +async def get_workflow_run(workflow_run_id: str) -> Optional[WorkflowRun]: """获取工作流运行记录""" async with AsyncSessionLocal() as session: workflow_run = await session.get(WorkflowRun, workflow_run_id) - return workflow_run.to_dict() if workflow_run else None + return workflow_run async def get_workflow_run_nodes(workflow_run_id: str) -> List[dict]: diff --git a/workflow_service/routes/run.py b/workflow_service/routes/run.py index d675a98..01df260 100644 --- a/workflow_service/routes/run.py +++ b/workflow_service/routes/run.py @@ -40,7 +40,7 @@ async def run_workflow( raise HTTPException(status_code=404, detail=detail) workflow = json.loads(workflow_data["workflow_json"]) - flow = ComfyWorkflow(workflow, workflow_name) + flow = ComfyWorkflow(workflow_name, workflow) # 提交到队列 workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data) diff --git a/workflow_service/routes/runx/model_with_multi_dress.py b/workflow_service/routes/runx/model_with_multi_dress.py index d708139..3f9b47f 100644 --- a/workflow_service/routes/runx/model_with_multi_dress.py +++ b/workflow_service/routes/runx/model_with_multi_dress.py @@ -33,7 +33,7 @@ async def model_with_multi_dress( raise HTTPException(status_code=404, detail=detail) workflow = json.loads(workflow_data["workflow_json"]) - flow = ComfyWorkflow(workflow, workflow_name) + flow = ComfyWorkflow(workflow_name, workflow) # 将请求拆分为多个请求 batch_data = _convert(data) diff --git a/workflow_service/routes/workflow.py b/workflow_service/routes/workflow.py index dea9ac0..f4ad2f5 100644 --- a/workflow_service/routes/workflow.py +++ b/workflow_service/routes/workflow.py @@ -92,8 +92,9 @@ async def get_one_workflow_endpoint(base_name: str, version: Optional[str] = Non try: workflow = json.loads(workflow_data["workflow_json"]) - flow = ComfyWorkflow(workflow) + flow = ComfyWorkflow(base_name, workflow) return { + "name": flow.workflow_name, "workflow": flow.workflow_data, "api_spec": flow.get_api_spec(), "inputs_json_schema": flow.get_inputs_json_schema(),