253 lines
9.3 KiB
Python
253 lines
9.3 KiB
Python
from datetime import datetime
|
|
from typing import Optional
|
|
import aiosqlite
|
|
import json
|
|
import re
|
|
import uuid
|
|
|
|
from workflow_service.config import Settings
|
|
|
|
DATABASE_FILE = Settings().DB_FILE
|
|
|
|
async def init_db():
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS workflows (
|
|
name TEXT PRIMARY KEY,
|
|
base_name TEXT NOT NULL,
|
|
version TEXT NOT NULL,
|
|
workflow_json TEXT NOT NULL
|
|
)
|
|
""")
|
|
|
|
# 新增 workflow_run 表
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS workflow_run (
|
|
id TEXT PRIMARY KEY,
|
|
workflow_name TEXT NOT NULL,
|
|
prompt_id TEXT,
|
|
client_id TEXT,
|
|
status TEXT NOT NULL DEFAULT 'pending',
|
|
server_url TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
started_at TIMESTAMP,
|
|
completed_at TIMESTAMP,
|
|
error_message TEXT,
|
|
workflow_json TEXT NOT NULL,
|
|
api_spec TEXT NOT NULL,
|
|
request_data TEXT NOT NULL,
|
|
result TEXT
|
|
)
|
|
""")
|
|
|
|
# 新增 workflow_run_nodes 表,记录每个节点的运行状态
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS workflow_run_nodes (
|
|
id TEXT PRIMARY KEY,
|
|
workflow_run_id TEXT NOT NULL,
|
|
node_id TEXT NOT NULL,
|
|
node_type TEXT NOT NULL,
|
|
status TEXT NOT NULL DEFAULT 'pending',
|
|
started_at TIMESTAMP,
|
|
completed_at TIMESTAMP,
|
|
output_data TEXT,
|
|
error_message TEXT,
|
|
FOREIGN KEY (workflow_run_id) REFERENCES workflow_run (id)
|
|
)
|
|
""")
|
|
|
|
await db.commit()
|
|
print(f"数据库 '{DATABASE_FILE}' 已准备就绪。")
|
|
|
|
async def save_workflow(name: str, workflow_json: str):
|
|
version = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
await db.execute(
|
|
"INSERT OR REPLACE INTO workflows (name, base_name, version, workflow_json) VALUES (?, ?, ?, ?)",
|
|
(f"{name} [{version}]", name, version, workflow_json)
|
|
)
|
|
await db.commit()
|
|
|
|
async def get_all_workflows() -> list[dict]:
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
# 按 base_name 分组,获取每个基础名称的最新版本记录
|
|
cursor = await db.execute("""
|
|
SELECT w1.name, w1.workflow_json, w1.base_name, w1.version
|
|
FROM workflows w1
|
|
INNER JOIN (
|
|
SELECT base_name, MAX(version) as max_version
|
|
FROM workflows
|
|
GROUP BY base_name
|
|
) w2 ON w1.base_name = w2.base_name AND w1.version = w2.max_version
|
|
ORDER BY w1.base_name
|
|
""")
|
|
rows = await cursor.fetchall()
|
|
return [{"name": row["name"], "workflow": json.loads(row["workflow_json"])} for row in rows]
|
|
|
|
async def get_latest_workflow_by_base_name(base_name: str) -> dict | None:
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM workflows WHERE base_name = ? ORDER BY version DESC LIMIT 1", (base_name,))
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
async def get_workflow_by_version(base_name: str, version: str) -> dict | None:
|
|
name = f"{base_name} [{version}]"
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM workflows WHERE name = ?", (name,))
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
async def get_workflow(name: str, version: Optional[str] = None) -> dict | None:
|
|
if version:
|
|
return await get_workflow_by_version(name, version)
|
|
else:
|
|
return await get_latest_workflow_by_base_name(name)
|
|
|
|
async def delete_workflow(name: str) -> bool:
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
cursor = await db.execute("DELETE FROM workflows WHERE name = ?", (name,))
|
|
await db.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
# 新增 workflow_run 相关函数
|
|
async def create_workflow_run(
|
|
workflow_run_id: str,
|
|
workflow_name: str,
|
|
workflow_json: str,
|
|
api_spec: str,
|
|
request_data: str
|
|
) -> str:
|
|
"""创建新的工作流运行记录"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
await db.execute("""
|
|
INSERT INTO workflow_run (id, workflow_name, workflow_json, api_spec, request_data)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""", (workflow_run_id, workflow_name, workflow_json, api_spec, request_data))
|
|
await db.commit()
|
|
return workflow_run_id
|
|
|
|
async def update_workflow_run_status(
|
|
workflow_run_id: str,
|
|
status: str,
|
|
server_url: str = None,
|
|
prompt_id: str = None,
|
|
client_id: str = None,
|
|
error_message: str = None,
|
|
result: str = None
|
|
):
|
|
"""更新工作流运行状态"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
if status == 'running':
|
|
await db.execute("""
|
|
UPDATE workflow_run
|
|
SET status = ?, server_url = ?, prompt_id = ?, client_id = ?, started_at = CURRENT_TIMESTAMP
|
|
WHERE id = ?
|
|
""", (status, server_url, prompt_id, client_id, workflow_run_id))
|
|
elif status == 'completed':
|
|
await db.execute("""
|
|
UPDATE workflow_run
|
|
SET status = ?, completed_at = CURRENT_TIMESTAMP, result = ?
|
|
WHERE id = ?
|
|
""", (status, result, workflow_run_id))
|
|
elif status == 'failed':
|
|
await db.execute("""
|
|
UPDATE workflow_run
|
|
SET status = ?, error_message = ?, completed_at = CURRENT_TIMESTAMP
|
|
WHERE id = ?
|
|
""", (status, error_message, workflow_run_id))
|
|
else:
|
|
await db.execute("""
|
|
UPDATE workflow_run
|
|
SET status = ?
|
|
WHERE id = ?
|
|
""", (status, workflow_run_id))
|
|
await db.commit()
|
|
|
|
async def create_workflow_run_nodes(workflow_run_id: str, nodes_data: list):
|
|
"""创建工作流运行节点记录"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
for node in nodes_data:
|
|
node_id = str(uuid.uuid4())
|
|
await db.execute("""
|
|
INSERT INTO workflow_run_nodes (id, workflow_run_id, node_id, node_type)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (node_id, workflow_run_id, node["id"], node["type"]))
|
|
await db.commit()
|
|
|
|
async def update_workflow_run_node_status(
|
|
workflow_run_id: str,
|
|
node_id: str,
|
|
status: str,
|
|
output_data: str = None,
|
|
error_message: str = None
|
|
):
|
|
"""更新工作流运行节点状态"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
if status == 'running':
|
|
await db.execute("""
|
|
UPDATE workflow_run_nodes
|
|
SET status = ?, started_at = CURRENT_TIMESTAMP
|
|
WHERE workflow_run_id = ? AND node_id = ?
|
|
""", (status, workflow_run_id, node_id))
|
|
elif status == 'completed':
|
|
await db.execute("""
|
|
UPDATE workflow_run_nodes
|
|
SET status = ?, output_data = ?, completed_at = CURRENT_TIMESTAMP
|
|
WHERE workflow_run_id = ? AND node_id = ?
|
|
""", (status, output_data, workflow_run_id, node_id))
|
|
elif status == 'failed':
|
|
await db.execute("""
|
|
UPDATE workflow_run_nodes
|
|
SET status = ?, error_message = ?, completed_at = CURRENT_TIMESTAMP
|
|
WHERE workflow_run_id = ? AND node_id = ?
|
|
""", (status, error_message, workflow_run_id, node_id))
|
|
else:
|
|
await db.execute("""
|
|
UPDATE workflow_run_nodes
|
|
SET status = ?
|
|
WHERE workflow_run_id = ? AND node_id = ?
|
|
""", (status, workflow_run_id, node_id))
|
|
await db.commit()
|
|
|
|
async def get_workflow_run(workflow_run_id: str) -> dict | None:
|
|
"""获取工作流运行记录"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("""
|
|
SELECT * FROM workflow_run WHERE id = ?
|
|
""", (workflow_run_id,))
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
async def get_workflow_run_nodes(workflow_run_id: str) -> list[dict]:
|
|
"""获取工作流运行节点记录"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("""
|
|
SELECT * FROM workflow_run_nodes WHERE workflow_run_id = ?
|
|
""", (workflow_run_id,))
|
|
rows = await cursor.fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
async def get_pending_workflow_runs() -> list[dict]:
|
|
"""获取所有待处理的工作流运行记录"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("""
|
|
SELECT * FROM workflow_run WHERE status = 'pending' ORDER BY created_at ASC
|
|
""")
|
|
rows = await cursor.fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
async def get_running_workflow_runs() -> list[dict]:
|
|
"""获取所有正在运行的工作流运行记录"""
|
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("""
|
|
SELECT * FROM workflow_run WHERE status = 'running'
|
|
""")
|
|
rows = await cursor.fetchall()
|
|
return [dict(row) for row in rows] |