283 lines
8.5 KiB
Python
283 lines
8.5 KiB
Python
"""
|
||
数据库操作API
|
||
提供工作流相关的所有数据库操作函数
|
||
"""
|
||
|
||
from datetime import datetime
|
||
from typing import Optional, List
|
||
import json
|
||
import uuid
|
||
|
||
from sqlalchemy.future import select
|
||
from sqlalchemy import delete
|
||
|
||
from .models import Base, Workflow, WorkflowRun, WorkflowRunNode, ComfyUIServer
|
||
from .connection import async_engine, AsyncSessionLocal
|
||
|
||
|
||
async def init_db():
|
||
"""初始化数据库表结构"""
|
||
async with async_engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
print(f"数据库表结构已创建完成。")
|
||
|
||
|
||
async def save_workflow(name: str, workflow_json: str):
|
||
"""保存工作流"""
|
||
version = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
workflow = Workflow(
|
||
name=f"{name} [{version}]",
|
||
base_name=name,
|
||
version=version,
|
||
workflow_json=workflow_json,
|
||
)
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
await session.merge(workflow) # 使用merge实现INSERT OR REPLACE
|
||
await session.commit()
|
||
|
||
|
||
async def get_all_workflows() -> List[dict]:
|
||
"""获取所有工作流(最新版本)"""
|
||
async with AsyncSessionLocal() as session:
|
||
# 使用SQLAlchemy ORM语法,通过子查询获取每个base_name的最新版本
|
||
from sqlalchemy import func
|
||
|
||
# 子查询:获取每个base_name的最新版本
|
||
latest_versions = (
|
||
select(Workflow.base_name, func.max(Workflow.version).label("max_version"))
|
||
.group_by(Workflow.base_name)
|
||
.subquery()
|
||
)
|
||
|
||
# 主查询:关联获取完整的工作流信息
|
||
stmt = (
|
||
select(Workflow)
|
||
.join(
|
||
latest_versions,
|
||
(Workflow.base_name == latest_versions.c.base_name)
|
||
& (Workflow.version == latest_versions.c.max_version),
|
||
)
|
||
.order_by(Workflow.base_name)
|
||
)
|
||
|
||
result = await session.execute(stmt)
|
||
workflows = result.scalars().all()
|
||
|
||
return [
|
||
{"name": wf.name, "workflow": json.loads(wf.workflow_json)}
|
||
for wf in workflows
|
||
]
|
||
|
||
|
||
async def get_latest_workflow_by_base_name(base_name: str) -> Optional[dict]:
|
||
"""根据基础名称获取最新版本的工作流"""
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = (
|
||
select(Workflow)
|
||
.where(Workflow.base_name == base_name)
|
||
.order_by(Workflow.version.desc())
|
||
.limit(1)
|
||
)
|
||
|
||
result = await session.execute(stmt)
|
||
workflow = result.scalar_one_or_none()
|
||
|
||
return workflow.to_dict() if workflow else None
|
||
|
||
|
||
async def get_workflow_by_version(base_name: str, version: str) -> Optional[dict]:
|
||
"""根据版本获取工作流"""
|
||
name = f"{base_name} [{version}]"
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = select(Workflow).where(Workflow.name == name)
|
||
result = await session.execute(stmt)
|
||
workflow = result.scalar_one_or_none()
|
||
|
||
return workflow.to_dict() if workflow else None
|
||
|
||
|
||
async def get_workflow(name: str, version: Optional[str] = None) -> Optional[dict]:
|
||
"""获取工作流"""
|
||
if version:
|
||
return await get_workflow_by_version(name, version)
|
||
else:
|
||
return await get_latest_workflow_by_base_name(name)
|
||
|
||
|
||
async def delete_workflow(name: str) -> bool:
|
||
"""删除工作流"""
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = delete(Workflow).where(Workflow.name == name)
|
||
result = await session.execute(stmt)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
|
||
|
||
async def create_workflow_run(
|
||
workflow_run_id: str,
|
||
workflow_name: str,
|
||
workflow_json: str,
|
||
api_spec: str,
|
||
request_data: str,
|
||
) -> str:
|
||
"""创建新的工作流运行记录"""
|
||
workflow_run = WorkflowRun(
|
||
id=workflow_run_id,
|
||
workflow_name=workflow_name,
|
||
workflow_json=workflow_json,
|
||
api_spec=api_spec,
|
||
request_data=request_data,
|
||
)
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
session.add(workflow_run)
|
||
await session.commit()
|
||
|
||
return workflow_run_id
|
||
|
||
|
||
async def update_workflow_run_status(
|
||
workflow_run_id: str,
|
||
status: str,
|
||
server_url: str = None,
|
||
prompt_id: str = None,
|
||
client_id: str = None,
|
||
error_message: str = None,
|
||
result: str = None,
|
||
):
|
||
"""更新工作流运行状态"""
|
||
async with AsyncSessionLocal() as session:
|
||
workflow_run = await session.get(WorkflowRun, workflow_run_id)
|
||
if not workflow_run:
|
||
return
|
||
|
||
workflow_run.status = status
|
||
|
||
if status == "running":
|
||
workflow_run.server_url = server_url
|
||
workflow_run.prompt_id = prompt_id
|
||
workflow_run.client_id = client_id
|
||
workflow_run.started_at = datetime.utcnow()
|
||
elif status == "completed":
|
||
workflow_run.completed_at = datetime.utcnow()
|
||
workflow_run.result = result
|
||
elif status == "failed":
|
||
workflow_run.error_message = error_message
|
||
workflow_run.completed_at = datetime.utcnow()
|
||
|
||
await session.commit()
|
||
|
||
|
||
async def create_workflow_run_nodes(workflow_run_id: str, nodes_data: List[dict]):
|
||
"""创建工作流运行节点记录"""
|
||
nodes = []
|
||
for node in nodes_data:
|
||
node_obj = WorkflowRunNode(
|
||
id=str(uuid.uuid4()),
|
||
workflow_run_id=workflow_run_id,
|
||
node_id=node["id"],
|
||
node_type=node["type"],
|
||
)
|
||
nodes.append(node_obj)
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
session.add_all(nodes)
|
||
await session.commit()
|
||
|
||
|
||
async def update_workflow_run_node_status(
|
||
workflow_run_id: str,
|
||
node_id: str,
|
||
status: str,
|
||
output_data: str = None,
|
||
error_message: str = None,
|
||
):
|
||
"""更新工作流运行节点状态"""
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = select(WorkflowRunNode).where(
|
||
WorkflowRunNode.workflow_run_id == workflow_run_id,
|
||
WorkflowRunNode.node_id == node_id,
|
||
)
|
||
result = await session.execute(stmt)
|
||
node = result.scalar_one_or_none()
|
||
|
||
if not node:
|
||
return
|
||
|
||
node.status = status
|
||
|
||
if status == "running":
|
||
node.started_at = datetime.utcnow()
|
||
elif status == "completed":
|
||
node.output_data = output_data
|
||
node.completed_at = datetime.utcnow()
|
||
elif status == "failed":
|
||
node.error_message = error_message
|
||
node.completed_at = datetime.utcnow()
|
||
|
||
await session.commit()
|
||
|
||
|
||
async def get_workflow_run(workflow_run_id: str) -> Optional[WorkflowRun]:
|
||
"""获取工作流运行记录"""
|
||
async with AsyncSessionLocal() as session:
|
||
workflow_run = await session.get(WorkflowRun, workflow_run_id)
|
||
return workflow_run
|
||
|
||
|
||
async def get_workflow_run_nodes(workflow_run_id: str) -> List[dict]:
|
||
"""获取工作流运行节点记录"""
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = select(WorkflowRunNode).where(
|
||
WorkflowRunNode.workflow_run_id == workflow_run_id
|
||
)
|
||
result = await session.execute(stmt)
|
||
nodes = result.scalars().all()
|
||
|
||
return [node.to_dict() for node in nodes]
|
||
|
||
|
||
async def get_pending_workflow_runs() -> List[dict]:
|
||
"""获取所有待处理的工作流运行记录"""
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = (
|
||
select(WorkflowRun)
|
||
.where(WorkflowRun.status == "pending")
|
||
.order_by(WorkflowRun.created_at.asc())
|
||
)
|
||
|
||
result = await session.execute(stmt)
|
||
runs = result.scalars().all()
|
||
|
||
return [run.to_dict() for run in runs]
|
||
|
||
|
||
async def get_running_workflow_runs() -> List[dict]:
|
||
"""获取所有正在运行的工作流运行记录"""
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = select(WorkflowRun).where(WorkflowRun.status == "running")
|
||
result = await session.execute(stmt)
|
||
runs = result.scalars().all()
|
||
|
||
return [run.to_dict() for run in runs]
|
||
|
||
|
||
async def get_workflow_runs_recent(
|
||
start_time: datetime, end_time: datetime
|
||
) -> List[dict]:
|
||
"""获取指定时间范围内的最近工作流运行记录"""
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = (
|
||
select(WorkflowRun)
|
||
.where(
|
||
WorkflowRun.created_at >= start_time, WorkflowRun.created_at <= end_time
|
||
)
|
||
.order_by(WorkflowRun.created_at.desc())
|
||
)
|
||
|
||
result = await session.execute(stmt)
|
||
runs = result.scalars().all()
|
||
|
||
return [run.to_dict() for run in runs]
|