ComfyUI-WorkflowPublisher/workflow_service/database.py

253 lines
9.3 KiB
Python

from datetime import datetime
from typing import Optional
import aiosqlite
import json
import re
import uuid
from workflow_service.config import Settings
DATABASE_FILE = Settings().DB_FILE
async def init_db():
async with aiosqlite.connect(DATABASE_FILE) as db:
await db.execute("""
CREATE TABLE IF NOT EXISTS workflows (
name TEXT PRIMARY KEY,
base_name TEXT NOT NULL,
version TEXT NOT NULL,
workflow_json TEXT NOT NULL
)
""")
# 新增 workflow_run 表
await db.execute("""
CREATE TABLE IF NOT EXISTS workflow_run (
id TEXT PRIMARY KEY,
workflow_name TEXT NOT NULL,
prompt_id TEXT,
client_id TEXT,
status TEXT NOT NULL DEFAULT 'pending',
server_url TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message TEXT,
workflow_json TEXT NOT NULL,
api_spec TEXT NOT NULL,
request_data TEXT NOT NULL,
result TEXT
)
""")
# 新增 workflow_run_nodes 表,记录每个节点的运行状态
await db.execute("""
CREATE TABLE IF NOT EXISTS workflow_run_nodes (
id TEXT PRIMARY KEY,
workflow_run_id TEXT NOT NULL,
node_id TEXT NOT NULL,
node_type TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
output_data TEXT,
error_message TEXT,
FOREIGN KEY (workflow_run_id) REFERENCES workflow_run (id)
)
""")
await db.commit()
print(f"数据库 '{DATABASE_FILE}' 已准备就绪。")
async def save_workflow(name: str, workflow_json: str):
version = datetime.now().strftime("%Y%m%d%H%M%S")
async with aiosqlite.connect(DATABASE_FILE) as db:
await db.execute(
"INSERT OR REPLACE INTO workflows (name, base_name, version, workflow_json) VALUES (?, ?, ?, ?)",
(f"{name} [{version}]", name, version, workflow_json)
)
await db.commit()
async def get_all_workflows() -> list[dict]:
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
# 按 base_name 分组,获取每个基础名称的最新版本记录
cursor = await db.execute("""
SELECT w1.name, w1.workflow_json, w1.base_name, w1.version
FROM workflows w1
INNER JOIN (
SELECT base_name, MAX(version) as max_version
FROM workflows
GROUP BY base_name
) w2 ON w1.base_name = w2.base_name AND w1.version = w2.max_version
ORDER BY w1.base_name
""")
rows = await cursor.fetchall()
return [{"name": row["name"], "workflow": json.loads(row["workflow_json"])} for row in rows]
async def get_latest_workflow_by_base_name(base_name: str) -> dict | None:
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM workflows WHERE base_name = ? ORDER BY version DESC LIMIT 1", (base_name,))
row = await cursor.fetchone()
return dict(row) if row else None
async def get_workflow_by_version(base_name: str, version: str) -> dict | None:
name = f"{base_name} [{version}]"
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM workflows WHERE name = ?", (name,))
row = await cursor.fetchone()
return dict(row) if row else None
async def get_workflow(name: str, version: Optional[str] = None) -> dict | None:
if version:
return await get_workflow_by_version(name, version)
else:
return await get_latest_workflow_by_base_name(name)
async def delete_workflow(name: str) -> bool:
async with aiosqlite.connect(DATABASE_FILE) as db:
cursor = await db.execute("DELETE FROM workflows WHERE name = ?", (name,))
await db.commit()
return cursor.rowcount > 0
# 新增 workflow_run 相关函数
async def create_workflow_run(
workflow_run_id: str,
workflow_name: str,
workflow_json: str,
api_spec: str,
request_data: str
) -> str:
"""创建新的工作流运行记录"""
async with aiosqlite.connect(DATABASE_FILE) as db:
await db.execute("""
INSERT INTO workflow_run (id, workflow_name, workflow_json, api_spec, request_data)
VALUES (?, ?, ?, ?, ?)
""", (workflow_run_id, workflow_name, workflow_json, api_spec, request_data))
await db.commit()
return workflow_run_id
async def update_workflow_run_status(
workflow_run_id: str,
status: str,
server_url: str = None,
prompt_id: str = None,
client_id: str = None,
error_message: str = None,
result: str = None
):
"""更新工作流运行状态"""
async with aiosqlite.connect(DATABASE_FILE) as db:
if status == 'running':
await db.execute("""
UPDATE workflow_run
SET status = ?, server_url = ?, prompt_id = ?, client_id = ?, started_at = CURRENT_TIMESTAMP
WHERE id = ?
""", (status, server_url, prompt_id, client_id, workflow_run_id))
elif status == 'completed':
await db.execute("""
UPDATE workflow_run
SET status = ?, completed_at = CURRENT_TIMESTAMP, result = ?
WHERE id = ?
""", (status, result, workflow_run_id))
elif status == 'failed':
await db.execute("""
UPDATE workflow_run
SET status = ?, error_message = ?, completed_at = CURRENT_TIMESTAMP
WHERE id = ?
""", (status, error_message, workflow_run_id))
else:
await db.execute("""
UPDATE workflow_run
SET status = ?
WHERE id = ?
""", (status, workflow_run_id))
await db.commit()
async def create_workflow_run_nodes(workflow_run_id: str, nodes_data: list):
"""创建工作流运行节点记录"""
async with aiosqlite.connect(DATABASE_FILE) as db:
for node in nodes_data:
node_id = str(uuid.uuid4())
await db.execute("""
INSERT INTO workflow_run_nodes (id, workflow_run_id, node_id, node_type)
VALUES (?, ?, ?, ?)
""", (node_id, workflow_run_id, node["id"], node["type"]))
await db.commit()
async def update_workflow_run_node_status(
workflow_run_id: str,
node_id: str,
status: str,
output_data: str = None,
error_message: str = None
):
"""更新工作流运行节点状态"""
async with aiosqlite.connect(DATABASE_FILE) as db:
if status == 'running':
await db.execute("""
UPDATE workflow_run_nodes
SET status = ?, started_at = CURRENT_TIMESTAMP
WHERE workflow_run_id = ? AND node_id = ?
""", (status, workflow_run_id, node_id))
elif status == 'completed':
await db.execute("""
UPDATE workflow_run_nodes
SET status = ?, output_data = ?, completed_at = CURRENT_TIMESTAMP
WHERE workflow_run_id = ? AND node_id = ?
""", (status, output_data, workflow_run_id, node_id))
elif status == 'failed':
await db.execute("""
UPDATE workflow_run_nodes
SET status = ?, error_message = ?, completed_at = CURRENT_TIMESTAMP
WHERE workflow_run_id = ? AND node_id = ?
""", (status, error_message, workflow_run_id, node_id))
else:
await db.execute("""
UPDATE workflow_run_nodes
SET status = ?
WHERE workflow_run_id = ? AND node_id = ?
""", (status, workflow_run_id, node_id))
await db.commit()
async def get_workflow_run(workflow_run_id: str) -> dict | None:
"""获取工作流运行记录"""
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM workflow_run WHERE id = ?
""", (workflow_run_id,))
row = await cursor.fetchone()
return dict(row) if row else None
async def get_workflow_run_nodes(workflow_run_id: str) -> list[dict]:
"""获取工作流运行节点记录"""
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM workflow_run_nodes WHERE workflow_run_id = ?
""", (workflow_run_id,))
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def get_pending_workflow_runs() -> list[dict]:
"""获取所有待处理的工作流运行记录"""
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM workflow_run WHERE status = 'pending' ORDER BY created_at ASC
""")
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def get_running_workflow_runs() -> list[dict]:
"""获取所有正在运行的工作流运行记录"""
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM workflow_run WHERE status = 'running'
""")
rows = await cursor.fetchall()
return [dict(row) for row in rows]