mxivideo/python_core/services/scene_detection/cli.py

353 lines
14 KiB
Python

#!/usr/bin/env python3
"""
场景检测命令行接口
"""
import os
from typing import Dict, Any
from dataclasses import asdict
from .types import DetectorType, BatchDetectionConfig
from .detector import SceneDetectionService
from python_core.utils.progress import ProgressJSONRPCCommander
class SceneDetectionCommander(ProgressJSONRPCCommander):
"""场景检测命令行接口 - 支持进度条"""
def __init__(self):
super().__init__("scene_detection")
self.service = SceneDetectionService()
def _register_commands(self) -> None:
"""注册命令"""
# 单个视频检测
self.register_command(
name="detect",
description="检测单个视频的场景",
required_args=["video_path"],
optional_args={
"detector": {"type": str, "default": "content", "choices": ["content", "threshold", "adaptive"], "description": "检测器类型"},
"threshold": {"type": float, "default": 30.0, "description": "检测阈值"},
"min_scene_length": {"type": float, "default": 1.0, "description": "最小场景长度(秒)"},
"output": {"type": str, "description": "输出文件路径"},
"format": {"type": str, "default": "json", "choices": ["json", "csv", "txt"], "description": "输出格式"}
}
)
# 批量检测
self.register_command(
name="batch_detect",
description="批量检测目录中所有视频的场景",
required_args=["input_directory"],
optional_args={
"detector": {"type": str, "default": "content", "choices": ["content", "threshold", "adaptive"], "description": "检测器类型"},
"threshold": {"type": float, "default": 30.0, "description": "检测阈值"},
"min_scene_length": {"type": float, "default": 1.0, "description": "最小场景长度(秒)"},
"output": {"type": str, "description": "输出文件路径"},
"format": {"type": str, "default": "json", "choices": ["json", "csv", "txt"], "description": "输出格式"},
"adaptive": {"type": bool, "default": False, "description": "启用自适应阈值"},
"thumbnails": {"type": bool, "default": False, "description": "生成缩略图"}
}
)
# 分析结果
self.register_command(
name="analyze",
description="分析检测结果并生成统计信息",
required_args=["result_file"],
optional_args={
"output": {"type": str, "description": "统计输出文件路径"}
}
)
# 比较检测器
self.register_command(
name="compare",
description="比较不同检测器的效果",
required_args=["video_path"],
optional_args={
"thresholds": {"type": str, "default": "20,30,40", "description": "测试阈值列表(逗号分隔)"},
"output": {"type": str, "description": "比较结果输出文件"}
}
)
def _is_progressive_command(self, command: str) -> bool:
"""判断是否需要进度报告的命令"""
# 批量操作和比较操作需要进度报告
return command in ["batch_detect", "compare"]
def _execute_with_progress(self, command: str, args: Dict[str, Any]) -> Any:
"""执行带进度的命令"""
if command == "batch_detect":
return self._batch_detect_with_progress(args)
elif command == "compare":
return self._compare_with_progress(args)
else:
raise ValueError(f"Unknown progressive command: {command}")
def _execute_simple_command(self, command: str, args: Dict[str, Any]) -> Any:
"""执行简单命令(不需要进度)"""
if command == "detect":
return self._detect_single_video(args)
elif command == "analyze":
return self._analyze_results(args)
else:
raise ValueError(f"Unknown command: {command}")
def _detect_single_video(self, args: Dict[str, Any]) -> dict:
"""检测单个视频"""
config = self._create_config(args)
result = self.service.detect_single_video(args["video_path"], config)
# 保存结果(如果指定了输出路径)
if args.get("output"):
output_path = args["output"]
# 创建临时批量结果来使用保存功能
from .types import BatchDetectionResult
batch_result = BatchDetectionResult(
total_files=1,
processed_files=1 if result.success else 0,
failed_files=0 if result.success else 1,
total_scenes=result.total_scenes,
total_duration=result.total_duration,
average_scenes_per_video=result.total_scenes,
detection_time=result.detection_time,
results=[result] if result.success else [],
failed_list=[] if result.success else [{"filename": result.filename, "error": result.error}],
config=config
)
self.service.save_results(batch_result, output_path)
return asdict(result)
def _batch_detect_with_progress(self, args: Dict[str, Any]) -> dict:
"""带进度的批量检测"""
config = self._create_config(args)
input_directory = args["input_directory"]
# 先扫描文件数量
video_files = self.service._scan_video_files(input_directory)
if not video_files:
return {
"total_files": 0,
"processed_files": 0,
"failed_files": 0,
"message": "No video files found in directory"
}
# 使用进度任务
with self.create_task("批量场景检测", len(video_files)) as task:
def progress_callback(message: str):
# 从消息中提取进度信息
if "(" in message and "/" in message:
# 提取 (x/y) 格式的进度
try:
progress_part = message.split("(")[1].split(")")[0]
current, total = progress_part.split("/")
task.update(int(current) - 1, message)
except:
task.update(message=message)
else:
task.update(message=message)
# 执行批量检测
result = self.service.batch_detect_scenes(
input_directory, config, progress_callback
)
# 保存结果(如果指定了输出路径)
if args.get("output"):
self.service.save_results(result, args["output"])
task.finish(f"批量检测完成: {result.processed_files} 成功, {result.failed_files} 失败")
return asdict(result)
def _compare_with_progress(self, args: Dict[str, Any]) -> dict:
"""带进度的检测器比较"""
video_path = args["video_path"]
thresholds_str = args.get("thresholds", "20,30,40")
try:
thresholds = [float(t.strip()) for t in thresholds_str.split(",")]
except ValueError:
raise ValueError("Invalid thresholds format. Use comma-separated numbers like '20,30,40'")
detectors = ["content", "threshold", "adaptive"]
total_tests = len(detectors) * len(thresholds)
with self.create_task("比较检测器", total_tests) as task:
results = []
test_count = 0
for detector in detectors:
for threshold in thresholds:
test_count += 1
task.update(test_count - 1, f"测试 {detector} 检测器 (阈值: {threshold})")
config = BatchDetectionConfig(
detector_type=DetectorType(detector),
threshold=threshold,
min_scene_length=1.0
)
result = self.service.detect_single_video(video_path, config)
results.append({
"detector": detector,
"threshold": threshold,
"success": result.success,
"total_scenes": result.total_scenes,
"detection_time": result.detection_time,
"error": result.error
})
task.finish("检测器比较完成")
# 分析比较结果
comparison_result = {
"video_path": video_path,
"total_tests": total_tests,
"results": results,
"summary": self._analyze_comparison(results)
}
# 保存比较结果
if args.get("output"):
import json
with open(args["output"], 'w', encoding='utf-8') as f:
json.dump(comparison_result, f, indent=2, ensure_ascii=False)
return comparison_result
def _analyze_results(self, args: Dict[str, Any]) -> dict:
"""分析检测结果"""
result_file = args["result_file"]
try:
import json
with open(result_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 重构批量结果对象
from .types import BatchDetectionResult, VideoSceneResult, SceneInfo
results = []
for video_data in data.get("results", []):
scenes = [
SceneInfo(
index=scene["index"],
start_time=scene["start_time"],
end_time=scene["end_time"],
duration=scene["duration"],
confidence=scene.get("confidence", 1.0)
)
for scene in video_data.get("scenes", [])
]
video_result = VideoSceneResult(
video_path=video_data["video_path"],
filename=video_data["filename"],
success=True,
total_scenes=video_data["total_scenes"],
total_duration=video_data["total_duration"],
scenes=scenes,
detection_time=video_data["detection_time"],
detector_type=data.get("config", {}).get("detector_type", "unknown"),
threshold=data.get("config", {}).get("threshold", 0.0)
)
results.append(video_result)
# 创建批量结果对象
batch_result = BatchDetectionResult(
total_files=data["summary"]["total_files"],
processed_files=data["summary"]["processed_files"],
failed_files=data["summary"]["failed_files"],
total_scenes=data["summary"]["total_scenes"],
total_duration=data["summary"]["total_duration"],
average_scenes_per_video=data["summary"]["average_scenes_per_video"],
detection_time=data["summary"]["detection_time"],
results=results,
failed_list=data.get("failed_files", []),
config=BatchDetectionConfig() # 简化配置
)
# 计算统计信息
stats = self.service.calculate_stats(batch_result)
analysis_result = {
"source_file": result_file,
"statistics": asdict(stats),
"summary": data["summary"]
}
# 保存分析结果
if args.get("output"):
with open(args["output"], 'w', encoding='utf-8') as f:
json.dump(analysis_result, f, indent=2, ensure_ascii=False)
return analysis_result
except Exception as e:
raise ValueError(f"Failed to analyze results: {e}")
def _analyze_comparison(self, results: list) -> dict:
"""分析比较结果"""
successful_results = [r for r in results if r["success"]]
if not successful_results:
return {"message": "No successful detections"}
# 按检测器分组
by_detector = {}
for result in successful_results:
detector = result["detector"]
if detector not in by_detector:
by_detector[detector] = []
by_detector[detector].append(result)
# 分析每个检测器的表现
detector_analysis = {}
for detector, detector_results in by_detector.items():
avg_scenes = sum(r["total_scenes"] for r in detector_results) / len(detector_results)
avg_time = sum(r["detection_time"] for r in detector_results) / len(detector_results)
detector_analysis[detector] = {
"average_scenes": avg_scenes,
"average_detection_time": avg_time,
"test_count": len(detector_results)
}
# 找出最佳检测器
best_detector = max(detector_analysis.keys(),
key=lambda d: detector_analysis[d]["average_scenes"])
return {
"total_successful_tests": len(successful_results),
"detector_analysis": detector_analysis,
"best_detector": best_detector,
"recommendation": f"推荐使用 {best_detector} 检测器"
}
def _create_config(self, args: Dict[str, Any]) -> BatchDetectionConfig:
"""创建检测配置"""
return BatchDetectionConfig(
detector_type=DetectorType(args.get("detector", "content")),
threshold=args.get("threshold", 30.0),
min_scene_length=args.get("min_scene_length", 1.0),
adaptive_threshold=args.get("adaptive", False),
generate_thumbnails=args.get("thumbnails", False),
output_format=args.get("format", "json")
)
def main():
"""主函数"""
commander = SceneDetectionCommander()
commander.run()
if __name__ == "__main__":
main()