# 模板表 - PostgreSQL 版本 import uuid import json from typing import Dict, List, Any, Optional from datetime import datetime from contextlib import contextmanager from python_core.config import settings from python_core.utils.logger import setup_logger from .types import TemplateInfo # 尝试导入 psycopg2,如果失败则提供友好的错误信息 try: import psycopg2 import psycopg2.extras PSYCOPG2_AVAILABLE = True except ImportError as e: PSYCOPG2_AVAILABLE = False PSYCOPG2_ERROR = str(e) logger = setup_logger(__name__) class TemplateTablePostgres: """ 模板表类 - PostgreSQL 版本 基于 PostgreSQL 数据库实现的模板管理功能 """ def __init__(self): # 检查 psycopg2 是否可用 if not PSYCOPG2_AVAILABLE: error_msg = f""" PostgreSQL 驱动 psycopg2 未安装。请安装: 方案1(推荐): pip install psycopg2-binary 方案2: # Ubuntu/Debian sudo apt-get install postgresql libpq-dev python3-dev pip install psycopg2 # CentOS/RHEL sudo yum install postgresql postgresql-devel python3-devel pip install psycopg2 # macOS brew install postgresql pip install psycopg2 原始错误: {PSYCOPG2_ERROR} """ logger.error(error_msg) raise ImportError(error_msg) self.db_url = settings.db self.table_name = "templates" # 初始化模板表 # self._init_template_table() @contextmanager def _get_connection(self): """获取数据库连接的上下文管理器""" conn = None try: conn = psycopg2.connect(self.db_url) yield conn except Exception as e: if conn: conn.rollback() logger.error(f"Database connection error: {e}") raise e finally: if conn: conn.close() def _init_template_table(self): """初始化模板表""" try: with self._get_connection() as conn: with conn.cursor() as cursor: # 创建模板表(如果不存在) create_table_sql = """ CREATE TABLE IF NOT EXISTS templates ( id VARCHAR(36) PRIMARY KEY, name VARCHAR(255) NOT NULL, description TEXT DEFAULT '', thumbnail_path TEXT DEFAULT '', draft_content_path TEXT DEFAULT '', draft_content JSONB DEFAULT '{}', resources_path TEXT DEFAULT '', canvas_config JSONB DEFAULT '{}', duration INTEGER DEFAULT 0, material_count INTEGER DEFAULT 0, track_count INTEGER DEFAULT 0, tags JSONB DEFAULT '[]', is_cloud BOOLEAN DEFAULT FALSE, user_id VARCHAR(36) NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); """ cursor.execute(create_table_sql) # 创建索引 indexes = [ "CREATE INDEX IF NOT EXISTS idx_templates_name ON templates(name);", "CREATE INDEX IF NOT EXISTS idx_templates_user_id ON templates(user_id);", "CREATE INDEX IF NOT EXISTS idx_templates_is_cloud ON templates(is_cloud);", "CREATE INDEX IF NOT EXISTS idx_templates_created_at ON templates(created_at);", "CREATE INDEX IF NOT EXISTS idx_templates_tags ON templates USING GIN(tags);", "CREATE INDEX IF NOT EXISTS idx_templates_draft_content ON templates USING GIN(draft_content);", "CREATE INDEX IF NOT EXISTS idx_templates_user_name ON templates(user_id, name);" ] for index_sql in indexes: cursor.execute(index_sql) conn.commit() logger.info("Template table initialized") except Exception as e: logger.error(f"Failed to initialize template table: {e}") raise e def create_template(self, template_info: TemplateInfo) -> str: """ 创建模板 Args: template_info: 模板信息 Returns: 模板ID """ try: with self._get_connection() as conn: with conn.cursor() as cursor: # 检查模板名称是否已存在(同一用户下) cursor.execute( "SELECT id FROM templates WHERE name = %s AND user_id = %s", (template_info.name, template_info.user_id) ) if cursor.fetchone(): raise ValueError(f"Template name '{template_info.name}' already exists for this user") # 生成模板ID(如果没有提供) template_id = template_info.id or str(uuid.uuid4()) # 插入模板记录 insert_sql = """ INSERT INTO templates ( id, name, description, thumbnail_path, draft_content_path, draft_content, resources_path, canvas_config, duration, material_count, track_count, tags, is_cloud, user_id, created_at, updated_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ now = datetime.now() # 处理 draft_content draft_content = template_info.draft_content if draft_content is None: draft_content = {} logger.debug(f"Storing draft_content for template {template_id}: {type(draft_content)} with {len(draft_content)} keys") cursor.execute(insert_sql, ( template_id, template_info.name, template_info.description, template_info.thumbnail_path, template_info.draft_content_path, json.dumps(draft_content), template_info.resources_path, json.dumps(template_info.canvas_config), template_info.duration, template_info.material_count, template_info.track_count, json.dumps(template_info.tags), template_info.is_cloud, template_info.user_id, now, now )) conn.commit() logger.info(f"Created template: {template_info.name} (ID: {template_id})") return template_id except Exception as e: logger.error(f"Failed to create template '{template_info.name}': {e}") raise e def get_template_by_id(self, template_id: str) -> Optional[TemplateInfo]: """ 根据ID获取模板 Args: template_id: 模板ID Returns: 模板信息,如果不存在返回None """ try: with self._get_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: cursor.execute("SELECT * FROM templates WHERE id = %s", (template_id,)) row = cursor.fetchone() if row: return self._row_to_template_info(row) return None except Exception as e: logger.error(f"Failed to get template by ID '{template_id}': {e}") return None def get_template_by_name(self, name: str, user_id: str = None) -> Optional[TemplateInfo]: """ 根据名称获取模板 Args: name: 模板名称 user_id: 用户ID(可选,用于过滤用户模板) Returns: 模板信息,如果不存在返回None """ try: with self._get_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: if user_id: cursor.execute( "SELECT * FROM templates WHERE name = %s AND user_id = %s LIMIT 1", (name, user_id) ) else: cursor.execute("SELECT * FROM templates WHERE name = %s LIMIT 1", (name,)) row = cursor.fetchone() if row: return self._row_to_template_info(row) return None except Exception as e: logger.error(f"Failed to get template by name '{name}': {e}") return None def update_template(self, template_id: str, updates: Dict[str, Any]) -> bool: """ 更新模板信息 Args: template_id: 模板ID updates: 要更新的字段 Returns: 更新成功返回True """ try: with self._get_connection() as conn: with conn.cursor() as cursor: # 检查模板是否存在 cursor.execute("SELECT id FROM templates WHERE id = %s", (template_id,)) if not cursor.fetchone(): logger.warning(f"Template not found: {template_id}") return False # 移除不应该直接更新的字段 protected_fields = ["id", "created_at"] filtered_updates = {k: v for k, v in updates.items() if k not in protected_fields} if not filtered_updates: logger.warning("No valid fields to update") return False # 添加更新时间 filtered_updates["updated_at"] = datetime.now() # 处理 JSON 字段 if "canvas_config" in filtered_updates: filtered_updates["canvas_config"] = json.dumps(filtered_updates["canvas_config"]) if "tags" in filtered_updates: filtered_updates["tags"] = json.dumps(filtered_updates["tags"]) if "draft_content" in filtered_updates: filtered_updates["draft_content"] = json.dumps(filtered_updates["draft_content"]) # 构建更新SQL set_clauses = [] values = [] for key, value in filtered_updates.items(): set_clauses.append(f"{key} = %s") values.append(value) values.append(template_id) # WHERE条件的参数 update_sql = f"UPDATE templates SET {', '.join(set_clauses)} WHERE id = %s" cursor.execute(update_sql, values) conn.commit() if cursor.rowcount > 0: logger.info(f"Updated template: {template_id}") return True else: logger.warning(f"No rows updated for template: {template_id}") return False except Exception as e: logger.error(f"Failed to update template '{template_id}': {e}") return False def delete_template(self, template_id: str) -> bool: """ 删除模板 Args: template_id: 模板ID Returns: 删除成功返回True """ try: with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute("DELETE FROM templates WHERE id = %s", (template_id,)) conn.commit() if cursor.rowcount > 0: logger.info(f"Deleted template: {template_id}") return True else: logger.warning(f"No template found to delete: {template_id}") return False except Exception as e: logger.error(f"Failed to delete template '{template_id}': {e}") return False def get_templates_by_user(self, user_id: str, include_cloud: bool = True, limit: int = 100) -> List[TemplateInfo]: """ 获取用户的模板列表 Args: user_id: 用户ID include_cloud: 是否包含云端公共模板 limit: 最大返回数量 Returns: 模板列表 """ try: with self._get_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: if include_cloud: sql = """ SELECT * FROM templates WHERE user_id = %s OR is_cloud = TRUE ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (user_id, limit)) else: sql = """ SELECT * FROM templates WHERE user_id = %s ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (user_id, limit)) rows = cursor.fetchall() templates = [] for row in rows: template_info = self._row_to_template_info(row) templates.append(template_info) return templates except Exception as e: logger.error(f"Failed to get templates for user '{user_id}': {e}") return [] def search_templates(self, query: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]: """ 搜索模板 Args: query: 搜索关键词(匹配名称、描述、标签) user_id: 用户ID(可选,用于过滤用户模板) include_cloud: 是否包含云端公共模板 limit: 最大返回数量 Returns: 匹配的模板列表 """ try: with self._get_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: search_pattern = f"%{query}%" if user_id and include_cloud: sql = """ SELECT * FROM templates WHERE (user_id = %s OR is_cloud = TRUE) AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s) ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit)) elif user_id: sql = """ SELECT * FROM templates WHERE user_id = %s AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s) ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit)) else: sql = """ SELECT * FROM templates WHERE name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (search_pattern, search_pattern, search_pattern, limit)) rows = cursor.fetchall() templates = [] for row in rows: template_info = self._row_to_template_info(row) templates.append(template_info) return templates except Exception as e: logger.error(f"Failed to search templates with query '{query}': {e}") return [] def get_templates_by_tag(self, tag: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]: """ 根据标签获取模板 Args: tag: 标签名称 user_id: 用户ID(可选) include_cloud: 是否包含云端公共模板 limit: 最大返回数量 Returns: 匹配的模板列表 """ try: with self._get_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: tag_json = json.dumps(tag) if user_id and include_cloud: sql = """ SELECT * FROM templates WHERE (user_id = %s OR is_cloud = TRUE) AND tags @> %s ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (user_id, tag_json, limit)) elif user_id: sql = """ SELECT * FROM templates WHERE user_id = %s AND tags @> %s ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (user_id, tag_json, limit)) else: sql = """ SELECT * FROM templates WHERE tags @> %s ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (tag_json, limit)) rows = cursor.fetchall() templates = [] for row in rows: template_info = self._row_to_template_info(row) templates.append(template_info) return templates except Exception as e: logger.error(f"Failed to get templates by tag '{tag}': {e}") return [] def get_cloud_templates(self, limit: int = 100) -> List[TemplateInfo]: """ 获取云端公共模板 Args: limit: 最大返回数量 Returns: 云端模板列表 """ try: with self._get_connection() as conn: with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: sql = """ SELECT * FROM templates WHERE is_cloud = TRUE ORDER BY created_at DESC LIMIT %s """ cursor.execute(sql, (limit,)) rows = cursor.fetchall() templates = [] for row in rows: template_info = self._row_to_template_info(row) templates.append(template_info) return templates except Exception as e: logger.error(f"Failed to get cloud templates: {e}") return [] def get_template_count(self, user_id: str = None, include_cloud: bool = True) -> int: """ 获取模板数量 Args: user_id: 用户ID(可选) include_cloud: 是否包含云端公共模板 Returns: 模板数量 """ try: with self._get_connection() as conn: with conn.cursor() as cursor: if user_id is None: cursor.execute("SELECT COUNT(*) FROM templates") elif include_cloud: cursor.execute( "SELECT COUNT(*) FROM templates WHERE user_id = %s OR is_cloud = TRUE", (user_id,) ) else: cursor.execute("SELECT COUNT(*) FROM templates WHERE user_id = %s", (user_id,)) result = cursor.fetchone() return result[0] if result else 0 except Exception as e: logger.error(f"Failed to get template count: {e}") return 0 def batch_import_templates(self, templates: List[TemplateInfo]) -> Dict[str, Any]: """ 批量导入模板 Args: templates: 模板列表 Returns: 导入结果统计 """ try: success_count = 0 failed_count = 0 failed_templates = [] for template in templates: try: # 检查是否已存在(根据unique id属性) existing = self.get_template_by_id(template.id) if existing: logger.warning(f"Template already exists, skipping: {template.name}") continue # 创建模板 self.create_template(template) success_count += 1 except Exception as e: logger.error(f"Failed to import template '{template.name}': {e}") failed_count += 1 failed_templates.append({ "name": template.name, "error": str(e) }) result = { "total": len(templates), "success": success_count, "failed": failed_count, "failed_templates": failed_templates } logger.info(f"Batch import completed: {success_count} success, {failed_count} failed") return result except Exception as e: logger.error(f"Failed to batch import templates: {e}") return { "total": len(templates), "success": 0, "failed": len(templates), "error": str(e) } def get_popular_tags(self, user_id: str = None, limit: int = 20) -> List[Dict[str, Any]]: """ 获取热门标签 Args: user_id: 用户ID(可选) limit: 最大返回数量 Returns: 标签列表,包含标签名称和使用次数 """ try: with self._get_connection() as conn: with conn.cursor() as cursor: if user_id: sql = """ SELECT tag, COUNT(*) as count FROM ( SELECT jsonb_array_elements_text(tags) as tag FROM templates WHERE user_id = %s OR is_cloud = TRUE ) t GROUP BY tag ORDER BY count DESC LIMIT %s """ cursor.execute(sql, (user_id, limit)) else: sql = """ SELECT tag, COUNT(*) as count FROM ( SELECT jsonb_array_elements_text(tags) as tag FROM templates ) t GROUP BY tag ORDER BY count DESC LIMIT %s """ cursor.execute(sql, (limit,)) rows = cursor.fetchall() popular_tags = [] for row in rows: popular_tags.append({ "tag": row[0], "count": row[1] }) return popular_tags except Exception as e: logger.error(f"Failed to get popular tags: {e}") return [] # 辅助方法 def _row_to_template_info(self, row: Dict[str, Any]) -> TemplateInfo: """将数据库行转换为TemplateInfo对象""" template_info = TemplateInfo( id=row['id'], name=row['name'], description=row['description'], thumbnail_path=row['thumbnail_path'], draft_content_path=row['draft_content_path'], resources_path=row['resources_path'], canvas_config=row['canvas_config'] if isinstance(row['canvas_config'], dict) else {}, duration=row['duration'], material_count=row['material_count'], track_count=row['track_count'], tags=row['tags'] if isinstance(row['tags'], list) else [], is_cloud=row['is_cloud'], user_id=row['user_id'], created_at=row['created_at'].isoformat() if row['created_at'] else "", updated_at=row['updated_at'].isoformat() if row['updated_at'] else "" ) # 添加 draft_content 属性(动态添加) draft_content = row.get('draft_content', {}) if isinstance(draft_content, dict): template_info.draft_content = draft_content else: template_info.draft_content = {} return template_info def update_draft_content(self, template_id: str, draft_content: Dict[str, Any]) -> bool: """ 更新模板的 draft content Args: template_id: 模板ID draft_content: draft content 数据 Returns: 更新成功返回True """ try: return self.update_template(template_id, {"draft_content": draft_content}) except Exception as e: logger.error(f"Failed to update draft content for template '{template_id}': {e}") return False def get_draft_content(self, template_id: str) -> Optional[Dict[str, Any]]: """ 获取模板的 draft content Args: template_id: 模板ID Returns: draft content 数据,如果不存在返回None """ try: template = self.get_template_by_id(template_id) if template and hasattr(template, 'draft_content'): return template.draft_content return None except Exception as e: logger.error(f"Failed to get draft content for template '{template_id}': {e}") return None def save_draft_content_from_file(self, template_id: str, draft_file_path: str) -> bool: """ 从文件读取 draft content 并保存到数据库 Args: template_id: 模板ID draft_file_path: draft content 文件路径 Returns: 保存成功返回True """ try: import os if not os.path.exists(draft_file_path): logger.error(f"Draft content file not found: {draft_file_path}") return False with open(draft_file_path, 'r', encoding='utf-8') as f: draft_content = json.load(f) return self.update_draft_content(template_id, draft_content) except Exception as e: logger.error(f"Failed to save draft content from file '{draft_file_path}': {e}") return False def export_draft_content_to_file(self, template_id: str, output_path: str) -> bool: """ 将数据库中的 draft content 导出到文件 Args: template_id: 模板ID output_path: 输出文件路径 Returns: 导出成功返回True """ try: draft_content = self.get_draft_content(template_id) if draft_content is None: logger.error(f"No draft content found for template: {template_id}") return False import os os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(draft_content, f, ensure_ascii=False, indent=2) logger.info(f"Draft content exported to: {output_path}") return True except Exception as e: logger.error(f"Failed to export draft content to file '{output_path}': {e}") return False # 创建全局模板表实例 template_table = TemplateTablePostgres()