""" 数据库操作API 提供工作流相关的所有数据库操作函数 """ from datetime import datetime from typing import Optional, List import json import uuid from sqlalchemy.future import select from sqlalchemy import delete from .models import Base, Workflow, WorkflowRun, WorkflowRunNode, ComfyUIServer from .connection import async_engine, AsyncSessionLocal async def init_db(): """初始化数据库表结构""" async with async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) print(f"数据库表结构已创建完成。") async def save_workflow(name: str, workflow_json: str): """保存工作流""" version = datetime.now().strftime("%Y%m%d%H%M%S") workflow = Workflow( name=f"{name} [{version}]", base_name=name, version=version, workflow_json=workflow_json, ) async with AsyncSessionLocal() as session: await session.merge(workflow) # 使用merge实现INSERT OR REPLACE await session.commit() async def get_all_workflows() -> List[dict]: """获取所有工作流(最新版本)""" async with AsyncSessionLocal() as session: # 使用SQLAlchemy ORM语法,通过子查询获取每个base_name的最新版本 from sqlalchemy import func # 子查询:获取每个base_name的最新版本 latest_versions = ( select(Workflow.base_name, func.max(Workflow.version).label("max_version")) .group_by(Workflow.base_name) .subquery() ) # 主查询:关联获取完整的工作流信息 stmt = ( select(Workflow) .join( latest_versions, (Workflow.base_name == latest_versions.c.base_name) & (Workflow.version == latest_versions.c.max_version), ) .order_by(Workflow.base_name) ) result = await session.execute(stmt) workflows = result.scalars().all() return [ {"name": wf.name, "workflow": json.loads(wf.workflow_json)} for wf in workflows ] async def get_latest_workflow_by_base_name(base_name: str) -> Optional[dict]: """根据基础名称获取最新版本的工作流""" async with AsyncSessionLocal() as session: stmt = ( select(Workflow) .where(Workflow.base_name == base_name) .order_by(Workflow.version.desc()) .limit(1) ) result = await session.execute(stmt) workflow = result.scalar_one_or_none() return workflow.to_dict() if workflow else None async def get_workflow_by_version(base_name: str, version: str) -> Optional[dict]: """根据版本获取工作流""" name = f"{base_name} [{version}]" async with AsyncSessionLocal() as session: stmt = select(Workflow).where(Workflow.name == name) result = await session.execute(stmt) workflow = result.scalar_one_or_none() return workflow.to_dict() if workflow else None async def get_workflow(name: str, version: Optional[str] = None) -> Optional[dict]: """获取工作流""" if version: return await get_workflow_by_version(name, version) else: return await get_latest_workflow_by_base_name(name) async def delete_workflow(name: str) -> bool: """删除工作流""" async with AsyncSessionLocal() as session: stmt = delete(Workflow).where(Workflow.name == name) result = await session.execute(stmt) await session.commit() return result.rowcount > 0 async def create_workflow_run( workflow_run_id: str, workflow_name: str, workflow_json: str, api_spec: str, request_data: str, ) -> str: """创建新的工作流运行记录""" workflow_run = WorkflowRun( id=workflow_run_id, workflow_name=workflow_name, workflow_json=workflow_json, api_spec=api_spec, request_data=request_data, ) async with AsyncSessionLocal() as session: session.add(workflow_run) await session.commit() return workflow_run_id async def update_workflow_run_status( workflow_run_id: str, status: str, server_url: str = None, prompt_id: str = None, client_id: str = None, error_message: str = None, result: str = None, ): """更新工作流运行状态""" async with AsyncSessionLocal() as session: workflow_run = await session.get(WorkflowRun, workflow_run_id) if not workflow_run: return workflow_run.status = status if status == "running": workflow_run.server_url = server_url workflow_run.prompt_id = prompt_id workflow_run.client_id = client_id workflow_run.started_at = datetime.utcnow() elif status == "completed": workflow_run.completed_at = datetime.utcnow() workflow_run.result = result elif status == "failed": workflow_run.error_message = error_message workflow_run.completed_at = datetime.utcnow() await session.commit() async def create_workflow_run_nodes(workflow_run_id: str, nodes_data: List[dict]): """创建工作流运行节点记录""" nodes = [] for node in nodes_data: node_obj = WorkflowRunNode( id=str(uuid.uuid4()), workflow_run_id=workflow_run_id, node_id=node["id"], node_type=node["type"], ) nodes.append(node_obj) async with AsyncSessionLocal() as session: session.add_all(nodes) await session.commit() async def update_workflow_run_node_status( workflow_run_id: str, node_id: str, status: str, output_data: str = None, error_message: str = None, ): """更新工作流运行节点状态""" async with AsyncSessionLocal() as session: stmt = select(WorkflowRunNode).where( WorkflowRunNode.workflow_run_id == workflow_run_id, WorkflowRunNode.node_id == node_id, ) result = await session.execute(stmt) node = result.scalar_one_or_none() if not node: return node.status = status if status == "running": node.started_at = datetime.utcnow() elif status == "completed": node.output_data = output_data node.completed_at = datetime.utcnow() elif status == "failed": node.error_message = error_message node.completed_at = datetime.utcnow() await session.commit() async def get_workflow_run(workflow_run_id: str) -> Optional[dict]: """获取工作流运行记录""" async with AsyncSessionLocal() as session: workflow_run = await session.get(WorkflowRun, workflow_run_id) return workflow_run.to_dict() if workflow_run else None async def get_workflow_run_nodes(workflow_run_id: str) -> List[dict]: """获取工作流运行节点记录""" async with AsyncSessionLocal() as session: stmt = select(WorkflowRunNode).where( WorkflowRunNode.workflow_run_id == workflow_run_id ) result = await session.execute(stmt) nodes = result.scalars().all() return [node.to_dict() for node in nodes] async def get_pending_workflow_runs() -> List[dict]: """获取所有待处理的工作流运行记录""" async with AsyncSessionLocal() as session: stmt = ( select(WorkflowRun) .where(WorkflowRun.status == "pending") .order_by(WorkflowRun.created_at.asc()) ) result = await session.execute(stmt) runs = result.scalars().all() return [run.to_dict() for run in runs] async def get_running_workflow_runs() -> List[dict]: """获取所有正在运行的工作流运行记录""" async with AsyncSessionLocal() as session: stmt = select(WorkflowRun).where(WorkflowRun.status == "running") result = await session.execute(stmt) runs = result.scalars().all() return [run.to_dict() for run in runs] async def get_workflow_runs_recent( start_time: datetime, end_time: datetime ) -> List[dict]: """获取指定时间范围内的最近工作流运行记录""" async with AsyncSessionLocal() as session: stmt = ( select(WorkflowRun) .where( WorkflowRun.created_at >= start_time, WorkflowRun.created_at <= end_time ) .order_by(WorkflowRun.created_at.desc()) ) result = await session.execute(stmt) runs = result.scalars().all() return [run.to_dict() for run in runs]