diff --git a/python_core/database/template_postgres.py b/python_core/database/template_postgres.py new file mode 100644 index 0000000..c5b52c5 --- /dev/null +++ b/python_core/database/template_postgres.py @@ -0,0 +1,705 @@ +# 模板表 - PostgreSQL 版本 + +import uuid +import json +from typing import Dict, List, Any, Optional +from datetime import datetime +from contextlib import contextmanager +from python_core.config import settings +from python_core.utils.logger import setup_logger +from .types import TemplateInfo + +# 尝试导入 psycopg2,如果失败则提供友好的错误信息 +try: + import psycopg2 + import psycopg2.extras + PSYCOPG2_AVAILABLE = True +except ImportError as e: + PSYCOPG2_AVAILABLE = False + PSYCOPG2_ERROR = str(e) + +logger = setup_logger(__name__) + +class TemplateTablePostgres: + """ + 模板表类 - PostgreSQL 版本 + 基于 PostgreSQL 数据库实现的模板管理功能 + """ + + def __init__(self): + # 检查 psycopg2 是否可用 + if not PSYCOPG2_AVAILABLE: + error_msg = f""" +PostgreSQL 驱动 psycopg2 未安装。请安装: + +方案1(推荐): + pip install psycopg2-binary + +方案2: + # Ubuntu/Debian + sudo apt-get install postgresql libpq-dev python3-dev + pip install psycopg2 + + # CentOS/RHEL + sudo yum install postgresql postgresql-devel python3-devel + pip install psycopg2 + + # macOS + brew install postgresql + pip install psycopg2 + +原始错误: {PSYCOPG2_ERROR} +""" + logger.error(error_msg) + raise ImportError(error_msg) + + self.db_url = settings.db + self.table_name = "templates" + + # 初始化模板表 + self._init_template_table() + + @contextmanager + def _get_connection(self): + """获取数据库连接的上下文管理器""" + conn = None + try: + conn = psycopg2.connect(self.db_url) + yield conn + except Exception as e: + if conn: + conn.rollback() + logger.error(f"Database connection error: {e}") + raise e + finally: + if conn: + conn.close() + + def _init_template_table(self): + """初始化模板表""" + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + # 创建模板表(如果不存在) + create_table_sql = """ + CREATE TABLE IF NOT EXISTS templates ( + id VARCHAR(36) PRIMARY KEY, + name VARCHAR(255) NOT NULL, + description TEXT DEFAULT '', + thumbnail_path TEXT DEFAULT '', + draft_content_path TEXT DEFAULT '', + draft_content JSONB DEFAULT '{}', + resources_path TEXT DEFAULT '', + canvas_config JSONB DEFAULT '{}', + duration INTEGER DEFAULT 0, + material_count INTEGER DEFAULT 0, + track_count INTEGER DEFAULT 0, + tags JSONB DEFAULT '[]', + is_cloud BOOLEAN DEFAULT FALSE, + user_id VARCHAR(36) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """ + cursor.execute(create_table_sql) + + # 创建索引 + indexes = [ + "CREATE INDEX IF NOT EXISTS idx_templates_name ON templates(name);", + "CREATE INDEX IF NOT EXISTS idx_templates_user_id ON templates(user_id);", + "CREATE INDEX IF NOT EXISTS idx_templates_is_cloud ON templates(is_cloud);", + "CREATE INDEX IF NOT EXISTS idx_templates_created_at ON templates(created_at);", + "CREATE INDEX IF NOT EXISTS idx_templates_tags ON templates USING GIN(tags);", + "CREATE INDEX IF NOT EXISTS idx_templates_draft_content ON templates USING GIN(draft_content);", + "CREATE INDEX IF NOT EXISTS idx_templates_user_name ON templates(user_id, name);" + ] + + for index_sql in indexes: + cursor.execute(index_sql) + + conn.commit() + logger.info("Template table initialized") + + except Exception as e: + logger.error(f"Failed to initialize template table: {e}") + raise e + + def create_template(self, template_info: TemplateInfo) -> str: + """ + 创建模板 + + Args: + template_info: 模板信息 + + Returns: + 模板ID + """ + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + # 检查模板名称是否已存在(同一用户下) + cursor.execute( + "SELECT id FROM templates WHERE name = %s AND user_id = %s", + (template_info.name, template_info.user_id) + ) + if cursor.fetchone(): + raise ValueError(f"Template name '{template_info.name}' already exists for this user") + + # 生成模板ID(如果没有提供) + template_id = template_info.id or str(uuid.uuid4()) + + # 插入模板记录 + insert_sql = """ + INSERT INTO templates ( + id, name, description, thumbnail_path, draft_content_path, + draft_content, resources_path, canvas_config, duration, material_count, + track_count, tags, is_cloud, user_id, created_at, updated_at + ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """ + + now = datetime.now() + + # 处理 draft_content,如果 template_info 有 draft_content 属性则使用,否则为空对象 + draft_content = getattr(template_info, 'draft_content', {}) + + cursor.execute(insert_sql, ( + template_id, + template_info.name, + template_info.description, + template_info.thumbnail_path, + template_info.draft_content_path, + json.dumps(draft_content), + template_info.resources_path, + json.dumps(template_info.canvas_config), + template_info.duration, + template_info.material_count, + template_info.track_count, + json.dumps(template_info.tags), + template_info.is_cloud, + template_info.user_id, + now, + now + )) + + conn.commit() + logger.info(f"Created template: {template_info.name} (ID: {template_id})") + return template_id + + except Exception as e: + logger.error(f"Failed to create template '{template_info.name}': {e}") + raise e + + def get_template_by_id(self, template_id: str) -> Optional[TemplateInfo]: + """ + 根据ID获取模板 + + Args: + template_id: 模板ID + + Returns: + 模板信息,如果不存在返回None + """ + try: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: + cursor.execute("SELECT * FROM templates WHERE id = %s", (template_id,)) + row = cursor.fetchone() + + if row: + return self._row_to_template_info(row) + return None + + except Exception as e: + logger.error(f"Failed to get template by ID '{template_id}': {e}") + return None + + def get_template_by_name(self, name: str, user_id: str = None) -> Optional[TemplateInfo]: + """ + 根据名称获取模板 + + Args: + name: 模板名称 + user_id: 用户ID(可选,用于过滤用户模板) + + Returns: + 模板信息,如果不存在返回None + """ + try: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: + if user_id: + cursor.execute( + "SELECT * FROM templates WHERE name = %s AND user_id = %s LIMIT 1", + (name, user_id) + ) + else: + cursor.execute("SELECT * FROM templates WHERE name = %s LIMIT 1", (name,)) + + row = cursor.fetchone() + if row: + return self._row_to_template_info(row) + return None + + except Exception as e: + logger.error(f"Failed to get template by name '{name}': {e}") + return None + + def update_template(self, template_id: str, updates: Dict[str, Any]) -> bool: + """ + 更新模板信息 + + Args: + template_id: 模板ID + updates: 要更新的字段 + + Returns: + 更新成功返回True + """ + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + # 检查模板是否存在 + cursor.execute("SELECT id FROM templates WHERE id = %s", (template_id,)) + if not cursor.fetchone(): + logger.warning(f"Template not found: {template_id}") + return False + + # 移除不应该直接更新的字段 + protected_fields = ["id", "created_at"] + filtered_updates = {k: v for k, v in updates.items() if k not in protected_fields} + + if not filtered_updates: + logger.warning("No valid fields to update") + return False + + # 添加更新时间 + filtered_updates["updated_at"] = datetime.now() + + # 处理 JSON 字段 + if "canvas_config" in filtered_updates: + filtered_updates["canvas_config"] = json.dumps(filtered_updates["canvas_config"]) + if "tags" in filtered_updates: + filtered_updates["tags"] = json.dumps(filtered_updates["tags"]) + if "draft_content" in filtered_updates: + filtered_updates["draft_content"] = json.dumps(filtered_updates["draft_content"]) + + # 构建更新SQL + set_clauses = [] + values = [] + for key, value in filtered_updates.items(): + set_clauses.append(f"{key} = %s") + values.append(value) + + values.append(template_id) # WHERE条件的参数 + + update_sql = f"UPDATE templates SET {', '.join(set_clauses)} WHERE id = %s" + cursor.execute(update_sql, values) + + conn.commit() + + if cursor.rowcount > 0: + logger.info(f"Updated template: {template_id}") + return True + else: + logger.warning(f"No rows updated for template: {template_id}") + return False + + except Exception as e: + logger.error(f"Failed to update template '{template_id}': {e}") + return False + + def delete_template(self, template_id: str) -> bool: + """ + 删除模板 + + Args: + template_id: 模板ID + + Returns: + 删除成功返回True + """ + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute("DELETE FROM templates WHERE id = %s", (template_id,)) + conn.commit() + + if cursor.rowcount > 0: + logger.info(f"Deleted template: {template_id}") + return True + else: + logger.warning(f"No template found to delete: {template_id}") + return False + + except Exception as e: + logger.error(f"Failed to delete template '{template_id}': {e}") + return False + + def get_templates_by_user(self, user_id: str, include_cloud: bool = True, limit: int = 100) -> List[TemplateInfo]: + """ + 获取用户的模板列表 + + Args: + user_id: 用户ID + include_cloud: 是否包含云端公共模板 + limit: 最大返回数量 + + Returns: + 模板列表 + """ + try: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: + if include_cloud: + sql = """ + SELECT * FROM templates + WHERE user_id = %s OR is_cloud = TRUE + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (user_id, limit)) + else: + sql = """ + SELECT * FROM templates + WHERE user_id = %s + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (user_id, limit)) + + rows = cursor.fetchall() + templates = [] + + for row in rows: + template_info = self._row_to_template_info(row) + templates.append(template_info) + + return templates + + except Exception as e: + logger.error(f"Failed to get templates for user '{user_id}': {e}") + return [] + + def search_templates(self, query: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]: + """ + 搜索模板 + + Args: + query: 搜索关键词(匹配名称、描述、标签) + user_id: 用户ID(可选,用于过滤用户模板) + include_cloud: 是否包含云端公共模板 + limit: 最大返回数量 + + Returns: + 匹配的模板列表 + """ + try: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: + search_pattern = f"%{query}%" + + if user_id and include_cloud: + sql = """ + SELECT * FROM templates + WHERE (user_id = %s OR is_cloud = TRUE) + AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s) + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit)) + elif user_id: + sql = """ + SELECT * FROM templates + WHERE user_id = %s + AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s) + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit)) + else: + sql = """ + SELECT * FROM templates + WHERE name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (search_pattern, search_pattern, search_pattern, limit)) + + rows = cursor.fetchall() + templates = [] + + for row in rows: + template_info = self._row_to_template_info(row) + templates.append(template_info) + + return templates + + except Exception as e: + logger.error(f"Failed to search templates with query '{query}': {e}") + return [] + + def get_templates_by_tag(self, tag: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]: + """ + 根据标签获取模板 + + Args: + tag: 标签名称 + user_id: 用户ID(可选) + include_cloud: 是否包含云端公共模板 + limit: 最大返回数量 + + Returns: + 匹配的模板列表 + """ + try: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: + tag_json = json.dumps(tag) + + if user_id and include_cloud: + sql = """ + SELECT * FROM templates + WHERE (user_id = %s OR is_cloud = TRUE) + AND tags @> %s + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (user_id, tag_json, limit)) + elif user_id: + sql = """ + SELECT * FROM templates + WHERE user_id = %s + AND tags @> %s + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (user_id, tag_json, limit)) + else: + sql = """ + SELECT * FROM templates + WHERE tags @> %s + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (tag_json, limit)) + + rows = cursor.fetchall() + templates = [] + + for row in rows: + template_info = self._row_to_template_info(row) + templates.append(template_info) + + return templates + + except Exception as e: + logger.error(f"Failed to get templates by tag '{tag}': {e}") + return [] + + def get_cloud_templates(self, limit: int = 100) -> List[TemplateInfo]: + """ + 获取云端公共模板 + + Args: + limit: 最大返回数量 + + Returns: + 云端模板列表 + """ + try: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: + sql = """ + SELECT * FROM templates + WHERE is_cloud = TRUE + ORDER BY created_at DESC + LIMIT %s + """ + cursor.execute(sql, (limit,)) + + rows = cursor.fetchall() + templates = [] + + for row in rows: + template_info = self._row_to_template_info(row) + templates.append(template_info) + + return templates + + except Exception as e: + logger.error(f"Failed to get cloud templates: {e}") + return [] + + def get_template_count(self, user_id: str = None, include_cloud: bool = True) -> int: + """ + 获取模板数量 + + Args: + user_id: 用户ID(可选) + include_cloud: 是否包含云端公共模板 + + Returns: + 模板数量 + """ + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + if user_id is None: + cursor.execute("SELECT COUNT(*) FROM templates") + elif include_cloud: + cursor.execute( + "SELECT COUNT(*) FROM templates WHERE user_id = %s OR is_cloud = TRUE", + (user_id,) + ) + else: + cursor.execute("SELECT COUNT(*) FROM templates WHERE user_id = %s", (user_id,)) + + result = cursor.fetchone() + return result[0] if result else 0 + + except Exception as e: + logger.error(f"Failed to get template count: {e}") + return 0 + + def batch_import_templates(self, templates: List[TemplateInfo]) -> Dict[str, Any]: + """ + 批量导入模板 + + Args: + templates: 模板列表 + + Returns: + 导入结果统计 + """ + try: + success_count = 0 + failed_count = 0 + failed_templates = [] + + for template in templates: + try: + # 检查是否已存在(根据unique id属性) + existing = self.get_template_by_id(template.id) + if existing: + logger.warning(f"Template already exists, skipping: {template.name}") + continue + + # 创建模板 + self.create_template(template) + success_count += 1 + + except Exception as e: + logger.error(f"Failed to import template '{template.name}': {e}") + failed_count += 1 + failed_templates.append({ + "name": template.name, + "error": str(e) + }) + + result = { + "total": len(templates), + "success": success_count, + "failed": failed_count, + "failed_templates": failed_templates + } + + logger.info(f"Batch import completed: {success_count} success, {failed_count} failed") + return result + + except Exception as e: + logger.error(f"Failed to batch import templates: {e}") + return { + "total": len(templates), + "success": 0, + "failed": len(templates), + "error": str(e) + } + + def get_popular_tags(self, user_id: str = None, limit: int = 20) -> List[Dict[str, Any]]: + """ + 获取热门标签 + + Args: + user_id: 用户ID(可选) + limit: 最大返回数量 + + Returns: + 标签列表,包含标签名称和使用次数 + """ + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + if user_id: + sql = """ + SELECT tag, COUNT(*) as count + FROM ( + SELECT jsonb_array_elements_text(tags) as tag + FROM templates + WHERE user_id = %s OR is_cloud = TRUE + ) t + GROUP BY tag + ORDER BY count DESC + LIMIT %s + """ + cursor.execute(sql, (user_id, limit)) + else: + sql = """ + SELECT tag, COUNT(*) as count + FROM ( + SELECT jsonb_array_elements_text(tags) as tag + FROM templates + ) t + GROUP BY tag + ORDER BY count DESC + LIMIT %s + """ + cursor.execute(sql, (limit,)) + + rows = cursor.fetchall() + popular_tags = [] + + for row in rows: + popular_tags.append({ + "tag": row[0], + "count": row[1] + }) + + return popular_tags + + except Exception as e: + logger.error(f"Failed to get popular tags: {e}") + return [] + + # 辅助方法 + def _row_to_template_info(self, row: Dict[str, Any]) -> TemplateInfo: + """将数据库行转换为TemplateInfo对象""" + template_info = TemplateInfo( + id=row['id'], + name=row['name'], + description=row['description'], + thumbnail_path=row['thumbnail_path'], + draft_content_path=row['draft_content_path'], + resources_path=row['resources_path'], + canvas_config=row['canvas_config'] if isinstance(row['canvas_config'], dict) else {}, + duration=row['duration'], + material_count=row['material_count'], + track_count=row['track_count'], + tags=row['tags'] if isinstance(row['tags'], list) else [], + is_cloud=row['is_cloud'], + user_id=row['user_id'], + created_at=row['created_at'].isoformat() if row['created_at'] else "", + updated_at=row['updated_at'].isoformat() if row['updated_at'] else "" + ) + + # 添加 draft_content 属性(动态添加) + draft_content = row.get('draft_content', {}) + if isinstance(draft_content, dict): + template_info.draft_content = draft_content + else: + template_info.draft_content = {} + + return template_info + + +# 创建全局模板表实例 +template_table = TemplateTablePostgres() diff --git a/python_core/services/template_manager_cloud.py b/python_core/services/template_manager_cloud.py index 91e1852..24b1f11 100644 --- a/python_core/services/template_manager_cloud.py +++ b/python_core/services/template_manager_cloud.py @@ -9,13 +9,12 @@ import shutil import uuid from pathlib import Path from typing import Dict, List, Any, Optional -from dataclasses import asdict from datetime import datetime from ..utils.logger import setup_logger from ..config import settings from python_core.database.types import TemplateInfo, MaterialInfo -from python_core.database.template import template_table +from python_core.database.template_postgres import template_table logger = setup_logger(__name__)