refactor: 更新WorkflowQueueManager以使用ComfyWorkflow类,简化任务添加逻辑并移除冗余API规范参数
This commit is contained in:
parent
13580de0a1
commit
3c4c689532
|
|
@ -11,7 +11,7 @@ from typing import Dict, Any, Optional
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientTimeout
|
from aiohttp import ClientTimeout
|
||||||
|
|
||||||
from workflow_service.comfy.comfy_workflow import ComfyAPISpec, ComfyWorkflow
|
from workflow_service.comfy.comfy_workflow import ComfyWorkflow
|
||||||
from workflow_service.config import Settings
|
from workflow_service.config import Settings
|
||||||
from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo
|
from workflow_service.comfy.comfy_server import server_manager, ComfyUIServerInfo
|
||||||
from workflow_service.database.api import (
|
from workflow_service.database.api import (
|
||||||
|
|
@ -163,9 +163,7 @@ class WorkflowQueueManager:
|
||||||
|
|
||||||
async def add_task(
|
async def add_task(
|
||||||
self,
|
self,
|
||||||
workflow_name: str,
|
workflow: ComfyWorkflow,
|
||||||
workflow_data: dict,
|
|
||||||
api_spec: ComfyAPISpec,
|
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
):
|
):
|
||||||
"""添加新任务到队列"""
|
"""添加新任务到队列"""
|
||||||
|
|
@ -174,18 +172,16 @@ class WorkflowQueueManager:
|
||||||
# 创建任务记录
|
# 创建任务记录
|
||||||
await create_workflow_run(
|
await create_workflow_run(
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
workflow_name=workflow_name,
|
workflow_name=workflow.workflow_name,
|
||||||
workflow_json=json.dumps(workflow_data),
|
workflow_json=json.dumps(workflow.workflow_data.model_dump()),
|
||||||
api_spec=json.dumps(api_spec.model_dump()),
|
api_spec=json.dumps(workflow.get_api_spec().model_dump()),
|
||||||
request_data=json.dumps(request_data),
|
request_data=json.dumps(request_data),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建工作流节点记录
|
# 创建工作流节点记录
|
||||||
nodes_data = []
|
nodes_data = []
|
||||||
for node in workflow_data.get("nodes", []):
|
for node in workflow.workflow_data.nodes:
|
||||||
nodes_data.append(
|
nodes_data.append({"id": str(node.id), "type": node.type})
|
||||||
{"id": str(node["id"]), "type": node.get("type", "unknown")}
|
|
||||||
)
|
|
||||||
await create_workflow_run_nodes(workflow_run_id, nodes_data)
|
await create_workflow_run_nodes(workflow_run_id, nodes_data)
|
||||||
|
|
||||||
# 添加到待处理队列
|
# 添加到待处理队列
|
||||||
|
|
@ -263,12 +259,11 @@ class WorkflowQueueManager:
|
||||||
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
|
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
|
||||||
|
|
||||||
workflow_data = json.loads(workflow_run["workflow_json"])
|
workflow_data = json.loads(workflow_run["workflow_json"])
|
||||||
api_spec = json.loads(workflow_run["api_spec"])
|
|
||||||
request_data = json.loads(workflow_run["request_data"])
|
request_data = json.loads(workflow_run["request_data"])
|
||||||
|
|
||||||
# 执行工作流
|
# 执行工作流
|
||||||
result = await _execute_prompt_on_server(
|
result = await _execute_prompt_on_server(
|
||||||
workflow_data, api_spec, request_data, server, workflow_run_id
|
workflow_data, request_data, server, workflow_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存处理后的结果到数据库
|
# 保存处理后的结果到数据库
|
||||||
|
|
@ -375,7 +370,6 @@ queue_manager = WorkflowQueueManager(monitor_interval=5)
|
||||||
|
|
||||||
async def _execute_prompt_on_server(
|
async def _execute_prompt_on_server(
|
||||||
workflow_data: dict,
|
workflow_data: dict,
|
||||||
api_spec: dict,
|
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
server: ComfyUIServerInfo,
|
server: ComfyUIServerInfo,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
|
|
|
||||||
|
|
@ -160,14 +160,16 @@ class ComfyWorkflow:
|
||||||
ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力
|
ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, workflow_data: dict):
|
def __init__(self, workflow_data: dict, workflow_name: str = None):
|
||||||
"""
|
"""
|
||||||
初始化工作流实例
|
初始化工作流实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workflow_data: 工作流数据
|
workflow_data: 工作流数据
|
||||||
|
workflow_name: 工作流名称
|
||||||
"""
|
"""
|
||||||
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
|
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
|
||||||
|
self.workflow_name = workflow_name
|
||||||
self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes}
|
self._nodes_map = {str(node.id): node for node in self.workflow_data.nodes}
|
||||||
self._api_spec = self._parse_api_spec()
|
self._api_spec = self._parse_api_spec()
|
||||||
self._inputs_json_schema = self._parse_inputs_json_schema()
|
self._inputs_json_schema = self._parse_inputs_json_schema()
|
||||||
|
|
|
||||||
|
|
@ -40,16 +40,10 @@ async def run_workflow(
|
||||||
raise HTTPException(status_code=404, detail=detail)
|
raise HTTPException(status_code=404, detail=detail)
|
||||||
|
|
||||||
workflow = json.loads(workflow_data["workflow_json"])
|
workflow = json.loads(workflow_data["workflow_json"])
|
||||||
flow = ComfyWorkflow(workflow)
|
flow = ComfyWorkflow(workflow, workflow_name)
|
||||||
api_spec = flow.get_api_spec()
|
|
||||||
|
|
||||||
# 提交到队列
|
# 提交到队列
|
||||||
workflow_run_id = await queue_manager.add_task(
|
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data)
|
||||||
workflow_name=workflow_name,
|
|
||||||
workflow_data=workflow,
|
|
||||||
api_spec=api_spec,
|
|
||||||
request_data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,7 @@ async def model_with_multi_dress(
|
||||||
raise HTTPException(status_code=404, detail=detail)
|
raise HTTPException(status_code=404, detail=detail)
|
||||||
|
|
||||||
workflow = json.loads(workflow_data["workflow_json"])
|
workflow = json.loads(workflow_data["workflow_json"])
|
||||||
flow = ComfyWorkflow(workflow)
|
flow = ComfyWorkflow(workflow, workflow_name)
|
||||||
api_spec = flow.get_api_spec()
|
|
||||||
|
|
||||||
# 将请求拆分为多个请求
|
# 将请求拆分为多个请求
|
||||||
batch_data = _convert(data)
|
batch_data = _convert(data)
|
||||||
|
|
@ -42,12 +41,7 @@ async def model_with_multi_dress(
|
||||||
# 提交到队列
|
# 提交到队列
|
||||||
workflow_run_ids: list[str] = []
|
workflow_run_ids: list[str] = []
|
||||||
for item in batch_data:
|
for item in batch_data:
|
||||||
workflow_run_id = await queue_manager.add_task(
|
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=item)
|
||||||
workflow_name=workflow_name,
|
|
||||||
workflow_data=workflow,
|
|
||||||
api_spec=api_spec,
|
|
||||||
request_data=item,
|
|
||||||
)
|
|
||||||
workflow_run_ids.append(workflow_run_id)
|
workflow_run_ids.append(workflow_run_id)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue