802 lines
29 KiB
Python
802 lines
29 KiB
Python
# 模板表 - PostgreSQL 版本
|
||
|
||
import uuid
|
||
import json
|
||
from typing import Dict, List, Any, Optional
|
||
from datetime import datetime
|
||
from contextlib import contextmanager
|
||
from python_core.config import settings
|
||
from python_core.utils.logger import setup_logger
|
||
from .types import TemplateInfo
|
||
|
||
# 尝试导入 psycopg2,如果失败则提供友好的错误信息
|
||
try:
|
||
import psycopg2
|
||
import psycopg2.extras
|
||
PSYCOPG2_AVAILABLE = True
|
||
except ImportError as e:
|
||
PSYCOPG2_AVAILABLE = False
|
||
PSYCOPG2_ERROR = str(e)
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
class TemplateTablePostgres:
|
||
"""
|
||
模板表类 - PostgreSQL 版本
|
||
基于 PostgreSQL 数据库实现的模板管理功能
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 检查 psycopg2 是否可用
|
||
if not PSYCOPG2_AVAILABLE:
|
||
error_msg = f"""
|
||
PostgreSQL 驱动 psycopg2 未安装。请安装:
|
||
|
||
方案1(推荐):
|
||
pip install psycopg2-binary
|
||
|
||
方案2:
|
||
# Ubuntu/Debian
|
||
sudo apt-get install postgresql libpq-dev python3-dev
|
||
pip install psycopg2
|
||
|
||
# CentOS/RHEL
|
||
sudo yum install postgresql postgresql-devel python3-devel
|
||
pip install psycopg2
|
||
|
||
# macOS
|
||
brew install postgresql
|
||
pip install psycopg2
|
||
|
||
原始错误: {PSYCOPG2_ERROR}
|
||
"""
|
||
logger.error(error_msg)
|
||
raise ImportError(error_msg)
|
||
|
||
self.db_url = settings.db
|
||
self.table_name = "templates"
|
||
|
||
# 初始化模板表
|
||
# self._init_template_table()
|
||
|
||
@contextmanager
|
||
def _get_connection(self):
|
||
"""获取数据库连接的上下文管理器"""
|
||
conn = None
|
||
try:
|
||
conn = psycopg2.connect(self.db_url)
|
||
yield conn
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
logger.error(f"Database connection error: {e}")
|
||
raise e
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
def _init_template_table(self):
|
||
"""初始化模板表"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 创建模板表(如果不存在)
|
||
create_table_sql = """
|
||
CREATE TABLE IF NOT EXISTS templates (
|
||
id VARCHAR(36) PRIMARY KEY,
|
||
name VARCHAR(255) NOT NULL,
|
||
description TEXT DEFAULT '',
|
||
thumbnail_path TEXT DEFAULT '',
|
||
draft_content_path TEXT DEFAULT '',
|
||
draft_content JSONB DEFAULT '{}',
|
||
resources_path TEXT DEFAULT '',
|
||
canvas_config JSONB DEFAULT '{}',
|
||
duration INTEGER DEFAULT 0,
|
||
material_count INTEGER DEFAULT 0,
|
||
track_count INTEGER DEFAULT 0,
|
||
tags JSONB DEFAULT '[]',
|
||
is_cloud BOOLEAN DEFAULT FALSE,
|
||
user_id VARCHAR(36) NOT NULL,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
"""
|
||
cursor.execute(create_table_sql)
|
||
|
||
# 创建索引
|
||
indexes = [
|
||
"CREATE INDEX IF NOT EXISTS idx_templates_name ON templates(name);",
|
||
"CREATE INDEX IF NOT EXISTS idx_templates_user_id ON templates(user_id);",
|
||
"CREATE INDEX IF NOT EXISTS idx_templates_is_cloud ON templates(is_cloud);",
|
||
"CREATE INDEX IF NOT EXISTS idx_templates_created_at ON templates(created_at);",
|
||
"CREATE INDEX IF NOT EXISTS idx_templates_tags ON templates USING GIN(tags);",
|
||
"CREATE INDEX IF NOT EXISTS idx_templates_draft_content ON templates USING GIN(draft_content);",
|
||
"CREATE INDEX IF NOT EXISTS idx_templates_user_name ON templates(user_id, name);"
|
||
]
|
||
|
||
for index_sql in indexes:
|
||
cursor.execute(index_sql)
|
||
|
||
conn.commit()
|
||
logger.info("Template table initialized")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize template table: {e}")
|
||
raise e
|
||
|
||
def create_template(self, template_info: TemplateInfo) -> str:
|
||
"""
|
||
创建模板
|
||
|
||
Args:
|
||
template_info: 模板信息
|
||
|
||
Returns:
|
||
模板ID
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 检查模板名称是否已存在(同一用户下)
|
||
cursor.execute(
|
||
"SELECT id FROM templates WHERE name = %s AND user_id = %s",
|
||
(template_info.name, template_info.user_id)
|
||
)
|
||
if cursor.fetchone():
|
||
raise ValueError(f"Template name '{template_info.name}' already exists for this user")
|
||
|
||
# 生成模板ID(如果没有提供)
|
||
template_id = template_info.id or str(uuid.uuid4())
|
||
|
||
# 插入模板记录
|
||
insert_sql = """
|
||
INSERT INTO templates (
|
||
id, name, description, thumbnail_path, draft_content_path,
|
||
draft_content, resources_path, canvas_config, duration, material_count,
|
||
track_count, tags, is_cloud, user_id, created_at, updated_at
|
||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
|
||
now = datetime.now()
|
||
|
||
# 处理 draft_content
|
||
draft_content = template_info.draft_content
|
||
if draft_content is None:
|
||
draft_content = {}
|
||
|
||
logger.debug(f"Storing draft_content for template {template_id}: {type(draft_content)} with {len(draft_content)} keys")
|
||
|
||
cursor.execute(insert_sql, (
|
||
template_id,
|
||
template_info.name,
|
||
template_info.description,
|
||
template_info.thumbnail_path,
|
||
template_info.draft_content_path,
|
||
json.dumps(draft_content),
|
||
template_info.resources_path,
|
||
json.dumps(template_info.canvas_config),
|
||
template_info.duration,
|
||
template_info.material_count,
|
||
template_info.track_count,
|
||
json.dumps(template_info.tags),
|
||
template_info.is_cloud,
|
||
template_info.user_id,
|
||
now,
|
||
now
|
||
))
|
||
|
||
conn.commit()
|
||
logger.info(f"Created template: {template_info.name} (ID: {template_id})")
|
||
return template_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to create template '{template_info.name}': {e}")
|
||
raise e
|
||
|
||
def get_template_by_id(self, template_id: str) -> Optional[TemplateInfo]:
|
||
"""
|
||
根据ID获取模板
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
|
||
Returns:
|
||
模板信息,如果不存在返回None
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
cursor.execute("SELECT * FROM templates WHERE id = %s", (template_id,))
|
||
row = cursor.fetchone()
|
||
|
||
if row:
|
||
return self._row_to_template_info(row)
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get template by ID '{template_id}': {e}")
|
||
return None
|
||
|
||
def get_template_by_name(self, name: str, user_id: str = None) -> Optional[TemplateInfo]:
|
||
"""
|
||
根据名称获取模板
|
||
|
||
Args:
|
||
name: 模板名称
|
||
user_id: 用户ID(可选,用于过滤用户模板)
|
||
|
||
Returns:
|
||
模板信息,如果不存在返回None
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
if user_id:
|
||
cursor.execute(
|
||
"SELECT * FROM templates WHERE name = %s AND user_id = %s LIMIT 1",
|
||
(name, user_id)
|
||
)
|
||
else:
|
||
cursor.execute("SELECT * FROM templates WHERE name = %s LIMIT 1", (name,))
|
||
|
||
row = cursor.fetchone()
|
||
if row:
|
||
return self._row_to_template_info(row)
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get template by name '{name}': {e}")
|
||
return None
|
||
|
||
def update_template(self, template_id: str, updates: Dict[str, Any]) -> bool:
|
||
"""
|
||
更新模板信息
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
updates: 要更新的字段
|
||
|
||
Returns:
|
||
更新成功返回True
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 检查模板是否存在
|
||
cursor.execute("SELECT id FROM templates WHERE id = %s", (template_id,))
|
||
if not cursor.fetchone():
|
||
logger.warning(f"Template not found: {template_id}")
|
||
return False
|
||
|
||
# 移除不应该直接更新的字段
|
||
protected_fields = ["id", "created_at"]
|
||
filtered_updates = {k: v for k, v in updates.items() if k not in protected_fields}
|
||
|
||
if not filtered_updates:
|
||
logger.warning("No valid fields to update")
|
||
return False
|
||
|
||
# 添加更新时间
|
||
filtered_updates["updated_at"] = datetime.now()
|
||
|
||
# 处理 JSON 字段
|
||
if "canvas_config" in filtered_updates:
|
||
filtered_updates["canvas_config"] = json.dumps(filtered_updates["canvas_config"])
|
||
if "tags" in filtered_updates:
|
||
filtered_updates["tags"] = json.dumps(filtered_updates["tags"])
|
||
if "draft_content" in filtered_updates:
|
||
filtered_updates["draft_content"] = json.dumps(filtered_updates["draft_content"])
|
||
|
||
# 构建更新SQL
|
||
set_clauses = []
|
||
values = []
|
||
for key, value in filtered_updates.items():
|
||
set_clauses.append(f"{key} = %s")
|
||
values.append(value)
|
||
|
||
values.append(template_id) # WHERE条件的参数
|
||
|
||
update_sql = f"UPDATE templates SET {', '.join(set_clauses)} WHERE id = %s"
|
||
cursor.execute(update_sql, values)
|
||
|
||
conn.commit()
|
||
|
||
if cursor.rowcount > 0:
|
||
logger.info(f"Updated template: {template_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"No rows updated for template: {template_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update template '{template_id}': {e}")
|
||
return False
|
||
|
||
def delete_template(self, template_id: str) -> bool:
|
||
"""
|
||
删除模板
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
|
||
Returns:
|
||
删除成功返回True
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("DELETE FROM templates WHERE id = %s", (template_id,))
|
||
conn.commit()
|
||
|
||
if cursor.rowcount > 0:
|
||
logger.info(f"Deleted template: {template_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"No template found to delete: {template_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete template '{template_id}': {e}")
|
||
return False
|
||
|
||
def get_templates_by_user(self, user_id: str, include_cloud: bool = True, limit: int = 100) -> List[TemplateInfo]:
|
||
"""
|
||
获取用户的模板列表
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
include_cloud: 是否包含云端公共模板
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
模板列表
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
if include_cloud:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE user_id = %s OR is_cloud = TRUE
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (user_id, limit))
|
||
else:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE user_id = %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (user_id, limit))
|
||
|
||
rows = cursor.fetchall()
|
||
templates = []
|
||
|
||
for row in rows:
|
||
template_info = self._row_to_template_info(row)
|
||
templates.append(template_info)
|
||
|
||
return templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get templates for user '{user_id}': {e}")
|
||
return []
|
||
|
||
def search_templates(self, query: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
|
||
"""
|
||
搜索模板
|
||
|
||
Args:
|
||
query: 搜索关键词(匹配名称、描述、标签)
|
||
user_id: 用户ID(可选,用于过滤用户模板)
|
||
include_cloud: 是否包含云端公共模板
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
匹配的模板列表
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
search_pattern = f"%{query}%"
|
||
|
||
if user_id and include_cloud:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE (user_id = %s OR is_cloud = TRUE)
|
||
AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s)
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit))
|
||
elif user_id:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE user_id = %s
|
||
AND (name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s)
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (user_id, search_pattern, search_pattern, search_pattern, limit))
|
||
else:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE name ILIKE %s OR description ILIKE %s OR tags::text ILIKE %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (search_pattern, search_pattern, search_pattern, limit))
|
||
|
||
rows = cursor.fetchall()
|
||
templates = []
|
||
|
||
for row in rows:
|
||
template_info = self._row_to_template_info(row)
|
||
templates.append(template_info)
|
||
|
||
return templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to search templates with query '{query}': {e}")
|
||
return []
|
||
|
||
def get_templates_by_tag(self, tag: str, user_id: str = None, include_cloud: bool = True, limit: int = 50) -> List[TemplateInfo]:
|
||
"""
|
||
根据标签获取模板
|
||
|
||
Args:
|
||
tag: 标签名称
|
||
user_id: 用户ID(可选)
|
||
include_cloud: 是否包含云端公共模板
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
匹配的模板列表
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
tag_json = json.dumps(tag)
|
||
|
||
if user_id and include_cloud:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE (user_id = %s OR is_cloud = TRUE)
|
||
AND tags @> %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (user_id, tag_json, limit))
|
||
elif user_id:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE user_id = %s
|
||
AND tags @> %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (user_id, tag_json, limit))
|
||
else:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE tags @> %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (tag_json, limit))
|
||
|
||
rows = cursor.fetchall()
|
||
templates = []
|
||
|
||
for row in rows:
|
||
template_info = self._row_to_template_info(row)
|
||
templates.append(template_info)
|
||
|
||
return templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get templates by tag '{tag}': {e}")
|
||
return []
|
||
|
||
def get_cloud_templates(self, limit: int = 100) -> List[TemplateInfo]:
|
||
"""
|
||
获取云端公共模板
|
||
|
||
Args:
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
云端模板列表
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
sql = """
|
||
SELECT * FROM templates
|
||
WHERE is_cloud = TRUE
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (limit,))
|
||
|
||
rows = cursor.fetchall()
|
||
templates = []
|
||
|
||
for row in rows:
|
||
template_info = self._row_to_template_info(row)
|
||
templates.append(template_info)
|
||
|
||
return templates
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get cloud templates: {e}")
|
||
return []
|
||
|
||
def get_template_count(self, user_id: str = None, include_cloud: bool = True) -> int:
|
||
"""
|
||
获取模板数量
|
||
|
||
Args:
|
||
user_id: 用户ID(可选)
|
||
include_cloud: 是否包含云端公共模板
|
||
|
||
Returns:
|
||
模板数量
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
if user_id is None:
|
||
cursor.execute("SELECT COUNT(*) FROM templates")
|
||
elif include_cloud:
|
||
cursor.execute(
|
||
"SELECT COUNT(*) FROM templates WHERE user_id = %s OR is_cloud = TRUE",
|
||
(user_id,)
|
||
)
|
||
else:
|
||
cursor.execute("SELECT COUNT(*) FROM templates WHERE user_id = %s", (user_id,))
|
||
|
||
result = cursor.fetchone()
|
||
return result[0] if result else 0
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get template count: {e}")
|
||
return 0
|
||
|
||
def batch_import_templates(self, templates: List[TemplateInfo]) -> Dict[str, Any]:
|
||
"""
|
||
批量导入模板
|
||
|
||
Args:
|
||
templates: 模板列表
|
||
|
||
Returns:
|
||
导入结果统计
|
||
"""
|
||
try:
|
||
success_count = 0
|
||
failed_count = 0
|
||
failed_templates = []
|
||
|
||
for template in templates:
|
||
try:
|
||
# 检查是否已存在(根据unique id属性)
|
||
existing = self.get_template_by_id(template.id)
|
||
if existing:
|
||
logger.warning(f"Template already exists, skipping: {template.name}")
|
||
continue
|
||
|
||
# 创建模板
|
||
self.create_template(template)
|
||
success_count += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to import template '{template.name}': {e}")
|
||
failed_count += 1
|
||
failed_templates.append({
|
||
"name": template.name,
|
||
"error": str(e)
|
||
})
|
||
|
||
result = {
|
||
"total": len(templates),
|
||
"success": success_count,
|
||
"failed": failed_count,
|
||
"failed_templates": failed_templates
|
||
}
|
||
|
||
logger.info(f"Batch import completed: {success_count} success, {failed_count} failed")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to batch import templates: {e}")
|
||
return {
|
||
"total": len(templates),
|
||
"success": 0,
|
||
"failed": len(templates),
|
||
"error": str(e)
|
||
}
|
||
|
||
def get_popular_tags(self, user_id: str = None, limit: int = 20) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取热门标签
|
||
|
||
Args:
|
||
user_id: 用户ID(可选)
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
标签列表,包含标签名称和使用次数
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
if user_id:
|
||
sql = """
|
||
SELECT tag, COUNT(*) as count
|
||
FROM (
|
||
SELECT jsonb_array_elements_text(tags) as tag
|
||
FROM templates
|
||
WHERE user_id = %s OR is_cloud = TRUE
|
||
) t
|
||
GROUP BY tag
|
||
ORDER BY count DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (user_id, limit))
|
||
else:
|
||
sql = """
|
||
SELECT tag, COUNT(*) as count
|
||
FROM (
|
||
SELECT jsonb_array_elements_text(tags) as tag
|
||
FROM templates
|
||
) t
|
||
GROUP BY tag
|
||
ORDER BY count DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (limit,))
|
||
|
||
rows = cursor.fetchall()
|
||
popular_tags = []
|
||
|
||
for row in rows:
|
||
popular_tags.append({
|
||
"tag": row[0],
|
||
"count": row[1]
|
||
})
|
||
|
||
return popular_tags
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get popular tags: {e}")
|
||
return []
|
||
|
||
# 辅助方法
|
||
def _row_to_template_info(self, row: Dict[str, Any]) -> TemplateInfo:
|
||
"""将数据库行转换为TemplateInfo对象"""
|
||
template_info = TemplateInfo(
|
||
id=row['id'],
|
||
name=row['name'],
|
||
description=row['description'],
|
||
thumbnail_path=row['thumbnail_path'],
|
||
draft_content_path=row['draft_content_path'],
|
||
resources_path=row['resources_path'],
|
||
canvas_config=row['canvas_config'] if isinstance(row['canvas_config'], dict) else {},
|
||
duration=row['duration'],
|
||
material_count=row['material_count'],
|
||
track_count=row['track_count'],
|
||
tags=row['tags'] if isinstance(row['tags'], list) else [],
|
||
is_cloud=row['is_cloud'],
|
||
user_id=row['user_id'],
|
||
created_at=row['created_at'].isoformat() if row['created_at'] else "",
|
||
updated_at=row['updated_at'].isoformat() if row['updated_at'] else ""
|
||
)
|
||
|
||
# 添加 draft_content 属性(动态添加)
|
||
draft_content = row.get('draft_content', {})
|
||
if isinstance(draft_content, dict):
|
||
template_info.draft_content = draft_content
|
||
else:
|
||
template_info.draft_content = {}
|
||
|
||
return template_info
|
||
|
||
def update_draft_content(self, template_id: str, draft_content: Dict[str, Any]) -> bool:
|
||
"""
|
||
更新模板的 draft content
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
draft_content: draft content 数据
|
||
|
||
Returns:
|
||
更新成功返回True
|
||
"""
|
||
try:
|
||
return self.update_template(template_id, {"draft_content": draft_content})
|
||
except Exception as e:
|
||
logger.error(f"Failed to update draft content for template '{template_id}': {e}")
|
||
return False
|
||
|
||
def get_draft_content(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取模板的 draft content
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
|
||
Returns:
|
||
draft content 数据,如果不存在返回None
|
||
"""
|
||
try:
|
||
template = self.get_template_by_id(template_id)
|
||
if template and hasattr(template, 'draft_content'):
|
||
return template.draft_content
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Failed to get draft content for template '{template_id}': {e}")
|
||
return None
|
||
|
||
def save_draft_content_from_file(self, template_id: str, draft_file_path: str) -> bool:
|
||
"""
|
||
从文件读取 draft content 并保存到数据库
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
draft_file_path: draft content 文件路径
|
||
|
||
Returns:
|
||
保存成功返回True
|
||
"""
|
||
try:
|
||
import os
|
||
if not os.path.exists(draft_file_path):
|
||
logger.error(f"Draft content file not found: {draft_file_path}")
|
||
return False
|
||
|
||
with open(draft_file_path, 'r', encoding='utf-8') as f:
|
||
draft_content = json.load(f)
|
||
|
||
return self.update_draft_content(template_id, draft_content)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to save draft content from file '{draft_file_path}': {e}")
|
||
return False
|
||
|
||
def export_draft_content_to_file(self, template_id: str, output_path: str) -> bool:
|
||
"""
|
||
将数据库中的 draft content 导出到文件
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
output_path: 输出文件路径
|
||
|
||
Returns:
|
||
导出成功返回True
|
||
"""
|
||
try:
|
||
draft_content = self.get_draft_content(template_id)
|
||
if draft_content is None:
|
||
logger.error(f"No draft content found for template: {template_id}")
|
||
return False
|
||
|
||
import os
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(draft_content, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"Draft content exported to: {output_path}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to export draft content to file '{output_path}': {e}")
|
||
return False
|
||
|
||
|
||
# 创建全局模板表实例
|
||
template_table = TemplateTablePostgres()
|