ComfyUI-WorkflowPublisher/app/database/api.py

283 lines
8.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库操作API
提供工作流相关的所有数据库操作函数
"""
from datetime import datetime
from typing import Optional, List
import json
import uuid
from sqlalchemy.future import select
from sqlalchemy import delete
from .models import Base, Workflow, WorkflowRun, WorkflowRunNode, ComfyUIServer
from .connection import async_engine, AsyncSessionLocal
async def init_db():
"""初始化数据库表结构"""
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
print(f"数据库表结构已创建完成。")
async def save_workflow(name: str, workflow_json: str):
"""保存工作流"""
version = datetime.now().strftime("%Y%m%d%H%M%S")
workflow = Workflow(
name=f"{name} [{version}]",
base_name=name,
version=version,
workflow_json=workflow_json,
)
async with AsyncSessionLocal() as session:
await session.merge(workflow) # 使用merge实现INSERT OR REPLACE
await session.commit()
async def get_all_workflows() -> List[dict]:
"""获取所有工作流(最新版本)"""
async with AsyncSessionLocal() as session:
# 使用SQLAlchemy ORM语法通过子查询获取每个base_name的最新版本
from sqlalchemy import func
# 子查询获取每个base_name的最新版本
latest_versions = (
select(Workflow.base_name, func.max(Workflow.version).label("max_version"))
.group_by(Workflow.base_name)
.subquery()
)
# 主查询:关联获取完整的工作流信息
stmt = (
select(Workflow)
.join(
latest_versions,
(Workflow.base_name == latest_versions.c.base_name)
& (Workflow.version == latest_versions.c.max_version),
)
.order_by(Workflow.base_name)
)
result = await session.execute(stmt)
workflows = result.scalars().all()
return [
{"name": wf.name, "workflow": json.loads(wf.workflow_json)}
for wf in workflows
]
async def get_latest_workflow_by_base_name(base_name: str) -> Optional[dict]:
"""根据基础名称获取最新版本的工作流"""
async with AsyncSessionLocal() as session:
stmt = (
select(Workflow)
.where(Workflow.base_name == base_name)
.order_by(Workflow.version.desc())
.limit(1)
)
result = await session.execute(stmt)
workflow = result.scalar_one_or_none()
return workflow.to_dict() if workflow else None
async def get_workflow_by_version(base_name: str, version: str) -> Optional[dict]:
"""根据版本获取工作流"""
name = f"{base_name} [{version}]"
async with AsyncSessionLocal() as session:
stmt = select(Workflow).where(Workflow.name == name)
result = await session.execute(stmt)
workflow = result.scalar_one_or_none()
return workflow.to_dict() if workflow else None
async def get_workflow(name: str, version: Optional[str] = None) -> Optional[dict]:
"""获取工作流"""
if version:
return await get_workflow_by_version(name, version)
else:
return await get_latest_workflow_by_base_name(name)
async def delete_workflow(name: str) -> bool:
"""删除工作流"""
async with AsyncSessionLocal() as session:
stmt = delete(Workflow).where(Workflow.name == name)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
async def create_workflow_run(
workflow_run_id: str,
workflow_name: str,
workflow_json: str,
api_spec: str,
request_data: str,
) -> str:
"""创建新的工作流运行记录"""
workflow_run = WorkflowRun(
id=workflow_run_id,
workflow_name=workflow_name,
workflow_json=workflow_json,
api_spec=api_spec,
request_data=request_data,
)
async with AsyncSessionLocal() as session:
session.add(workflow_run)
await session.commit()
return workflow_run_id
async def update_workflow_run_status(
workflow_run_id: str,
status: str,
server_url: str = None,
prompt_id: str = None,
client_id: str = None,
error_message: str = None,
result: str = None,
):
"""更新工作流运行状态"""
async with AsyncSessionLocal() as session:
workflow_run = await session.get(WorkflowRun, workflow_run_id)
if not workflow_run:
return
workflow_run.status = status
if status == "running":
workflow_run.server_url = server_url
workflow_run.prompt_id = prompt_id
workflow_run.client_id = client_id
workflow_run.started_at = datetime.utcnow()
elif status == "completed":
workflow_run.completed_at = datetime.utcnow()
workflow_run.result = result
elif status == "failed":
workflow_run.error_message = error_message
workflow_run.completed_at = datetime.utcnow()
await session.commit()
async def create_workflow_run_nodes(workflow_run_id: str, nodes_data: List[dict]):
"""创建工作流运行节点记录"""
nodes = []
for node in nodes_data:
node_obj = WorkflowRunNode(
id=str(uuid.uuid4()),
workflow_run_id=workflow_run_id,
node_id=node["id"],
node_type=node["type"],
)
nodes.append(node_obj)
async with AsyncSessionLocal() as session:
session.add_all(nodes)
await session.commit()
async def update_workflow_run_node_status(
workflow_run_id: str,
node_id: str,
status: str,
output_data: str = None,
error_message: str = None,
):
"""更新工作流运行节点状态"""
async with AsyncSessionLocal() as session:
stmt = select(WorkflowRunNode).where(
WorkflowRunNode.workflow_run_id == workflow_run_id,
WorkflowRunNode.node_id == node_id,
)
result = await session.execute(stmt)
node = result.scalar_one_or_none()
if not node:
return
node.status = status
if status == "running":
node.started_at = datetime.utcnow()
elif status == "completed":
node.output_data = output_data
node.completed_at = datetime.utcnow()
elif status == "failed":
node.error_message = error_message
node.completed_at = datetime.utcnow()
await session.commit()
async def get_workflow_run(workflow_run_id: str) -> Optional[WorkflowRun]:
"""获取工作流运行记录"""
async with AsyncSessionLocal() as session:
workflow_run = await session.get(WorkflowRun, workflow_run_id)
return workflow_run
async def get_workflow_run_nodes(workflow_run_id: str) -> List[dict]:
"""获取工作流运行节点记录"""
async with AsyncSessionLocal() as session:
stmt = select(WorkflowRunNode).where(
WorkflowRunNode.workflow_run_id == workflow_run_id
)
result = await session.execute(stmt)
nodes = result.scalars().all()
return [node.to_dict() for node in nodes]
async def get_pending_workflow_runs() -> List[dict]:
"""获取所有待处理的工作流运行记录"""
async with AsyncSessionLocal() as session:
stmt = (
select(WorkflowRun)
.where(WorkflowRun.status == "pending")
.order_by(WorkflowRun.created_at.asc())
)
result = await session.execute(stmt)
runs = result.scalars().all()
return [run.to_dict() for run in runs]
async def get_running_workflow_runs() -> List[dict]:
"""获取所有正在运行的工作流运行记录"""
async with AsyncSessionLocal() as session:
stmt = select(WorkflowRun).where(WorkflowRun.status == "running")
result = await session.execute(stmt)
runs = result.scalars().all()
return [run.to_dict() for run in runs]
async def get_workflow_runs_recent(
start_time: datetime, end_time: datetime
) -> List[dict]:
"""获取指定时间范围内的最近工作流运行记录"""
async with AsyncSessionLocal() as session:
stmt = (
select(WorkflowRun)
.where(
WorkflowRun.created_at >= start_time, WorkflowRun.created_at <= end_time
)
.order_by(WorkflowRun.created_at.desc())
)
result = await session.execute(stmt)
runs = result.scalars().all()
return [run.to_dict() for run in runs]