This commit is contained in:
root 2025-07-12 22:08:19 +08:00
parent 0b6958b7b6
commit c326dd1837
2 changed files with 706 additions and 2 deletions

View File

@ -0,0 +1,705 @@
# 模板表 - PostgreSQL 版本
import uuid
import json
from typing import Dict, List, Any, Optional
from datetime import datetime
from contextlib import contextmanager
from python_core.config import settings
from python_core.utils.logger import setup_logger
from .types import TemplateInfo
# 尝试导入 psycopg2如果失败则提供友好的错误信息
try:
import psycopg2
import psycopg2.extras
PSYCOPG2_AVAILABLE = True
except ImportError as e:
PSYCOPG2_AVAILABLE = False
PSYCOPG2_ERROR = str(e)
logger = setup_logger(__name__)
class TemplateTablePostgres:
"""
模板表类 - PostgreSQL 版本
基于 PostgreSQL 数据库实现的模板管理功能
"""
def __init__(self):
# 检查 psycopg2 是否可用
if not PSYCOPG2_AVAILABLE:
error_msg = f"""
PostgreSQL 驱动 psycopg2 未安装请安装
方案1推荐
pip install psycopg2-binary
方案2
# Ubuntu/Debian
sudo apt-get install postgresql libpq-dev python3-dev
pip install psycopg2
# CentOS/RHEL
sudo yum install postgresql postgresql-devel python3-devel
pip install psycopg2
# macOS
brew install postgresql
pip install psycopg2
原始错误: {PSYCOPG2_ERROR}
"""
logger.error(error_msg)
raise ImportError(error_msg)
self.db_url = settings.db
self.table_name = "templates"
# 初始化模板表
self._init_template_table()
@contextmanager
def _get_connection(self):
"""获取数据库连接的上下文管理器"""
conn = None
try:
conn = psycopg2.connect(self.db_url)
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database connection error: {e}")
raise e
finally:
if conn:
conn.close()
def _init_template_table(self):
"""初始化模板表"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 创建模板表(如果不存在)
create_table_sql = """
CREATE TABLE IF NOT EXISTS templates (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL,
description TEXT DEFAULT '',
thumbnail_path TEXT DEFAULT '',
draft_content_path TEXT DEFAULT '',
draft_content JSONB DEFAULT '{}',
resources_path TEXT DEFAULT '',
canvas_config JSONB DEFAULT '{}',
duration INTEGER DEFAULT 0,
material_count INTEGER DEFAULT 0,
track_count INTEGER DEFAULT 0,
tags JSONB DEFAULT '[]',
is_cloud BOOLEAN DEFAULT FALSE,
user_id VARCHAR(36) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
cursor.execute(create_table_sql)
# 创建索引
indexes = [
"CREATE INDEX IF NOT EXISTS idx_templates_name ON templates(name);",
"CREATE INDEX IF NOT EXISTS idx_templates_user_id ON templates(user_id);",
"CREATE INDEX IF NOT EXISTS idx_templates_is_cloud ON templates(is_cloud);",
"CREATE INDEX IF NOT EXISTS idx_templates_created_at ON templates(created_at);",
"CREATE INDEX IF NOT EXISTS idx_templates_tags ON templates USING GIN(tags);",
"CREATE INDEX IF NOT EXISTS idx_templates_draft_content ON templates USING GIN(draft_content);",
"CREATE INDEX IF NOT EXISTS idx_templates_user_name ON templates(user_id, name);"
]
for index_sql in indexes:
cursor.execute(index_sql)
conn.commit()
logger.info("Template table initialized")
except Exception as e:
logger.error(f"Failed to initialize template table: {e}")
raise e
def create_template(self, template_info: TemplateInfo) -> str:
"""
创建模板
Args:
template_info: 模板信息
Returns:
模板ID
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 检查模板名称是否已存在(同一用户下)
cursor.execute(
"SELECT id FROM templates WHERE name = %s AND user_id = %s",
(template_info.name, template_info.user_id)
)
if cursor.fetchone():
raise ValueError(f"Template name '{template_info.name}' already exists for this user")
# 生成模板ID如果没有提供
template_id = template_info.id or str(uuid.uuid4())
# 插入模板记录
insert_sql = """
INSERT INTO templates (
id, name, description, thumbnail_path, draft_content_path,
draft_content, resources_path, canvas_config, duration, material_count,
track_count, tags, is_cloud, user_id, created_at, updated_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
now = datetime.now()
# 处理 draft_content如果 template_info 有 draft_content 属性则使用,否则为空对象
draft_content = getattr(template_info, 'draft_content', {})
cursor.execute(insert_sql, (
template_id,
template_info.name,
template_info.description,
template_info.thumbnail_path,
template_info.draft_content_path,
json.dumps(draft_content),
template_info.resources_path,
json.dumps(template_info.canvas_config),
template_info.duration,
template_info.material_count,
template_info.track_count,
json.dumps(template_info.tags),
template_info.is_cloud,
template_info.user_id,
now,
now
))
conn.commit()
logger.info(f"Created template: {template_info.name} (ID: {template_id})")
return template_id
except Exception as e:
logger.error(f"Failed to create template '{template_info.name}': {e}")
raise e
def get_template_by_id(self, template_id: str) -> Optional[TemplateInfo]:
"""
根据ID获取模板
Args:
template_id: 模板ID
Returns:
模板信息如果不存在返回None
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("SELECT * FROM templates WHERE id = %s", (template_id,))
row = cursor.fetchone()
if row:
return self._row_to_template_info(row)
return None
except Exception as e:
logger.error(f"Failed to get template by ID '{template_id}': {e}")
return None
def get_template_by_name(self, name: str, user_id: str = None) -> Optional[TemplateInfo]:
"""
根据名称获取模板
Args:
name: 模板名称
user_id: 用户ID可选用于过滤用户模板
Returns:
模板信息如果不存在返回None
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
if user_id:
cursor.execute(
"SELECT * FROM templates WHERE name = %s AND user_id = %s LIMIT 1",
(name, user_id)
)
else:
cursor.execute("SELECT * FROM templates WHERE name = %s LIMIT 1", (name,))
row = cursor.fetchone()
if row:
return self._row_to_template_info(row)
return None
except Exception as e:
logger.error(f"Failed to get template by name '{name}': {e}")
return None
def update_template(self, template_id: str, updates: Dict[str, Any]) -> bool:
"""
更新模板信息
Args:
template_id: 模板ID
updates: 要更新的字段
Returns:
更新成功返回True
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 检查模板是否存在
cursor.execute("SELECT id FROM templates WHERE id = %s", (template_id,))
if not cursor.fetchone():
logger.warning(f"Template not found: {template_id}")
return False
# 移除不应该直接更新的字段
protected_fields = ["id", "created_at"]
filtered_updates = {k: v for k, v in updates.items() if k not in protected_fields}
if not filtered_updates:
logger.warning("No valid fields to update")
return False
# 添加更新时间
filtered_updates["updated_at"] = datetime.now()
# 处理 JSON 字段
if "canvas_config" in filtered_updates:
filtered_updates["canvas_config"] = json.dumps(filtered_updates["canvas_config"])
if "tags" in filtered_updates:
filtered_updates["tags"] = json.dumps(filtered_updates["tags"])
if "draft_content" in filtered_updates:
filtered_updates["draft_content"] = json.dumps(filtered_updates["draft_content"])
# 构建更新SQL
set_clauses = []
values = []
for key, value in filtered_updates.items():
set_clauses.append(f"{key} = %s")
values.append(value)
values.append(template_id) # WHERE条件的参数
update_sql = f"UPDATE templates SET {', '.join(set_clauses)} WHERE id = %s"
cursor.execute(update_sql, values)
conn.commit()
if cursor.rowcount > 0:
logger.info(f"Updated template: {template_id}")
return True
else:
logger.warning(f"No rows updated for template: {template_id}")
return False
except Exception as e:
logger.error(f"Failed to update template '{template_id}': {e}")
return False
def delete_template(self, template_id: str) -> bool:
"""
删除模板
Args:
template_id: 模板ID
Returns:
删除成功返回True
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM templates WHERE id = %s", (template_id,))
conn.commit()
if cursor.rowcount > 0:
logger.info(f"Deleted template: {template_id}")
return True
else:
logger.warning(f"No template found to delete: {template_id}")
return False
except Exception as e:
logger.error(f"Failed to delete template '{template_id}': {e}")
return False
def get_templates_by_user(self, user_id: str, include_cloud: bool = True, limit: int = 100) -> List[TemplateInfo]:
"""
获取用户的模板列表
Args:
user_id: 用户ID
include_cloud: 是否包含云端公共模板
limit: 最大返回数量
Returns:
模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
if include_cloud:
sql = """
SELECT * FROM templates
WHERE user_id = %s OR is_cloud = TRUE
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, limit))
else:
sql = """
SELECT * FROM templates
WHERE user_id = %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, limit))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to get templates for user '{user_id}': {e}")
return []
def search_templates(self, query: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
"""
搜索模板
Args:
query: 搜索关键词匹配名称描述标签
user_id: 用户ID可选用于过滤用户模板
include_cloud: 是否包含云端公共模板
limit: 最大返回数量
Returns:
匹配的模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
search_pattern = f"%{query}%"
if user_id and include_cloud:
sql = """
SELECT * FROM templates
WHERE (user_id = %s OR is_cloud = TRUE)
AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s)
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit))
elif user_id:
sql = """
SELECT * FROM templates
WHERE user_id = %s
AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s)
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit))
else:
sql = """
SELECT * FROM templates
WHERE name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (search_pattern, search_pattern, search_pattern, limit))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to search templates with query '{query}': {e}")
return []
def get_templates_by_tag(self, tag: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
"""
根据标签获取模板
Args:
tag: 标签名称
user_id: 用户ID可选
include_cloud: 是否包含云端公共模板
limit: 最大返回数量
Returns:
匹配的模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
tag_json = json.dumps(tag)
if user_id and include_cloud:
sql = """
SELECT * FROM templates
WHERE (user_id = %s OR is_cloud = TRUE)
AND tags @> %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, tag_json, limit))
elif user_id:
sql = """
SELECT * FROM templates
WHERE user_id = %s
AND tags @> %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, tag_json, limit))
else:
sql = """
SELECT * FROM templates
WHERE tags @> %s
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (tag_json, limit))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to get templates by tag '{tag}': {e}")
return []
def get_cloud_templates(self, limit: int = 100) -> List[TemplateInfo]:
"""
获取云端公共模板
Args:
limit: 最大返回数量
Returns:
云端模板列表
"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
sql = """
SELECT * FROM templates
WHERE is_cloud = TRUE
ORDER BY created_at DESC
LIMIT %s
"""
cursor.execute(sql, (limit,))
rows = cursor.fetchall()
templates = []
for row in rows:
template_info = self._row_to_template_info(row)
templates.append(template_info)
return templates
except Exception as e:
logger.error(f"Failed to get cloud templates: {e}")
return []
def get_template_count(self, user_id: str = None, include_cloud: bool = True) -> int:
"""
获取模板数量
Args:
user_id: 用户ID可选
include_cloud: 是否包含云端公共模板
Returns:
模板数量
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
if user_id is None:
cursor.execute("SELECT COUNT(*) FROM templates")
elif include_cloud:
cursor.execute(
"SELECT COUNT(*) FROM templates WHERE user_id = %s OR is_cloud = TRUE",
(user_id,)
)
else:
cursor.execute("SELECT COUNT(*) FROM templates WHERE user_id = %s", (user_id,))
result = cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"Failed to get template count: {e}")
return 0
def batch_import_templates(self, templates: List[TemplateInfo]) -> Dict[str, Any]:
"""
批量导入模板
Args:
templates: 模板列表
Returns:
导入结果统计
"""
try:
success_count = 0
failed_count = 0
failed_templates = []
for template in templates:
try:
# 检查是否已存在根据unique id属性
existing = self.get_template_by_id(template.id)
if existing:
logger.warning(f"Template already exists, skipping: {template.name}")
continue
# 创建模板
self.create_template(template)
success_count += 1
except Exception as e:
logger.error(f"Failed to import template '{template.name}': {e}")
failed_count += 1
failed_templates.append({
"name": template.name,
"error": str(e)
})
result = {
"total": len(templates),
"success": success_count,
"failed": failed_count,
"failed_templates": failed_templates
}
logger.info(f"Batch import completed: {success_count} success, {failed_count} failed")
return result
except Exception as e:
logger.error(f"Failed to batch import templates: {e}")
return {
"total": len(templates),
"success": 0,
"failed": len(templates),
"error": str(e)
}
def get_popular_tags(self, user_id: str = None, limit: int = 20) -> List[Dict[str, Any]]:
"""
获取热门标签
Args:
user_id: 用户ID可选
limit: 最大返回数量
Returns:
标签列表包含标签名称和使用次数
"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
if user_id:
sql = """
SELECT tag, COUNT(*) as count
FROM (
SELECT jsonb_array_elements_text(tags) as tag
FROM templates
WHERE user_id = %s OR is_cloud = TRUE
) t
GROUP BY tag
ORDER BY count DESC
LIMIT %s
"""
cursor.execute(sql, (user_id, limit))
else:
sql = """
SELECT tag, COUNT(*) as count
FROM (
SELECT jsonb_array_elements_text(tags) as tag
FROM templates
) t
GROUP BY tag
ORDER BY count DESC
LIMIT %s
"""
cursor.execute(sql, (limit,))
rows = cursor.fetchall()
popular_tags = []
for row in rows:
popular_tags.append({
"tag": row[0],
"count": row[1]
})
return popular_tags
except Exception as e:
logger.error(f"Failed to get popular tags: {e}")
return []
# 辅助方法
def _row_to_template_info(self, row: Dict[str, Any]) -> TemplateInfo:
"""将数据库行转换为TemplateInfo对象"""
template_info = TemplateInfo(
id=row['id'],
name=row['name'],
description=row['description'],
thumbnail_path=row['thumbnail_path'],
draft_content_path=row['draft_content_path'],
resources_path=row['resources_path'],
canvas_config=row['canvas_config'] if isinstance(row['canvas_config'], dict) else {},
duration=row['duration'],
material_count=row['material_count'],
track_count=row['track_count'],
tags=row['tags'] if isinstance(row['tags'], list) else [],
is_cloud=row['is_cloud'],
user_id=row['user_id'],
created_at=row['created_at'].isoformat() if row['created_at'] else "",
updated_at=row['updated_at'].isoformat() if row['updated_at'] else ""
)
# 添加 draft_content 属性(动态添加)
draft_content = row.get('draft_content', {})
if isinstance(draft_content, dict):
template_info.draft_content = draft_content
else:
template_info.draft_content = {}
return template_info
# 创建全局模板表实例
template_table = TemplateTablePostgres()

View File

@ -9,13 +9,12 @@ import shutil
import uuid
from pathlib import Path
from typing import Dict, List, Any, Optional
from dataclasses import asdict
from datetime import datetime
from ..utils.logger import setup_logger
from ..config import settings
from python_core.database.types import TemplateInfo, MaterialInfo
from python_core.database.template import template_table
from python_core.database.template_postgres import template_table
logger = setup_logger(__name__)