594 lines
23 KiB
Python
594 lines
23 KiB
Python
"""
|
||
模板管理CLI命令 - 简化版本,只传输JSON RPC
|
||
"""
|
||
|
||
from pathlib import Path
|
||
from typing import Optional, List
|
||
from dataclasses import asdict
|
||
from datetime import datetime
|
||
import typer
|
||
from python_core.utils.jsonrpc_enhanced import create_response_handler, create_progress_reporter
|
||
from python_core.services.template_manager_cloud import TemplateManagerCloud, TemplateInfo
|
||
from python_core.utils.logger import logger
|
||
from uuid import uuid4
|
||
|
||
template_app = typer.Typer(name="template", help="模板管理命令")
|
||
|
||
|
||
@template_app.command("batch-import")
|
||
def batch_import(
|
||
source_folder: str = typer.Argument(..., help="包含模板子文件夹的源文件夹"),
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""批量导入模板"""
|
||
response = create_progress_reporter()
|
||
try:
|
||
# 验证源文件夹
|
||
source_path = Path(source_folder)
|
||
if not source_path.exists():
|
||
response.error(-32603, f"源文件夹不存在: {source_folder}")
|
||
return
|
||
|
||
if not source_path.is_dir():
|
||
response.error(-32603, f"路径不是文件夹: {source_folder}")
|
||
return
|
||
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 进度回调函数
|
||
def progress_callback(current: int, total: int, message: str):
|
||
response.progress(current, total, message)
|
||
|
||
# 执行批量导入
|
||
result = manager.batch_import_templates(source_folder, progress_callback)
|
||
|
||
# 返回结果
|
||
if result['status']:
|
||
response.success({
|
||
"imported_count": result['imported_count'],
|
||
"failed_count": result['failed_count'],
|
||
"imported_templates": result['imported_templates'],
|
||
"failed_templates": result['failed_templates'],
|
||
"message": result['msg']
|
||
})
|
||
else:
|
||
response.error(-32603, result['msg'])
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"批量导入失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("list")
|
||
def list_templates(
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
include_cloud: bool = typer.Option(True, "--include-cloud", help="包含云端模板"),
|
||
limit: int = typer.Option(100, "--limit", "-l", help="显示数量限制"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""列出所有模板"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 获取模板列表
|
||
templates = manager.get_templates(include_cloud=include_cloud)
|
||
|
||
# 限制显示数量
|
||
if len(templates) > limit:
|
||
templates = templates[:limit]
|
||
|
||
# 转换为字典格式
|
||
templates_data = []
|
||
for template in templates:
|
||
if isinstance(template, TemplateInfo):
|
||
templates_data.append(asdict(template))
|
||
else:
|
||
templates_data.append(template)
|
||
|
||
response.success({
|
||
"templates": templates_data,
|
||
"total_count": len(templates_data),
|
||
"limit": limit
|
||
})
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"获取模板列表失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("get")
|
||
def get_template(
|
||
template_id: str = typer.Argument(..., help="模板ID"),
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""获取模板详情"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 获取模板
|
||
template = manager.get_template(template_id)
|
||
|
||
if not template:
|
||
response.error(-32604, f"未找到模板: {template_id}")
|
||
return
|
||
|
||
# 转换为字典格式
|
||
if isinstance(template, TemplateInfo):
|
||
template_data = asdict(template)
|
||
else:
|
||
template_data = template
|
||
|
||
response.success(template_data)
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"获取模板详情失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("detail")
|
||
def get_template_detail(
|
||
template_id: str = typer.Argument(..., help="模板ID"),
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""获取模板详细信息(包含轨道和片段信息)"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 使用 PostgreSQL 模板表
|
||
from python_core.database.template_postgres import template_table
|
||
|
||
# 获取模板基本信息
|
||
template = template_table.get_template_by_id(template_id)
|
||
if not template:
|
||
response.error(-32603, f"模板不存在: {template_id}")
|
||
return
|
||
|
||
# 获取 draft_content
|
||
draft_content = template_table.get_draft_content(template_id)
|
||
if draft_content is None:
|
||
response.error(-32603, f"模板详细信息不存在: {template_id}")
|
||
return
|
||
|
||
# 确保 draft_content 是字典类型
|
||
if not isinstance(draft_content, dict):
|
||
logger.warning(f"draft_content is not a dict for template {template_id}: {type(draft_content)}")
|
||
draft_content = {}
|
||
|
||
# 构建详细信息
|
||
detail = {
|
||
'id': template.id,
|
||
'name': template.name,
|
||
'description': template.description,
|
||
'canvas_config': template.canvas_config,
|
||
'duration': template.duration,
|
||
'fps': draft_content.get('canvas_config', {}).get('fps', 30),
|
||
'sample_rate': draft_content.get('canvas_config', {}).get('sample_rate'),
|
||
'tracks': []
|
||
}
|
||
|
||
# 处理轨道信息
|
||
tracks_data = draft_content.get('tracks', [])
|
||
materials_data = draft_content.get('materials', {})
|
||
|
||
# 确保 tracks_data 是列表
|
||
if not isinstance(tracks_data, list):
|
||
logger.warning(f"tracks_data is not a list: {type(tracks_data)}")
|
||
tracks_data = []
|
||
|
||
# 创建素材查找表 - 处理真实的 draft_content 结构
|
||
materials_lookup = {}
|
||
|
||
# 如果 materials_data 是字典(真实结构),遍历各个类型
|
||
if isinstance(materials_data, dict):
|
||
for material_type in ['videos', 'audios', 'images', 'texts', 'stickers']:
|
||
material_list = materials_data.get(material_type, [])
|
||
if isinstance(material_list, list):
|
||
for material in material_list:
|
||
if isinstance(material, dict):
|
||
material_id = material.get('id', '')
|
||
if material_id:
|
||
# 添加类型信息
|
||
material_with_type = material.copy()
|
||
material_with_type['material_type'] = material_type
|
||
materials_lookup[material_id] = material_with_type
|
||
|
||
# 如果 materials_data 是列表(简化结构),直接处理
|
||
elif isinstance(materials_data, list):
|
||
for material in materials_data:
|
||
if isinstance(material, dict):
|
||
material_id = material.get('id', '')
|
||
if material_id:
|
||
materials_lookup[material_id] = material
|
||
else:
|
||
logger.warning(f"materials_data is neither dict nor list: {type(materials_data)}")
|
||
|
||
# 处理轨道
|
||
for idx, track_data in enumerate(tracks_data):
|
||
# 确保 track_data 是字典
|
||
if not isinstance(track_data, dict):
|
||
logger.warning(f"track_data is not a dict: {type(track_data)}")
|
||
continue
|
||
|
||
track = {
|
||
'id': track_data.get('id', f'track_{idx}'),
|
||
'name': track_data.get('name', f'轨道 {idx + 1}'),
|
||
'type': track_data.get('type', 'video'),
|
||
'index': idx,
|
||
'segments': [],
|
||
'properties': track_data.get('properties', {})
|
||
}
|
||
|
||
# 处理片段
|
||
segments_data = track_data.get('segments', [])
|
||
|
||
# 确保 segments_data 是列表
|
||
if not isinstance(segments_data, list):
|
||
logger.warning(f"segments_data is not a list: {type(segments_data)}")
|
||
segments_data = []
|
||
|
||
for segment_data in segments_data:
|
||
# 确保 segment_data 是字典
|
||
if not isinstance(segment_data, dict):
|
||
logger.warning(f"segment_data is not a dict: {type(segment_data)}")
|
||
continue
|
||
|
||
# 获取时间信息 - 真实的 draft_content 使用 target_timerange,单位是微秒
|
||
target_timerange = segment_data.get('target_timerange', {})
|
||
if target_timerange:
|
||
# 使用真实的 draft_content 结构
|
||
start_time = target_timerange.get('start', 0) / 1000000.0 # 微秒转换为秒
|
||
duration_us = target_timerange.get('duration', 0)
|
||
duration = duration_us / 1000000.0 # 微秒转换为秒
|
||
end_time = start_time + duration
|
||
else:
|
||
# 兼容简化的结构(毫秒)
|
||
start_time = segment_data.get('start', 0) / 1000.0 # 毫秒转换为秒
|
||
end_time = segment_data.get('end', 0) / 1000.0
|
||
duration = end_time - start_time
|
||
|
||
# 获取素材信息 - 真实的 draft_content 使用 material_id
|
||
material_id = segment_data.get('material_id', '')
|
||
material = materials_lookup.get(material_id, {})
|
||
|
||
# 如果没有找到素材,尝试从 materials 中查找对应的素材
|
||
if not material and material_id:
|
||
# 在 materials 的各个类别中查找
|
||
for material_type in ['videos', 'audios', 'images', 'texts', 'stickers']:
|
||
material_list = draft_content.get('materials', {}).get(material_type, [])
|
||
for mat in material_list:
|
||
if isinstance(mat, dict) and mat.get('id') == material_id:
|
||
material = mat
|
||
break
|
||
if material:
|
||
break
|
||
|
||
# 确定片段类型和名称
|
||
segment_type = 'video' # 默认类型
|
||
default_segment_name = f'片段 {len(track["segments"]) + 1}'
|
||
resource_path = ''
|
||
|
||
# 优先使用片段自己的名称
|
||
segment_name = segment_data.get('name', default_segment_name)
|
||
|
||
if material:
|
||
# 从素材获取信息
|
||
material_type = material.get('material_type', material.get('type', 'video'))
|
||
if material_type == 'videos':
|
||
segment_type = 'video'
|
||
elif material_type == 'audios':
|
||
segment_type = 'audio'
|
||
elif material_type == 'images':
|
||
segment_type = 'image'
|
||
elif material_type == 'texts':
|
||
segment_type = 'text'
|
||
elif material_type == 'stickers':
|
||
segment_type = 'sticker'
|
||
else:
|
||
segment_type = material_type
|
||
|
||
# 获取素材路径,但不覆盖片段名称
|
||
resource_path = material.get('path', '')
|
||
|
||
# 只有当片段没有自定义名称时,才使用素材名称
|
||
if segment_name == default_segment_name:
|
||
material_name = material.get('material_name', material.get('name'))
|
||
if material_name:
|
||
segment_name = material_name
|
||
|
||
segment = {
|
||
'id': segment_data.get('id', f'segment_{len(track["segments"])}'),
|
||
'type': segment_type,
|
||
'name': segment_name,
|
||
'start_time': start_time,
|
||
'end_time': end_time,
|
||
'duration': duration,
|
||
'resource_path': resource_path,
|
||
'properties': segment_data.get('properties', {}),
|
||
'effects': segment_data.get('effects', [])
|
||
}
|
||
|
||
track['segments'].append(segment)
|
||
|
||
detail['tracks'].append(track)
|
||
|
||
response.success(detail)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模板详情失败: {e}")
|
||
response.error(-32603, f"获取模板详情失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("delete")
|
||
def delete_template(
|
||
template_id: str = typer.Argument(..., help="模板ID"),
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
force: bool = typer.Option(False, "--force", "-f", help="强制删除,不询问确认"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""删除模板"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 获取模板信息
|
||
template = manager.get_template(template_id)
|
||
if not template:
|
||
response.error(-32604, f"未找到模板: {template_id}")
|
||
return
|
||
|
||
# 删除模板
|
||
success = manager.delete_template(template_id)
|
||
|
||
if success:
|
||
response.success({
|
||
"deleted": True,
|
||
"template_id": template_id,
|
||
"template_name": getattr(template, 'name', 'Unknown')
|
||
})
|
||
else:
|
||
response.error(-32603, "模板删除失败")
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"删除模板失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("search")
|
||
def search_templates(
|
||
query: str = typer.Argument(..., help="搜索关键词"),
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
include_cloud: bool = typer.Option(True, "--include-cloud", help="包含云端模板"),
|
||
limit: int = typer.Option(50, "--limit", "-l", help="显示数量限制"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""搜索模板"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 搜索模板
|
||
templates = manager.search_templates(query, include_cloud=include_cloud, limit=limit)
|
||
|
||
# 转换为字典格式
|
||
templates_data = []
|
||
for template in templates:
|
||
if isinstance(template, TemplateInfo):
|
||
templates_data.append(asdict(template))
|
||
else:
|
||
templates_data.append(template)
|
||
|
||
response.success({
|
||
"templates": templates_data,
|
||
"query": query,
|
||
"total_count": len(templates_data),
|
||
"limit": limit
|
||
})
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"搜索模板失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("stats")
|
||
def get_template_stats(
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""获取模板统计信息"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 获取模板数量
|
||
template_count = manager.get_template_count(include_cloud=True)
|
||
user_template_count = manager.get_template_count(include_cloud=False)
|
||
|
||
# 获取所有模板进行统计
|
||
templates = manager.get_templates(include_cloud=True)
|
||
|
||
# 计算统计信息
|
||
total_materials = sum(getattr(template, 'material_count', 0) for template in templates)
|
||
total_tracks = sum(getattr(template, 'track_count', 0) for template in templates)
|
||
total_duration = sum(getattr(template, 'duration', 0) for template in templates)
|
||
|
||
response.success({
|
||
"total_templates": template_count,
|
||
"user_templates": user_template_count,
|
||
"cloud_templates": template_count - user_template_count,
|
||
"total_materials": total_materials,
|
||
"total_tracks": total_tracks,
|
||
"total_duration": total_duration,
|
||
"average_duration": total_duration // template_count if template_count > 0 else 0
|
||
})
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"获取统计信息失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("tags")
|
||
def get_popular_tags(
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
limit: int = typer.Option(20, "--limit", "-l", help="显示数量限制"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""获取热门标签"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 获取热门标签
|
||
tags = manager.get_popular_tags(limit=limit)
|
||
|
||
response.success({
|
||
"tags": tags,
|
||
"total_count": len(tags),
|
||
"limit": limit
|
||
})
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"获取热门标签失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("update-segment")
|
||
def update_segment_name(
|
||
template_id: str = typer.Argument(..., help="模板ID"),
|
||
segment_id: str = typer.Argument(..., help="片段ID"),
|
||
new_name: str = typer.Argument(..., help="新名称"),
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""更新模板片段名称"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 使用 PostgreSQL 模板表
|
||
from python_core.database.template_postgres import template_table
|
||
|
||
# 获取模板
|
||
template = template_table.get_template_by_id(template_id)
|
||
if not template:
|
||
response.error(-32604, f"模板不存在: {template_id}")
|
||
return
|
||
|
||
# 检查权限
|
||
if user_id and template.user_id != user_id:
|
||
response.error(-32605, f"无权限修改模板: {template_id}")
|
||
return
|
||
|
||
# 获取 draft_content
|
||
draft_content = template_table.get_draft_content(template_id)
|
||
if draft_content is None:
|
||
response.error(-32606, f"模板详细信息不存在: {template_id}")
|
||
return
|
||
|
||
# 确保 draft_content 是字典类型
|
||
if not isinstance(draft_content, dict):
|
||
draft_content = {}
|
||
|
||
# 查找并更新片段名称
|
||
updated = False
|
||
tracks_data = draft_content.get('tracks', [])
|
||
|
||
logger.info(f"Searching for segment {segment_id} in {len(tracks_data)} tracks")
|
||
|
||
for track_idx, track in enumerate(tracks_data):
|
||
if isinstance(track, dict):
|
||
segments = track.get('segments', [])
|
||
logger.info(f"Track {track_idx} has {len(segments)} segments")
|
||
|
||
for seg_idx, segment in enumerate(segments):
|
||
if isinstance(segment, dict):
|
||
# 使用与 get_template_detail 相同的 ID 生成逻辑
|
||
seg_id = segment.get('id', f'segment_{seg_idx}')
|
||
seg_name = segment.get('name', 'Unknown')
|
||
logger.info(f" Segment {seg_idx}: id={seg_id}, name={seg_name}")
|
||
|
||
if seg_id == segment_id:
|
||
old_name = segment.get('name', 'Unknown')
|
||
segment['name'] = new_name
|
||
updated = True
|
||
logger.info(f"Updated segment {segment_id}: '{old_name}' -> '{new_name}'")
|
||
break
|
||
if updated:
|
||
break
|
||
|
||
if not updated:
|
||
logger.warning(f"Segment not found: {segment_id}")
|
||
response.error(-32607, f"片段不存在: {segment_id}")
|
||
return
|
||
|
||
# 更新数据库中的 draft_content
|
||
success = template_table.update_template(template_id, {
|
||
'draft_content': draft_content,
|
||
'updated_at': datetime.now().isoformat()
|
||
})
|
||
|
||
if success:
|
||
response.success({
|
||
'template_id': template_id,
|
||
'segment_id': segment_id,
|
||
'new_name': new_name,
|
||
'updated': True
|
||
})
|
||
else:
|
||
response.error(-32608, "更新失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新片段名称失败: {e}")
|
||
response.error(-32603, f"更新片段名称失败: {str(e)}")
|
||
|
||
|
||
@template_app.command("by-tag")
|
||
def get_templates_by_tag(
|
||
tag: str = typer.Argument(..., help="标签名称"),
|
||
user_id: Optional[str] = typer.Option(None, "--user-id", help="用户ID"),
|
||
include_cloud: bool = typer.Option(True, "--include-cloud", help="包含云端模板"),
|
||
limit: int = typer.Option(50, "--limit", "-l", help="显示数量限制"),
|
||
verbose: bool = typer.Option(False, "--verbose", "-v", help="详细输出"),
|
||
json_output: bool = typer.Option(True, "--json", help="JSON格式输出")
|
||
):
|
||
"""根据标签获取模板"""
|
||
response = create_response_handler()
|
||
try:
|
||
# 创建模板管理器
|
||
manager = TemplateManagerCloud(user_id=user_id or "default")
|
||
|
||
# 根据标签获取模板
|
||
templates = manager.get_templates_by_tag(tag, include_cloud=include_cloud, limit=limit)
|
||
|
||
# 转换为字典格式
|
||
templates_data = []
|
||
for template in templates:
|
||
if isinstance(template, TemplateInfo):
|
||
templates_data.append(asdict(template))
|
||
else:
|
||
templates_data.append(template)
|
||
|
||
response.success({
|
||
"templates": templates_data,
|
||
"tag": tag,
|
||
"total_count": len(templates_data),
|
||
"limit": limit
|
||
})
|
||
|
||
except Exception as e:
|
||
response.error(-32603, f"根据标签获取模板失败: {str(e)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
template_app()
|