refactor: 更新WorkflowQueueManager以使用ComfyWorkflow类,简化任务添加逻辑并移除冗余API规范参数
This commit is contained in:
parent
13580de0a1
commit
3c4c689532
|
|
@ -11,7 +11,7 @@ from typing import Dict, Any, Optional
|
|||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
|
||||
from workflow_service.comfy.comfy_workflow import ComfyAPISpec, ComfyWorkflow
|
||||
from workflow_service.comfy.comfy_workflow import ComfyWorkflow
|
||||
from workflow_service.config import Settings
|
||||
from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo
|
||||
from workflow_service.database.api import (
|
||||
|
|
@ -163,9 +163,7 @@ class WorkflowQueueManager:
|
|||
|
||||
async def add_task(
|
||||
self,
|
||||
workflow_name: str,
|
||||
workflow_data: dict,
|
||||
api_spec: ComfyAPISpec,
|
||||
workflow: ComfyWorkflow,
|
||||
request_data: dict,
|
||||
):
|
||||
"""添加新任务到队列"""
|
||||
|
|
@ -174,18 +172,16 @@ class WorkflowQueueManager:
|
|||
# 创建任务记录
|
||||
await create_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_name=workflow_name,
|
||||
workflow_json=json.dumps(workflow_data),
|
||||
api_spec=json.dumps(api_spec.model_dump()),
|
||||
workflow_name=workflow.workflow_name,
|
||||
workflow_json=json.dumps(workflow.workflow_data.model_dump()),
|
||||
api_spec=json.dumps(workflow.get_api_spec().model_dump()),
|
||||
request_data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
# 创建工作流节点记录
|
||||
nodes_data = []
|
||||
for node in workflow_data.get("nodes", []):
|
||||
nodes_data.append(
|
||||
{"id": str(node["id"]), "type": node.get("type", "unknown")}
|
||||
)
|
||||
for node in workflow.workflow_data.nodes:
|
||||
nodes_data.append({"id": str(node.id), "type": node.type})
|
||||
await create_workflow_run_nodes(workflow_run_id, nodes_data)
|
||||
|
||||
# 添加到待处理队列
|
||||
|
|
@ -263,12 +259,11 @@ class WorkflowQueueManager:
|
|||
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
|
||||
|
||||
workflow_data = json.loads(workflow_run["workflow_json"])
|
||||
api_spec = json.loads(workflow_run["api_spec"])
|
||||
request_data = json.loads(workflow_run["request_data"])
|
||||
|
||||
# 执行工作流
|
||||
result = await _execute_prompt_on_server(
|
||||
workflow_data, api_spec, request_data, server, workflow_run_id
|
||||
workflow_data, request_data, server, workflow_run_id
|
||||
)
|
||||
|
||||
# 保存处理后的结果到数据库
|
||||
|
|
@ -375,7 +370,6 @@ queue_manager = WorkflowQueueManager(monitor_interval=5)
|
|||
|
||||
async def _execute_prompt_on_server(
|
||||
workflow_data: dict,
|
||||
api_spec: dict,
|
||||
request_data: dict,
|
||||
server: ComfyUIServerInfo,
|
||||
workflow_run_id: str,
|
||||
|
|
|
|||
|
|
@ -160,14 +160,16 @@ class ComfyWorkflow:
|
|||
ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力
|
||||
"""
|
||||
|
||||
def __init__(self, workflow_data: dict):
|
||||
def __init__(self, workflow_data: dict, workflow_name: str = None):
|
||||
"""
|
||||
初始化工作流实例
|
||||
|
||||
Args:
|
||||
workflow_data: 工作流数据
|
||||
workflow_name: 工作流名称
|
||||
"""
|
||||
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
|
||||
self.workflow_name = workflow_name
|
||||
self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes}
|
||||
self._api_spec = self._parse_api_spec()
|
||||
self._inputs_json_schema = self._parse_inputs_json_schema()
|
||||
|
|
|
|||
|
|
@ -40,16 +40,10 @@ async def run_workflow(
|
|||
raise HTTPException(status_code=404, detail=detail)
|
||||
|
||||
workflow = json.loads(workflow_data["workflow_json"])
|
||||
flow = ComfyWorkflow(workflow)
|
||||
api_spec = flow.get_api_spec()
|
||||
flow = ComfyWorkflow(workflow, workflow_name)
|
||||
|
||||
# 提交到队列
|
||||
workflow_run_id = await queue_manager.add_task(
|
||||
workflow_name=workflow_name,
|
||||
workflow_data=workflow,
|
||||
api_spec=api_spec,
|
||||
request_data=data,
|
||||
)
|
||||
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
|
|
|
|||
|
|
@ -33,8 +33,7 @@ async def model_with_multi_dress(
|
|||
raise HTTPException(status_code=404, detail=detail)
|
||||
|
||||
workflow = json.loads(workflow_data["workflow_json"])
|
||||
flow = ComfyWorkflow(workflow)
|
||||
api_spec = flow.get_api_spec()
|
||||
flow = ComfyWorkflow(workflow, workflow_name)
|
||||
|
||||
# 将请求拆分为多个请求
|
||||
batch_data = _convert(data)
|
||||
|
|
@ -42,12 +41,7 @@ async def model_with_multi_dress(
|
|||
# 提交到队列
|
||||
workflow_run_ids: list[str] = []
|
||||
for item in batch_data:
|
||||
workflow_run_id = await queue_manager.add_task(
|
||||
workflow_name=workflow_name,
|
||||
workflow_data=workflow,
|
||||
api_spec=api_spec,
|
||||
request_data=item,
|
||||
)
|
||||
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=item)
|
||||
workflow_run_ids.append(workflow_run_id)
|
||||
|
||||
return JSONResponse(
|
||||
|
|
|
|||
Loading…
Reference in New Issue