257 lines
7.7 KiB
Python
257 lines
7.7 KiB
Python
"""
|
|
数据库操作API
|
|
提供工作流相关的所有数据库操作函数
|
|
"""
|
|
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
import json
|
|
import uuid
|
|
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy import delete
|
|
|
|
from .models import Base, Workflow, WorkflowRun, WorkflowRunNode
|
|
from .connection import async_engine, AsyncSessionLocal
|
|
|
|
|
|
async def init_db():
|
|
"""初始化数据库表结构"""
|
|
async with async_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
print(f"数据库表结构已创建完成。")
|
|
|
|
|
|
async def save_workflow(name: str, workflow_json: str):
|
|
"""保存工作流"""
|
|
version = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
workflow = Workflow(
|
|
name=f"{name} [{version}]",
|
|
base_name=name,
|
|
version=version,
|
|
workflow_json=workflow_json,
|
|
)
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
await session.merge(workflow) # 使用merge实现INSERT OR REPLACE
|
|
await session.commit()
|
|
|
|
|
|
async def get_all_workflows() -> List[dict]:
|
|
"""获取所有工作流(最新版本)"""
|
|
async with AsyncSessionLocal() as session:
|
|
# 使用子查询获取每个base_name的最新版本
|
|
stmt = (
|
|
select(Workflow)
|
|
.where(
|
|
Workflow.version
|
|
== select(Workflow.version)
|
|
.where(Workflow.base_name == Workflow.base_name)
|
|
.order_by(Workflow.version.desc())
|
|
.limit(1)
|
|
.scalar_subquery()
|
|
)
|
|
.order_by(Workflow.base_name)
|
|
)
|
|
|
|
result = await session.execute(stmt)
|
|
workflows = result.scalars().all()
|
|
|
|
return [
|
|
{"name": wf.name, "workflow": json.loads(wf.workflow_json)}
|
|
for wf in workflows
|
|
]
|
|
|
|
|
|
async def get_latest_workflow_by_base_name(base_name: str) -> Optional[dict]:
|
|
"""根据基础名称获取最新版本的工作流"""
|
|
async with AsyncSessionLocal() as session:
|
|
stmt = (
|
|
select(Workflow)
|
|
.where(Workflow.base_name == base_name)
|
|
.order_by(Workflow.version.desc())
|
|
.limit(1)
|
|
)
|
|
|
|
result = await session.execute(stmt)
|
|
workflow = result.scalar_one_or_none()
|
|
|
|
return workflow.to_dict() if workflow else None
|
|
|
|
|
|
async def get_workflow_by_version(base_name: str, version: str) -> Optional[dict]:
|
|
"""根据版本获取工作流"""
|
|
name = f"{base_name} [{version}]"
|
|
async with AsyncSessionLocal() as session:
|
|
stmt = select(Workflow).where(Workflow.name == name)
|
|
result = await session.execute(stmt)
|
|
workflow = result.scalar_one_or_none()
|
|
|
|
return workflow.to_dict() if workflow else None
|
|
|
|
|
|
async def get_workflow(name: str, version: Optional[str] = None) -> Optional[dict]:
|
|
"""获取工作流"""
|
|
if version:
|
|
return await get_workflow_by_version(name, version)
|
|
else:
|
|
return await get_latest_workflow_by_base_name(name)
|
|
|
|
|
|
async def delete_workflow(name: str) -> bool:
|
|
"""删除工作流"""
|
|
async with AsyncSessionLocal() as session:
|
|
stmt = delete(Workflow).where(Workflow.name == name)
|
|
result = await session.execute(stmt)
|
|
await session.commit()
|
|
return result.rowcount > 0
|
|
|
|
|
|
async def create_workflow_run(
|
|
workflow_run_id: str,
|
|
workflow_name: str,
|
|
workflow_json: str,
|
|
api_spec: str,
|
|
request_data: str,
|
|
) -> str:
|
|
"""创建新的工作流运行记录"""
|
|
workflow_run = WorkflowRun(
|
|
id=workflow_run_id,
|
|
workflow_name=workflow_name,
|
|
workflow_json=workflow_json,
|
|
api_spec=api_spec,
|
|
request_data=request_data,
|
|
)
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
session.add(workflow_run)
|
|
await session.commit()
|
|
|
|
return workflow_run_id
|
|
|
|
|
|
async def update_workflow_run_status(
|
|
workflow_run_id: str,
|
|
status: str,
|
|
server_url: str = None,
|
|
prompt_id: str = None,
|
|
client_id: str = None,
|
|
error_message: str = None,
|
|
result: str = None,
|
|
):
|
|
"""更新工作流运行状态"""
|
|
async with AsyncSessionLocal() as session:
|
|
workflow_run = await session.get(WorkflowRun, workflow_run_id)
|
|
if not workflow_run:
|
|
return
|
|
|
|
workflow_run.status = status
|
|
|
|
if status == "running":
|
|
workflow_run.server_url = server_url
|
|
workflow_run.prompt_id = prompt_id
|
|
workflow_run.client_id = client_id
|
|
workflow_run.started_at = datetime.utcnow()
|
|
elif status == "completed":
|
|
workflow_run.completed_at = datetime.utcnow()
|
|
workflow_run.result = result
|
|
elif status == "failed":
|
|
workflow_run.error_message = error_message
|
|
workflow_run.completed_at = datetime.utcnow()
|
|
|
|
await session.commit()
|
|
|
|
|
|
async def create_workflow_run_nodes(workflow_run_id: str, nodes_data: List[dict]):
|
|
"""创建工作流运行节点记录"""
|
|
nodes = []
|
|
for node in nodes_data:
|
|
node_obj = WorkflowRunNode(
|
|
id=str(uuid.uuid4()),
|
|
workflow_run_id=workflow_run_id,
|
|
node_id=node["id"],
|
|
node_type=node["type"],
|
|
)
|
|
nodes.append(node_obj)
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
session.add_all(nodes)
|
|
await session.commit()
|
|
|
|
|
|
async def update_workflow_run_node_status(
|
|
workflow_run_id: str,
|
|
node_id: str,
|
|
status: str,
|
|
output_data: str = None,
|
|
error_message: str = None,
|
|
):
|
|
"""更新工作流运行节点状态"""
|
|
async with AsyncSessionLocal() as session:
|
|
stmt = select(WorkflowRunNode).where(
|
|
WorkflowRunNode.workflow_run_id == workflow_run_id,
|
|
WorkflowRunNode.node_id == node_id,
|
|
)
|
|
result = await session.execute(stmt)
|
|
node = result.scalar_one_or_none()
|
|
|
|
if not node:
|
|
return
|
|
|
|
node.status = status
|
|
|
|
if status == "running":
|
|
node.started_at = datetime.utcnow()
|
|
elif status == "completed":
|
|
node.output_data = output_data
|
|
node.completed_at = datetime.utcnow()
|
|
elif status == "failed":
|
|
node.error_message = error_message
|
|
node.completed_at = datetime.utcnow()
|
|
|
|
await session.commit()
|
|
|
|
|
|
async def get_workflow_run(workflow_run_id: str) -> Optional[dict]:
|
|
"""获取工作流运行记录"""
|
|
async with AsyncSessionLocal() as session:
|
|
workflow_run = await session.get(WorkflowRun, workflow_run_id)
|
|
return workflow_run.to_dict() if workflow_run else None
|
|
|
|
|
|
async def get_workflow_run_nodes(workflow_run_id: str) -> List[dict]:
|
|
"""获取工作流运行节点记录"""
|
|
async with AsyncSessionLocal() as session:
|
|
stmt = select(WorkflowRunNode).where(
|
|
WorkflowRunNode.workflow_run_id == workflow_run_id
|
|
)
|
|
result = await session.execute(stmt)
|
|
nodes = result.scalars().all()
|
|
|
|
return [node.to_dict() for node in nodes]
|
|
|
|
|
|
async def get_pending_workflow_runs() -> List[dict]:
|
|
"""获取所有待处理的工作流运行记录"""
|
|
async with AsyncSessionLocal() as session:
|
|
stmt = (
|
|
select(WorkflowRun)
|
|
.where(WorkflowRun.status == "pending")
|
|
.order_by(WorkflowRun.created_at.asc())
|
|
)
|
|
|
|
result = await session.execute(stmt)
|
|
runs = result.scalars().all()
|
|
|
|
return [run.to_dict() for run in runs]
|
|
|
|
|
|
async def get_running_workflow_runs() -> List[dict]:
|
|
"""获取所有正在运行的工作流运行记录"""
|
|
async with AsyncSessionLocal() as session:
|
|
stmt = select(WorkflowRun).where(WorkflowRun.status == "running")
|
|
result = await session.execute(stmt)
|
|
runs = result.scalars().all()
|
|
|
|
return [run.to_dict() for run in runs]
|