mxivideo/python_core/database/template_postgres.py

809 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 模板表 - PostgreSQL 版本
import uuid
import json
from typing import Dict, List, Any, Optional
from datetime import datetime
from contextlib import contextmanager
from python_core.config import settings
from python_core.utils.logger import setup_logger
from .types import TemplateInfo
# 尝试导入 psycopg2如果失败则提供友好的错误信息
try:
import psycopg2
import psycopg2.extras
PSYCOPG2_AVAILABLE = True
except ImportError as e:
PSYCOPG2_AVAILABLE = False
PSYCOPG2_ERROR = str(e)
logger = setup_logger(__name__)
class TemplateTablePostgres:
"""
模板表类 - PostgreSQL 版本
基于 PostgreSQL 数据库实现的模板管理功能
"""
def __init__(self):
# 检查 psycopg2 是否可用
if not PSYCOPG2_AVAILABLE:
error_msg = f"""
PostgreSQL 驱动 psycopg2 未安装。请安装:
方案1推荐
pip install psycopg2-binary
方案2
# Ubuntu/Debian
sudo apt-get install postgresql libpq-dev python3-dev
pip install psycopg2
# CentOS/RHEL
sudo yum install postgresql postgresql-devel python3-devel
pip install psycopg2
# macOS
brew install postgresql
pip install psycopg2
原始错误: {PSYCOPG2_ERROR}
"""
logger.error(error_msg)
raise ImportError(error_msg)
self.db_url = settings.db
self.table_name = "templates"
# 初始化模板表
self._init_template_table()
@contextmanager
def _get_connection(self):
"""获取数据库连接的上下文管理器"""
conn = None
try:
conn = psycopg2.connect(self.db_url)
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database connection error: {e}")
raise e
finally:
if conn:
conn.close()
def _init_template_table(self):
"""初始化模板表"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 创建模板表(如果不存在)
create_table_sql = """
CREATE TABLE IF NOT EXISTS templates (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL,
description TEXT DEFAULT '',
thumbnail_path TEXT DEFAULT '',
draft_content_path TEXT DEFAULT '',
draft_content JSONB DEFAULT '{}',
resources_path TEXT DEFAULT '',
canvas_config JSONB DEFAULT '{}',
duration INTEGER DEFAULT 0,
material_count INTEGER DEFAULT 0,
track_count INTEGER DEFAULT 0,
tags JSONB DEFAULT '[]',
is_cloud BOOLEAN DEFAULT FALSE,
user_id VARCHAR(36) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
cursor.execute(create_table_sql)
# 创建索引
indexes = [
"CREATE INDEX IF NOT EXISTS idx_templates_name ON templates(name);",
"CREATE INDEX IF NOT EXISTS idx_templates_user_id ON templates(user_id);",
"CREATE INDEX IF NOT EXISTS idx_templates_is_cloud ON templates(is_cloud);",
"CREATE INDEX IF NOT EXISTS idx_templates_created_at ON templates(created_at);",
"CREATE INDEX IF NOT EXISTS idx_templates_tags ON templates USING GIN(tags);",
"CREATE INDEX IF NOT EXISTS idx_templates_draft_content ON templates USING GIN(draft_content);",
"CREATE INDEX IF NOT EXISTS idx_templates_user_name ON templates(user_id, name);"
]
for index_sql in indexes:
cursor.execute(index_sql)
conn.commit()
logger.info("Template table initialized")
except Exception as e:
logger.error(f"Failed to initialize template table: {e}")
raise e
def create_template(self, template_info: TemplateInfo) -> str:
"""
创建模板
Args:
template_info: 模板信息
Returns:
模板ID
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 检查模板名称是否已存在(同一用户下)
cursor.execute(
"SELECT id FROM templates WHERE name = %s AND user_id = %s",
(template_info.name, template_info.user_id)
)
if cursor.fetchone():
raise ValueError(f"Template name '{template_info.name}' already exists for this user")
# 生成模板ID如果没有提供
template_id = template_info.id or str(uuid.uuid4())
# 插入模板记录
insert_sql = """
INSERT INTO templates (
id, name, description, thumbnail_path, draft_content_path,
draft_content, resources_path, canvas_config, duration, material_count,
track_count, tags, is_cloud, user_id, created_at, updated_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
now = datetime.now()
# 处理 draft_content
draft_content = template_info.draft_content
if draft_content is None:
draft_content = {}
logger.debug(f"Storing draft_content for template {template_id}: {type(draft_content)} with {len(draft_content)} keys")
cursor.execute(insert_sql, (
template_id,
template_info.name,
template_info.description,
template_info.thumbnail_path,
template_info.draft_content_path,
json.dumps(draft_content),
template_info.resources_path,
json.dumps(template_info.canvas_config),
template_info.duration,
template_info.material_count,
template_info.track_count,
json.dumps(template_info.tags),
template_info.is_cloud,
template_info.user_id,
now,
now
))
conn.commit()
logger.info(f"Created template: {template_info.name} (ID: {template_id})")
return template_id
except Exception as e:
logger.error(f"Failed to create template '{template_info.name}': {e}")
raise e
def get_template_by_id(self, template_id: str) -> Optional[TemplateInfo]:
"""
根据ID获取模板
Args:
template_id: 模板ID
Returns:
模板信息如果不存在返回None
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("SELECT * FROM templates WHERE id = %s", (template_id,))
row = cursor.fetchone()
if row:
return self._row_to_template_info(row)
return None
except Exception as e:
logger.error(f"Failed to get template by ID '{template_id}': {e}")
return None
def get_template_by_name(self, name: str, user_id: str = None) -> Optional[TemplateInfo]:
"""
根据名称获取模板
Args:
name: 模板名称
user_id: 用户ID可选用于过滤用户模板
Returns:
模板信息如果不存在返回None
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
if user_id:
cursor.execute(
"SELECT * FROM templates WHERE name = %s AND user_id = %s LIMIT 1",
(name, user_id)
)
else:
cursor.execute("SELECT * FROM templates WHERE name = %s LIMIT 1", (name,))
row = cursor.fetchone()
if row:
return self._row_to_template_info(row)
return None
except Exception as e:
logger.error(f"Failed to get template by name '{name}': {e}")
return None
def update_template(self, template_id: str, updates: Dict[str, Any]) -> bool:
"""
更新模板信息
Args:
template_id: 模板ID
updates: 要更新的字段
Returns:
更新成功返回True
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 检查模板是否存在
cursor.execute("SELECT id FROM templates WHERE id = %s", (template_id,))
if not cursor.fetchone():
logger.warning(f"Template not found: {template_id}")
return False
# 移除不应该直接更新的字段
protected_fields = ["id", "created_at"]
filtered_updates = {k: v for k, v in updates.items() if k not in protected_fields}
if not filtered_updates:
logger.warning("No valid fields to update")
return False
# 添加更新时间
filtered_updates["updated_at"] = datetime.now()
# 处理 JSON 字段
if "canvas_config" in filtered_updates:
filtered_updates["canvas_config"] = json.dumps(filtered_updates["canvas_config"])
if "tags" in filtered_updates:
filtered_updates["tags"] = json.dumps(filtered_updates["tags"])
if "draft_content" in filtered_updates:
filtered_updates["draft_content"] = json.dumps(filtered_updates["draft_content"])
# 构建更新SQL
set_clauses = []
values = []
for key, value in filtered_updates.items():
set_clauses.append(f"{key} = %s")
values.append(value)
values.append(template_id) # WHERE条件的参数
update_sql = f"UPDATE templates SET {', '.join(set_clauses)} WHERE id = %s"
cursor.execute(update_sql, values)
conn.commit()
if cursor.rowcount > 0:
logger.info(f"Updated template: {template_id}")
return True
else:
logger.warning(f"No rows updated for template: {template_id}")
return False
except Exception as e:
logger.error(f"Failed to update template '{template_id}': {e}")
return False
def delete_template(self, template_id: str) -> bool:
"""
删除模板
Args:
template_id: 模板ID
Returns:
删除成功返回True
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM templates WHERE id = %s", (template_id,))
conn.commit()
if cursor.rowcount > 0:
logger.info(f"Deleted template: {template_id}")
return True
else:
logger.warning(f"No template found to delete: {template_id}")
return False
except Exception as e:
logger.error(f"Failed to delete template '{template_id}': {e}")
return False
def get_templates_by_user(self, user_id: str, include_cloud: bool = True, limit: int = 100) -> List[TemplateInfo]:
"""
获取用户的模板列表
Args:
user_id: 用户ID
include_cloud: 是否包含云端公共模板
limit: 最大返回数量
Returns:
模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
if include_cloud:
sql = """
SELECT * FROM templates
WHERE user_id = %s OR is_cloud = TRUE
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, limit))
else:
sql = """
SELECT * FROM templates
WHERE user_id = %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, limit))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to get templates for user '{user_id}': {e}")
return []
def search_templates(self, query: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
"""
搜索模板
Args:
query: 搜索关键词(匹配名称、描述、标签)
user_id: 用户ID可选用于过滤用户模板
include_cloud: 是否包含云端公共模板
limit: 最大返回数量
Returns:
匹配的模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
search_pattern = f"%{query}%"
if user_id and include_cloud:
sql = """
SELECT * FROM templates
WHERE (user_id = %s OR is_cloud = TRUE)
AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s)
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit))
elif user_id:
sql = """
SELECT * FROM templates
WHERE user_id = %s
AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s)
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit))
else:
sql = """
SELECT * FROM templates
WHERE name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (search_pattern, search_pattern, search_pattern, limit))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to search templates with query '{query}': {e}")
return []
def get_templates_by_tag(self, tag: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
"""
根据标签获取模板
Args:
tag: 标签名称
user_id: 用户ID可选
include_cloud: 是否包含云端公共模板
limit: 最大返回数量
Returns:
匹配的模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
tag_json = json.dumps(tag)
if user_id and include_cloud:
sql = """
SELECT * FROM templates
WHERE (user_id = %s OR is_cloud = TRUE)
AND tags @> %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, tag_json, limit))
elif user_id:
sql = """
SELECT * FROM templates
WHERE user_id = %s
AND tags @> %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, tag_json, limit))
else:
sql = """
SELECT * FROM templates
WHERE tags @> %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (tag_json, limit))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to get templates by tag '{tag}': {e}")
return []
def get_cloud_templates(self, limit: int = 100) -> List[TemplateInfo]:
"""
获取云端公共模板
Args:
limit: 最大返回数量
Returns:
云端模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
sql = """
SELECT * FROM templates
WHERE is_cloud = TRUE
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (limit,))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to get cloud templates: {e}")
return []
def get_template_count(self, user_id: str = None, include_cloud: bool = True) -> int:
"""
获取模板数量
Args:
user_id: 用户ID可选
include_cloud: 是否包含云端公共模板
Returns:
模板数量
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
if user_id is None:
cursor.execute("SELECT COUNT(*) FROM templates")
elif include_cloud:
cursor.execute(
"SELECT COUNT(*) FROM templates WHERE user_id = %s OR is_cloud = TRUE",
(user_id,)
)
else:
cursor.execute("SELECT COUNT(*) FROM templates WHERE user_id = %s", (user_id,))
result = cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"Failed to get template count: {e}")
return 0
def batch_import_templates(self, templates: List[TemplateInfo]) -> Dict[str, Any]:
"""
批量导入模板
Args:
templates: 模板列表
Returns:
导入结果统计
"""
try:
success_count = 0
failed_count = 0
failed_templates = []
for template in templates:
try:
# 检查是否已存在根据unique id属性
existing = self.get_template_by_id(template.id)
if existing:
logger.warning(f"Template already exists, skipping: {template.name}")
continue
# 创建模板
self.create_template(template)
success_count += 1
except Exception as e:
logger.error(f"Failed to import template '{template.name}': {e}")
failed_count += 1
failed_templates.append({
"name": template.name,
"error": str(e)
})
result = {
"total": len(templates),
"success": success_count,
"failed": failed_count,
"failed_templates": failed_templates
}
logger.info(f"Batch import completed: {success_count} success, {failed_count} failed")
return result
except Exception as e:
logger.error(f"Failed to batch import templates: {e}")
return {
"total": len(templates),
"success": 0,
"failed": len(templates),
"error": str(e)
}
def get_popular_tags(self, user_id: str = None, limit: int = 20) -> List[Dict[str, Any]]:
"""
获取热门标签
Args:
user_id: 用户ID可选
limit: 最大返回数量
Returns:
标签列表,包含标签名称和使用次数
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
if user_id:
sql = """
SELECT tag, COUNT(*) as count
FROM (
SELECT jsonb_array_elements_text(tags) as tag
FROM templates
WHERE user_id = %s OR is_cloud = TRUE
) t
GROUP BY tag
ORDER BY count DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, limit))
else:
sql = """
SELECT tag, COUNT(*) as count
FROM (
SELECT jsonb_array_elements_text(tags) as tag
FROM templates
) t
GROUP BY tag
ORDER BY count DESC
LIMIT %s
"""
cursor.execute(sql, (limit,))
rows = cursor.fetchall()
popular_tags = []
for row in rows:
popular_tags.append({
"tag": row[0],
"count": row[1]
})
return popular_tags
except Exception as e:
logger.error(f"Failed to get popular tags: {e}")
return []
# 辅助方法
def _row_to_template_info(self, row: Dict[str, Any]) -> TemplateInfo:
"""将数据库行转换为TemplateInfo对象"""
template_info = TemplateInfo(
id=row['id'],
name=row['name'],
description=row['description'],
thumbnail_path=row['thumbnail_path'],
draft_content_path=row['draft_content_path'],
resources_path=row['resources_path'],
canvas_config=row['canvas_config'] if isinstance(row['canvas_config'], dict) else {},
duration=row['duration'],
material_count=row['material_count'],
track_count=row['track_count'],
tags=row['tags'] if isinstance(row['tags'], list) else [],
is_cloud=row['is_cloud'],
user_id=row['user_id'],
created_at=row['created_at'].isoformat() if row['created_at'] else "",
updated_at=row['updated_at'].isoformat() if row['updated_at'] else ""
)
# 添加 draft_content 属性(动态添加)
draft_content = row.get('draft_content', {})
if isinstance(draft_content, dict):
template_info.draft_content = draft_content
elif isinstance(draft_content, str):
# 如果是字符串,尝试解析为 JSON
try:
template_info.draft_content = json.loads(draft_content)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse draft_content JSON for template {row['id']}: {draft_content}")
template_info.draft_content = {}
else:
template_info.draft_content = {}
return template_info
def update_draft_content(self, template_id: str, draft_content: Dict[str, Any]) -> bool:
"""
更新模板的 draft content
Args:
template_id: 模板ID
draft_content: draft content 数据
Returns:
更新成功返回True
"""
try:
return self.update_template(template_id, {"draft_content": draft_content})
except Exception as e:
logger.error(f"Failed to update draft content for template '{template_id}': {e}")
return False
def get_draft_content(self, template_id: str) -> Optional[Dict[str, Any]]:
"""
获取模板的 draft content
Args:
template_id: 模板ID
Returns:
draft content 数据如果不存在返回None
"""
try:
template = self.get_template_by_id(template_id)
if template and hasattr(template, 'draft_content'):
return template.draft_content
return None
except Exception as e:
logger.error(f"Failed to get draft content for template '{template_id}': {e}")
return None
def save_draft_content_from_file(self, template_id: str, draft_file_path: str) -> bool:
"""
从文件读取 draft content 并保存到数据库
Args:
template_id: 模板ID
draft_file_path: draft content 文件路径
Returns:
保存成功返回True
"""
try:
import os
if not os.path.exists(draft_file_path):
logger.error(f"Draft content file not found: {draft_file_path}")
return False
with open(draft_file_path, 'r', encoding='utf-8') as f:
draft_content = json.load(f)
return self.update_draft_content(template_id, draft_content)
except Exception as e:
logger.error(f"Failed to save draft content from file '{draft_file_path}': {e}")
return False
def export_draft_content_to_file(self, template_id: str, output_path: str) -> bool:
"""
将数据库中的 draft content 导出到文件
Args:
template_id: 模板ID
output_path: 输出文件路径
Returns:
导出成功返回True
"""
try:
draft_content = self.get_draft_content(template_id)
if draft_content is None:
logger.error(f"No draft content found for template: {template_id}")
return False
import os
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(draft_content, f, ensure_ascii=False, indent=2)
logger.info(f"Draft content exported to: {output_path}")
return True
except Exception as e:
logger.error(f"Failed to export draft content to file '{output_path}': {e}")
return False
# 创建全局模板表实例
template_table = TemplateTablePostgres()