diff --git a/workflow_service/comfy/comfy_queue.py b/workflow_service/comfy/comfy_queue.py index f5afa44..133db59 100644 --- a/workflow_service/comfy/comfy_queue.py +++ b/workflow_service/comfy/comfy_queue.py @@ -11,7 +11,7 @@ from typing import Dict, Any, Optional import aiohttp from aiohttp import ClientTimeout -from workflow_service.comfy.comfy_workflow import ComfyAPISpec, ComfyWorkflow +from workflow_service.comfy.comfy_workflow import ComfyWorkflow from workflow_service.config import Settings from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo from workflow_service.database.api import ( @@ -163,9 +163,7 @@ class WorkflowQueueManager: async def add_task( self, - workflow_name: str, - workflow_data: dict, - api_spec: ComfyAPISpec, + workflow: ComfyWorkflow, request_data: dict, ): """添加新任务到队列""" @@ -174,18 +172,16 @@ class WorkflowQueueManager: # 创建任务记录 await create_workflow_run( workflow_run_id=workflow_run_id, - workflow_name=workflow_name, - workflow_json=json.dumps(workflow_data), - api_spec=json.dumps(api_spec.model_dump()), + workflow_name=workflow.workflow_name, + workflow_json=json.dumps(workflow.workflow_data.model_dump()), + api_spec=json.dumps(workflow.get_api_spec().model_dump()), request_data=json.dumps(request_data), ) # 创建工作流节点记录 nodes_data = [] - for node in workflow_data.get("nodes", []): - nodes_data.append( - {"id": str(node["id"]), "type": node.get("type", "unknown")} - ) + for node in workflow.workflow_data.nodes: + nodes_data.append({"id": str(node.id), "type": node.type}) await create_workflow_run_nodes(workflow_run_id, nodes_data) # 添加到待处理队列 @@ -263,12 +259,11 @@ class WorkflowQueueManager: raise Exception(f"找不到工作流运行记录: {workflow_run_id}") workflow_data = json.loads(workflow_run["workflow_json"]) - api_spec = json.loads(workflow_run["api_spec"]) request_data = json.loads(workflow_run["request_data"]) # 执行工作流 result = await _execute_prompt_on_server( - workflow_data, api_spec, request_data, server, workflow_run_id + workflow_data, request_data, server, workflow_run_id ) # 保存处理后的结果到数据库 @@ -375,7 +370,6 @@ queue_manager = WorkflowQueueManager(monitor_interval=5) async def _execute_prompt_on_server( workflow_data: dict, - api_spec: dict, request_data: dict, server: ComfyUIServerInfo, workflow_run_id: str, diff --git a/workflow_service/comfy/comfy_workflow.py b/workflow_service/comfy/comfy_workflow.py index a5c56d9..70b0ce4 100644 --- a/workflow_service/comfy/comfy_workflow.py +++ b/workflow_service/comfy/comfy_workflow.py @@ -160,14 +160,16 @@ class ComfyWorkflow: ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力 """ - def __init__(self, workflow_data: dict): + def __init__(self, workflow_data: dict, workflow_name: str = None): """ 初始化工作流实例 Args: workflow_data: 工作流数据 + workflow_name: 工作流名称 """ self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data) + self.workflow_name = workflow_name self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes} self._api_spec = self._parse_api_spec() self._inputs_json_schema = self._parse_inputs_json_schema() diff --git a/workflow_service/routes/run.py b/workflow_service/routes/run.py index 0950b91..d675a98 100644 --- a/workflow_service/routes/run.py +++ b/workflow_service/routes/run.py @@ -40,16 +40,10 @@ async def run_workflow( raise HTTPException(status_code=404, detail=detail) workflow = json.loads(workflow_data["workflow_json"]) - flow = ComfyWorkflow(workflow) - api_spec = flow.get_api_spec() + flow = ComfyWorkflow(workflow, workflow_name) # 提交到队列 - workflow_run_id = await queue_manager.add_task( - workflow_name=workflow_name, - workflow_data=workflow, - api_spec=api_spec, - request_data=data, - ) + workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data) return JSONResponse( content={ diff --git a/workflow_service/routes/runx/model_with_multi_dress.py b/workflow_service/routes/runx/model_with_multi_dress.py index 07a7d15..d708139 100644 --- a/workflow_service/routes/runx/model_with_multi_dress.py +++ b/workflow_service/routes/runx/model_with_multi_dress.py @@ -33,8 +33,7 @@ async def model_with_multi_dress( raise HTTPException(status_code=404, detail=detail) workflow = json.loads(workflow_data["workflow_json"]) - flow = ComfyWorkflow(workflow) - api_spec = flow.get_api_spec() + flow = ComfyWorkflow(workflow, workflow_name) # 将请求拆分为多个请求 batch_data = _convert(data) @@ -42,12 +41,7 @@ async def model_with_multi_dress( # 提交到队列 workflow_run_ids: list[str] = [] for item in batch_data: - workflow_run_id = await queue_manager.add_task( - workflow_name=workflow_name, - workflow_data=workflow, - api_spec=api_spec, - request_data=item, - ) + workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=item) workflow_run_ids.append(workflow_run_id) return JSONResponse(