refactor: 更新WorkflowQueueManager以使用ComfyWorkflow类,简化任务添加逻辑并移除冗余API规范参数

This commit is contained in:
iHeyTang 2025-08-21 14:28:42 +08:00
parent 13580de0a1
commit 3c4c689532
4 changed files with 15 additions and 31 deletions

View File

@ -11,7 +11,7 @@ from typing import Dict, Any, Optional
import aiohttp import aiohttp
from aiohttp import ClientTimeout from aiohttp import ClientTimeout
from workflow_service.comfy.comfy_workflow import ComfyAPISpec, ComfyWorkflow from workflow_service.comfy.comfy_workflow import ComfyWorkflow
from workflow_service.config import Settings from workflow_service.config import Settings
from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo
from workflow_service.database.api import ( from workflow_service.database.api import (
@ -163,9 +163,7 @@ class WorkflowQueueManager:
async def add_task( async def add_task(
self, self,
workflow_name: str, workflow: ComfyWorkflow,
workflow_data: dict,
api_spec: ComfyAPISpec,
request_data: dict, request_data: dict,
): ):
"""添加新任务到队列""" """添加新任务到队列"""
@ -174,18 +172,16 @@ class WorkflowQueueManager:
# 创建任务记录 # 创建任务记录
await create_workflow_run( await create_workflow_run(
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
workflow_name=workflow_name, workflow_name=workflow.workflow_name,
workflow_json=json.dumps(workflow_data), workflow_json=json.dumps(workflow.workflow_data.model_dump()),
api_spec=json.dumps(api_spec.model_dump()), api_spec=json.dumps(workflow.get_api_spec().model_dump()),
request_data=json.dumps(request_data), request_data=json.dumps(request_data),
) )
# 创建工作流节点记录 # 创建工作流节点记录
nodes_data = [] nodes_data = []
for node in workflow_data.get("nodes", []): for node in workflow.workflow_data.nodes:
nodes_data.append( nodes_data.append({"id": str(node.id), "type": node.type})
{"id": str(node["id"]), "type": node.get("type", "unknown")}
)
await create_workflow_run_nodes(workflow_run_id, nodes_data) await create_workflow_run_nodes(workflow_run_id, nodes_data)
# 添加到待处理队列 # 添加到待处理队列
@ -263,12 +259,11 @@ class WorkflowQueueManager:
raise Exception(f"找不到工作流运行记录: {workflow_run_id}") raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
workflow_data = json.loads(workflow_run["workflow_json"]) workflow_data = json.loads(workflow_run["workflow_json"])
api_spec = json.loads(workflow_run["api_spec"])
request_data = json.loads(workflow_run["request_data"]) request_data = json.loads(workflow_run["request_data"])
# 执行工作流 # 执行工作流
result = await _execute_prompt_on_server( result = await _execute_prompt_on_server(
workflow_data, api_spec, request_data, server, workflow_run_id workflow_data, request_data, server, workflow_run_id
) )
# 保存处理后的结果到数据库 # 保存处理后的结果到数据库
@ -375,7 +370,6 @@ queue_manager = WorkflowQueueManager(monitor_interval=5)
async def _execute_prompt_on_server( async def _execute_prompt_on_server(
workflow_data: dict, workflow_data: dict,
api_spec: dict,
request_data: dict, request_data: dict,
server: ComfyUIServerInfo, server: ComfyUIServerInfo,
workflow_run_id: str, workflow_run_id: str,

View File

@ -160,14 +160,16 @@ class ComfyWorkflow:
ComfyUI工作流处理器类提供面向对象的工作流管理和处理能力 ComfyUI工作流处理器类提供面向对象的工作流管理和处理能力
""" """
def __init__(self, workflow_data: dict): def __init__(self, workflow_data: dict, workflow_name: str = None):
""" """
初始化工作流实例 初始化工作流实例
Args: Args:
workflow_data: 工作流数据 workflow_data: 工作流数据
workflow_name: 工作流名称
""" """
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data) self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
self.workflow_name = workflow_name
self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes} self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes}
self._api_spec = self._parse_api_spec() self._api_spec = self._parse_api_spec()
self._inputs_json_schema = self._parse_inputs_json_schema() self._inputs_json_schema = self._parse_inputs_json_schema()

View File

@ -40,16 +40,10 @@ async def run_workflow(
raise HTTPException(status_code=404, detail=detail) raise HTTPException(status_code=404, detail=detail)
workflow = json.loads(workflow_data["workflow_json"]) workflow = json.loads(workflow_data["workflow_json"])
flow = ComfyWorkflow(workflow) flow = ComfyWorkflow(workflow, workflow_name)
api_spec = flow.get_api_spec()
# 提交到队列 # 提交到队列
workflow_run_id = await queue_manager.add_task( workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data)
workflow_name=workflow_name,
workflow_data=workflow,
api_spec=api_spec,
request_data=data,
)
return JSONResponse( return JSONResponse(
content={ content={

View File

@ -33,8 +33,7 @@ async def model_with_multi_dress(
raise HTTPException(status_code=404, detail=detail) raise HTTPException(status_code=404, detail=detail)
workflow = json.loads(workflow_data["workflow_json"]) workflow = json.loads(workflow_data["workflow_json"])
flow = ComfyWorkflow(workflow) flow = ComfyWorkflow(workflow, workflow_name)
api_spec = flow.get_api_spec()
# 将请求拆分为多个请求 # 将请求拆分为多个请求
batch_data = _convert(data) batch_data = _convert(data)
@ -42,12 +41,7 @@ async def model_with_multi_dress(
# 提交到队列 # 提交到队列
workflow_run_ids: list[str] = [] workflow_run_ids: list[str] = []
for item in batch_data: for item in batch_data:
workflow_run_id = await queue_manager.add_task( workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=item)
workflow_name=workflow_name,
workflow_data=workflow,
api_spec=api_spec,
request_data=item,
)
workflow_run_ids.append(workflow_run_id) workflow_run_ids.append(workflow_run_id)
return JSONResponse( return JSONResponse(