from datetime import datetime from typing import Optional import aiosqlite import json import re import uuid from workflow_service.config import Settings DATABASE_FILE = Settings().DB_FILE async def init_db(): async with aiosqlite.connect(DATABASE_FILE) as db: await db.execute(""" CREATE TABLE IF NOT EXISTS workflows ( name TEXT PRIMARY KEY, base_name TEXT NOT NULL, version TEXT NOT NULL, workflow_json TEXT NOT NULL ) """) # 新增 workflow_run 表 await db.execute(""" CREATE TABLE IF NOT EXISTS workflow_run ( id TEXT PRIMARY KEY, workflow_name TEXT NOT NULL, prompt_id TEXT, client_id TEXT, status TEXT NOT NULL DEFAULT 'pending', server_url TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, started_at TIMESTAMP, completed_at TIMESTAMP, error_message TEXT, workflow_json TEXT NOT NULL, api_spec TEXT NOT NULL, request_data TEXT NOT NULL, result TEXT ) """) # 新增 workflow_run_nodes 表,记录每个节点的运行状态 await db.execute(""" CREATE TABLE IF NOT EXISTS workflow_run_nodes ( id TEXT PRIMARY KEY, workflow_run_id TEXT NOT NULL, node_id TEXT NOT NULL, node_type TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'pending', started_at TIMESTAMP, completed_at TIMESTAMP, output_data TEXT, error_message TEXT, FOREIGN KEY (workflow_run_id) REFERENCES workflow_run (id) ) """) await db.commit() print(f"数据库 '{DATABASE_FILE}' 已准备就绪。") async def save_workflow(name: str, workflow_json: str): version = datetime.now().strftime("%Y%m%d%H%M%S") async with aiosqlite.connect(DATABASE_FILE) as db: await db.execute( "INSERT OR REPLACE INTO workflows (name, base_name, version, workflow_json) VALUES (?, ?, ?, ?)", (f"{name} [{version}]", name, version, workflow_json) ) await db.commit() async def get_all_workflows() -> list[dict]: async with aiosqlite.connect(DATABASE_FILE) as db: db.row_factory = aiosqlite.Row # 按 base_name 分组,获取每个基础名称的最新版本记录 cursor = await db.execute(""" SELECT w1.name, w1.workflow_json, w1.base_name, w1.version FROM workflows w1 INNER JOIN ( SELECT base_name, MAX(version) as max_version FROM workflows GROUP BY base_name ) w2 ON w1.base_name = w2.base_name AND w1.version = w2.max_version ORDER BY w1.base_name """) rows = await cursor.fetchall() return [{"name": row["name"], "workflow": json.loads(row["workflow_json"])} for row in rows] async def get_latest_workflow_by_base_name(base_name: str) -> dict | None: async with aiosqlite.connect(DATABASE_FILE) as db: db.row_factory = aiosqlite.Row cursor = await db.execute("SELECT * FROM workflows WHERE base_name = ? ORDER BY version DESC LIMIT 1", (base_name,)) row = await cursor.fetchone() return dict(row) if row else None async def get_workflow_by_version(base_name: str, version: str) -> dict | None: name = f"{base_name} [{version}]" async with aiosqlite.connect(DATABASE_FILE) as db: db.row_factory = aiosqlite.Row cursor = await db.execute("SELECT * FROM workflows WHERE name = ?", (name,)) row = await cursor.fetchone() return dict(row) if row else None async def get_workflow(name: str, version: Optional[str] = None) -> dict | None: if version: return await get_workflow_by_version(name, version) else: return await get_latest_workflow_by_base_name(name) async def delete_workflow(name: str) -> bool: async with aiosqlite.connect(DATABASE_FILE) as db: cursor = await db.execute("DELETE FROM workflows WHERE name = ?", (name,)) await db.commit() return cursor.rowcount > 0 # 新增 workflow_run 相关函数 async def create_workflow_run( workflow_run_id: str, workflow_name: str, workflow_json: str, api_spec: str, request_data: str ) -> str: """创建新的工作流运行记录""" async with aiosqlite.connect(DATABASE_FILE) as db: await db.execute(""" INSERT INTO workflow_run (id, workflow_name, workflow_json, api_spec, request_data) VALUES (?, ?, ?, ?, ?) """, (workflow_run_id, workflow_name, workflow_json, api_spec, request_data)) await db.commit() return workflow_run_id async def update_workflow_run_status( workflow_run_id: str, status: str, server_url: str = None, prompt_id: str = None, client_id: str = None, error_message: str = None, result: str = None ): """更新工作流运行状态""" async with aiosqlite.connect(DATABASE_FILE) as db: if status == 'running': await db.execute(""" UPDATE workflow_run SET status = ?, server_url = ?, prompt_id = ?, client_id = ?, started_at = CURRENT_TIMESTAMP WHERE id = ? """, (status, server_url, prompt_id, client_id, workflow_run_id)) elif status == 'completed': await db.execute(""" UPDATE workflow_run SET status = ?, completed_at = CURRENT_TIMESTAMP, result = ? WHERE id = ? """, (status, result, workflow_run_id)) elif status == 'failed': await db.execute(""" UPDATE workflow_run SET status = ?, error_message = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ? """, (status, error_message, workflow_run_id)) else: await db.execute(""" UPDATE workflow_run SET status = ? WHERE id = ? """, (status, workflow_run_id)) await db.commit() async def create_workflow_run_nodes(workflow_run_id: str, nodes_data: list): """创建工作流运行节点记录""" async with aiosqlite.connect(DATABASE_FILE) as db: for node in nodes_data: node_id = str(uuid.uuid4()) await db.execute(""" INSERT INTO workflow_run_nodes (id, workflow_run_id, node_id, node_type) VALUES (?, ?, ?, ?) """, (node_id, workflow_run_id, node["id"], node["type"])) await db.commit() async def update_workflow_run_node_status( workflow_run_id: str, node_id: str, status: str, output_data: str = None, error_message: str = None ): """更新工作流运行节点状态""" async with aiosqlite.connect(DATABASE_FILE) as db: if status == 'running': await db.execute(""" UPDATE workflow_run_nodes SET status = ?, started_at = CURRENT_TIMESTAMP WHERE workflow_run_id = ? AND node_id = ? """, (status, workflow_run_id, node_id)) elif status == 'completed': await db.execute(""" UPDATE workflow_run_nodes SET status = ?, output_data = ?, completed_at = CURRENT_TIMESTAMP WHERE workflow_run_id = ? AND node_id = ? """, (status, output_data, workflow_run_id, node_id)) elif status == 'failed': await db.execute(""" UPDATE workflow_run_nodes SET status = ?, error_message = ?, completed_at = CURRENT_TIMESTAMP WHERE workflow_run_id = ? AND node_id = ? """, (status, error_message, workflow_run_id, node_id)) else: await db.execute(""" UPDATE workflow_run_nodes SET status = ? WHERE workflow_run_id = ? AND node_id = ? """, (status, workflow_run_id, node_id)) await db.commit() async def get_workflow_run(workflow_run_id: str) -> dict | None: """获取工作流运行记录""" async with aiosqlite.connect(DATABASE_FILE) as db: db.row_factory = aiosqlite.Row cursor = await db.execute(""" SELECT * FROM workflow_run WHERE id = ? """, (workflow_run_id,)) row = await cursor.fetchone() return dict(row) if row else None async def get_workflow_run_nodes(workflow_run_id: str) -> list[dict]: """获取工作流运行节点记录""" async with aiosqlite.connect(DATABASE_FILE) as db: db.row_factory = aiosqlite.Row cursor = await db.execute(""" SELECT * FROM workflow_run_nodes WHERE workflow_run_id = ? """, (workflow_run_id,)) rows = await cursor.fetchall() return [dict(row) for row in rows] async def get_pending_workflow_runs() -> list[dict]: """获取所有待处理的工作流运行记录""" async with aiosqlite.connect(DATABASE_FILE) as db: db.row_factory = aiosqlite.Row cursor = await db.execute(""" SELECT * FROM workflow_run WHERE status = 'pending' ORDER BY created_at ASC """) rows = await cursor.fetchall() return [dict(row) for row in rows] async def get_running_workflow_runs() -> list[dict]: """获取所有正在运行的工作流运行记录""" async with aiosqlite.connect(DATABASE_FILE) as db: db.row_factory = aiosqlite.Row cursor = await db.execute(""" SELECT * FROM workflow_run WHERE status = 'running' """) rows = await cursor.fetchall() return [dict(row) for row in rows]