502 lines
16 KiB
Python
502 lines
16 KiB
Python
# 模板表
|
||
|
||
import uuid
|
||
import json
|
||
from typing import Dict, List, Any, Optional
|
||
from datetime import datetime
|
||
from .db import Db
|
||
from .types import TemplateInfo
|
||
from python_core.utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
class TemplateTable(Db):
|
||
"""
|
||
模板表类
|
||
基于Db类实现的模板管理功能
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 使用固定的数据库键
|
||
super().__init__("mixvideo")
|
||
self.table_name = "templates"
|
||
|
||
# 初始化模板表
|
||
self._init_template_table()
|
||
|
||
def _init_template_table(self):
|
||
"""初始化模板表"""
|
||
try:
|
||
# 检查模板表是否存在,不存在则创建
|
||
if not self._table_exists(self.table_name):
|
||
schema = {
|
||
"name": "string",
|
||
"description": "string",
|
||
"thumbnail_path": "string",
|
||
"draft_content_path": "string",
|
||
"resources_path": "string",
|
||
"canvas_config": "json",
|
||
"duration": "integer",
|
||
"material_count": "integer",
|
||
"track_count": "integer",
|
||
"tags": "json",
|
||
"is_cloud": "boolean",
|
||
"user_id": "string",
|
||
"created_at": "datetime",
|
||
"updated_at": "datetime"
|
||
}
|
||
|
||
self.create_table(self.table_name, schema)
|
||
logger.info("Template table initialized")
|
||
else:
|
||
logger.info("Template table already exists")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize template table: {e}")
|
||
raise e
|
||
|
||
def create_template(self, template_info: TemplateInfo) -> str:
|
||
"""
|
||
创建模板
|
||
|
||
Args:
|
||
template_info: 模板信息
|
||
|
||
Returns:
|
||
模板ID
|
||
"""
|
||
try:
|
||
# 检查模板名称是否已存在(同一用户下)
|
||
existing_template = self.get_template_by_name(template_info.name, template_info.user_id)
|
||
if existing_template:
|
||
raise ValueError(f"Template name '{template_info.name}' already exists for this user")
|
||
|
||
# 准备模板数据
|
||
template_data = {
|
||
"name": template_info.name,
|
||
"description": template_info.description,
|
||
"thumbnail_path": template_info.thumbnail_path,
|
||
"draft_content_path": template_info.draft_content_path,
|
||
"resources_path": template_info.resources_path,
|
||
"canvas_config": template_info.canvas_config,
|
||
"duration": template_info.duration,
|
||
"material_count": template_info.material_count,
|
||
"track_count": template_info.track_count,
|
||
"tags": template_info.tags,
|
||
"is_cloud": template_info.is_cloud,
|
||
"user_id": template_info.user_id,
|
||
"created_at": datetime.now().isoformat(),
|
||
"updated_at": datetime.now().isoformat()
|
||
}
|
||
|
||
# 插入模板记录
|
||
template_id = self.insert(self.table_name, template_data)
|
||
|
||
logger.info(f"Created template: {template_info.name} (ID: {template_id})")
|
||
return template_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to create template '{template_info.name}': {e}")
|
||
raise e
|
||
|
||
def get_template_by_id(self, template_id: str) -> Optional[TemplateInfo]:
|
||
"""
|
||
根据ID获取模板
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
|
||
Returns:
|
||
模板信息,如果不存在返回None
|
||
"""
|
||
try:
|
||
record = self.get(self.table_name, template_id)
|
||
if record:
|
||
return self._record_to_template_info(record)
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get template by ID '{template_id}': {e}")
|
||
return None
|
||
|
||
def get_template_by_name(self, name: str, user_id: str = None) -> Optional[TemplateInfo]:
|
||
"""
|
||
根据名称获取模板
|
||
|
||
Args:
|
||
name: 模板名称
|
||
user_id: 用户ID(可选,用于过滤用户模板)
|
||
|
||
Returns:
|
||
模板信息,如果不存在返回None
|
||
"""
|
||
try:
|
||
# 获取所有同名模板
|
||
templates = self.find_by_field(self.table_name, "name", name)
|
||
|
||
for record in templates:
|
||
template_info = self._record_to_template_info(record)
|
||
# 如果指定了用户ID,则过滤用户模板
|
||
if user_id is None or template_info.user_id == user_id:
|
||
return template_info
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get template by name '{name}': {e}")
|
||
return None
|
||
|
||
def update_template(self, template_id: str, updates: Dict[str, Any]) -> bool:
|
||
"""
|
||
更新模板信息
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
updates: 要更新的字段
|
||
|
||
Returns:
|
||
更新成功返回True
|
||
"""
|
||
try:
|
||
# 获取现有模板信息
|
||
existing_template = self.get_template_by_id(template_id)
|
||
if not existing_template:
|
||
logger.warning(f"Template not found: {template_id}")
|
||
return False
|
||
|
||
# 准备更新数据
|
||
updated_data = existing_template.__dict__.copy()
|
||
|
||
# 移除不应该直接更新的字段
|
||
protected_fields = ["id", "created_at"]
|
||
for field in protected_fields:
|
||
if field in updates:
|
||
del updates[field]
|
||
|
||
# 应用更新
|
||
updated_data.update(updates)
|
||
updated_data["updated_at"] = datetime.now().isoformat()
|
||
|
||
# 移除元数据字段,只保留数据字段
|
||
data_fields = {k: v for k, v in updated_data.items()
|
||
if k not in ["id", "created_at", "updated_at"]}
|
||
|
||
# 执行更新
|
||
result = self.update(self.table_name, template_id, data_fields)
|
||
|
||
if result:
|
||
logger.info(f"Updated template: {template_id}")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update template '{template_id}': {e}")
|
||
return False
|
||
|
||
def delete_template(self, template_id: str) -> bool:
|
||
"""
|
||
删除模板
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
|
||
Returns:
|
||
删除成功返回True
|
||
"""
|
||
try:
|
||
result = self.delete(self.table_name, template_id)
|
||
|
||
if result:
|
||
logger.info(f"Deleted template: {template_id}")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete template '{template_id}': {e}")
|
||
return False
|
||
|
||
def get_templates_by_user(self, user_id: str, include_cloud: bool = True, limit: int = 100) -> List[TemplateInfo]:
|
||
"""
|
||
获取用户的模板列表
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
include_cloud: 是否包含云端公共模板
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
模板列表
|
||
"""
|
||
try:
|
||
all_records = self.find_all(self.table_name, limit * 2) # 获取更多记录以便过滤
|
||
|
||
templates = []
|
||
for record in all_records:
|
||
if len(templates) >= limit:
|
||
break
|
||
|
||
template_info = self._record_to_template_info(record)
|
||
|
||
# 过滤条件:用户自己的模板 或 云端公共模板
|
||
if (template_info.user_id == user_id or
|
||
(include_cloud and template_info.is_cloud)):
|
||
templates.append(template_info)
|
||
|
||
return templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get templates for user '{user_id}': {e}")
|
||
return []
|
||
|
||
def search_templates(self, query: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
|
||
"""
|
||
搜索模板
|
||
|
||
Args:
|
||
query: 搜索关键词(匹配名称、描述、标签)
|
||
user_id: 用户ID(可选,用于过滤用户模板)
|
||
include_cloud: 是否包含云端公共模板
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
匹配的模板列表
|
||
"""
|
||
try:
|
||
all_templates = self.get_templates_by_user(user_id or "", include_cloud, limit * 2)
|
||
|
||
matching_templates = []
|
||
query_lower = query.lower()
|
||
|
||
for template in all_templates:
|
||
if len(matching_templates) >= limit:
|
||
break
|
||
|
||
# 检查名称、描述、标签是否匹配
|
||
name_match = query_lower in template.name.lower()
|
||
desc_match = query_lower in template.description.lower()
|
||
tags_match = any(query_lower in tag.lower() for tag in template.tags)
|
||
|
||
if name_match or desc_match or tags_match:
|
||
matching_templates.append(template)
|
||
|
||
return matching_templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to search templates with query '{query}': {e}")
|
||
return []
|
||
|
||
def get_templates_by_tag(self, tag: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
|
||
"""
|
||
根据标签获取模板
|
||
|
||
Args:
|
||
tag: 标签名称
|
||
user_id: 用户ID(可选)
|
||
include_cloud: 是否包含云端公共模板
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
匹配的模板列表
|
||
"""
|
||
try:
|
||
all_templates = self.get_templates_by_user(user_id or "", include_cloud, limit * 2)
|
||
|
||
matching_templates = []
|
||
tag_lower = tag.lower()
|
||
|
||
for template in all_templates:
|
||
if len(matching_templates) >= limit:
|
||
break
|
||
|
||
# 检查标签是否匹配
|
||
if any(tag_lower == t.lower() for t in template.tags):
|
||
matching_templates.append(template)
|
||
|
||
return matching_templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get templates by tag '{tag}': {e}")
|
||
return []
|
||
|
||
def get_cloud_templates(self, limit: int = 100) -> List[TemplateInfo]:
|
||
"""
|
||
获取云端公共模板
|
||
|
||
Args:
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
云端模板列表
|
||
"""
|
||
try:
|
||
cloud_templates = self.find_by_field(self.table_name, "is_cloud", True)
|
||
|
||
templates = []
|
||
for record in cloud_templates[:limit]:
|
||
template_info = self._record_to_template_info(record)
|
||
templates.append(template_info)
|
||
|
||
return templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get cloud templates: {e}")
|
||
return []
|
||
|
||
def get_template_count(self, user_id: str = None, include_cloud: bool = True) -> int:
|
||
"""
|
||
获取模板数量
|
||
|
||
Args:
|
||
user_id: 用户ID(可选)
|
||
include_cloud: 是否包含云端公共模板
|
||
|
||
Returns:
|
||
模板数量
|
||
"""
|
||
try:
|
||
if user_id is None:
|
||
return self.count(self.table_name)
|
||
else:
|
||
templates = self.get_templates_by_user(user_id, include_cloud)
|
||
return len(templates)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get template count: {e}")
|
||
return 0
|
||
|
||
def batch_import_templates(self, templates: List[TemplateInfo]) -> Dict[str, Any]:
|
||
"""
|
||
批量导入模板
|
||
|
||
Args:
|
||
templates: 模板列表
|
||
|
||
Returns:
|
||
导入结果统计
|
||
"""
|
||
try:
|
||
success_count = 0
|
||
failed_count = 0
|
||
failed_templates = []
|
||
|
||
for template in templates:
|
||
try:
|
||
# 检查是否已存在(根据unique id属性)
|
||
existing = self.get_template_by_id(template.id)
|
||
if existing:
|
||
logger.warning(f"Template already exists, skipping: {template.name}")
|
||
continue
|
||
|
||
# 创建模板
|
||
self.create_template(template)
|
||
success_count += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to import template '{template.name}': {e}")
|
||
failed_count += 1
|
||
failed_templates.append({
|
||
"name": template.name,
|
||
"error": str(e)
|
||
})
|
||
|
||
result = {
|
||
"total": len(templates),
|
||
"success": success_count,
|
||
"failed": failed_count,
|
||
"failed_templates": failed_templates
|
||
}
|
||
|
||
logger.info(f"Batch import completed: {success_count} success, {failed_count} failed")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to batch import templates: {e}")
|
||
return {
|
||
"total": len(templates),
|
||
"success": 0,
|
||
"failed": len(templates),
|
||
"error": str(e)
|
||
}
|
||
|
||
def get_popular_tags(self, user_id: str = None, limit: int = 20) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取热门标签
|
||
|
||
Args:
|
||
user_id: 用户ID(可选)
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
标签列表,包含标签名称和使用次数
|
||
"""
|
||
try:
|
||
templates = self.get_templates_by_user(user_id or "", include_cloud=True)
|
||
|
||
# 统计标签使用次数
|
||
tag_counts = {}
|
||
for template in templates:
|
||
for tag in template.tags:
|
||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||
|
||
# 按使用次数排序
|
||
sorted_tags = sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)
|
||
|
||
# 返回前N个标签
|
||
popular_tags = []
|
||
for tag, count in sorted_tags[:limit]:
|
||
popular_tags.append({
|
||
"tag": tag,
|
||
"count": count
|
||
})
|
||
|
||
return popular_tags
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get popular tags: {e}")
|
||
return []
|
||
|
||
# 辅助方法
|
||
def _record_to_template_info(self, record: Dict[str, Any]) -> TemplateInfo:
|
||
"""将数据库记录转换为TemplateInfo对象"""
|
||
data = record["data"]
|
||
return TemplateInfo(
|
||
id=record["id"],
|
||
name=data["name"],
|
||
description=data["description"],
|
||
thumbnail_path=data["thumbnail_path"],
|
||
draft_content_path=data["draft_content_path"],
|
||
resources_path=data["resources_path"],
|
||
canvas_config=data["canvas_config"],
|
||
duration=data["duration"],
|
||
material_count=data["material_count"],
|
||
track_count=data["track_count"],
|
||
tags=data["tags"],
|
||
is_cloud=data["is_cloud"],
|
||
user_id=data["user_id"],
|
||
created_at=data.get("created_at", record.get("created_at", "")),
|
||
updated_at=data.get("updated_at", record.get("updated_at", ""))
|
||
)
|
||
|
||
def _template_info_to_dict(self, template_info: TemplateInfo) -> Dict[str, Any]:
|
||
"""将TemplateInfo对象转换为字典"""
|
||
return {
|
||
"id": template_info.id,
|
||
"name": template_info.name,
|
||
"description": template_info.description,
|
||
"thumbnail_path": template_info.thumbnail_path,
|
||
"draft_content_path": template_info.draft_content_path,
|
||
"resources_path": template_info.resources_path,
|
||
"canvas_config": template_info.canvas_config,
|
||
"duration": template_info.duration,
|
||
"material_count": template_info.material_count,
|
||
"track_count": template_info.track_count,
|
||
"tags": template_info.tags,
|
||
"is_cloud": template_info.is_cloud,
|
||
"user_id": template_info.user_id,
|
||
"created_at": template_info.created_at,
|
||
"updated_at": template_info.updated_at
|
||
}
|
||
|
||
|
||
# 创建全局模板表实例
|
||
template_table = TemplateTable() |