353 lines
14 KiB
Python
353 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
场景检测命令行接口
|
|
"""
|
|
|
|
import os
|
|
from typing import Dict, Any
|
|
from dataclasses import asdict
|
|
|
|
from .types import DetectorType, BatchDetectionConfig
|
|
from .detector import SceneDetectionService
|
|
from python_core.utils.progress import ProgressJSONRPCCommander
|
|
|
|
class SceneDetectionCommander(ProgressJSONRPCCommander):
|
|
"""场景检测命令行接口 - 支持进度条"""
|
|
|
|
def __init__(self):
|
|
super().__init__("scene_detection")
|
|
self.service = SceneDetectionService()
|
|
|
|
def _register_commands(self) -> None:
|
|
"""注册命令"""
|
|
# 单个视频检测
|
|
self.register_command(
|
|
name="detect",
|
|
description="检测单个视频的场景",
|
|
required_args=["video_path"],
|
|
optional_args={
|
|
"detector": {"type": str, "default": "content", "choices": ["content", "threshold", "adaptive"], "description": "检测器类型"},
|
|
"threshold": {"type": float, "default": 30.0, "description": "检测阈值"},
|
|
"min_scene_length": {"type": float, "default": 1.0, "description": "最小场景长度(秒)"},
|
|
"output": {"type": str, "description": "输出文件路径"},
|
|
"format": {"type": str, "default": "json", "choices": ["json", "csv", "txt"], "description": "输出格式"}
|
|
}
|
|
)
|
|
|
|
# 批量检测
|
|
self.register_command(
|
|
name="batch_detect",
|
|
description="批量检测目录中所有视频的场景",
|
|
required_args=["input_directory"],
|
|
optional_args={
|
|
"detector": {"type": str, "default": "content", "choices": ["content", "threshold", "adaptive"], "description": "检测器类型"},
|
|
"threshold": {"type": float, "default": 30.0, "description": "检测阈值"},
|
|
"min_scene_length": {"type": float, "default": 1.0, "description": "最小场景长度(秒)"},
|
|
"output": {"type": str, "description": "输出文件路径"},
|
|
"format": {"type": str, "default": "json", "choices": ["json", "csv", "txt"], "description": "输出格式"},
|
|
"adaptive": {"type": bool, "default": False, "description": "启用自适应阈值"},
|
|
"thumbnails": {"type": bool, "default": False, "description": "生成缩略图"}
|
|
}
|
|
)
|
|
|
|
# 分析结果
|
|
self.register_command(
|
|
name="analyze",
|
|
description="分析检测结果并生成统计信息",
|
|
required_args=["result_file"],
|
|
optional_args={
|
|
"output": {"type": str, "description": "统计输出文件路径"}
|
|
}
|
|
)
|
|
|
|
# 比较检测器
|
|
self.register_command(
|
|
name="compare",
|
|
description="比较不同检测器的效果",
|
|
required_args=["video_path"],
|
|
optional_args={
|
|
"thresholds": {"type": str, "default": "20,30,40", "description": "测试阈值列表(逗号分隔)"},
|
|
"output": {"type": str, "description": "比较结果输出文件"}
|
|
}
|
|
)
|
|
|
|
def _is_progressive_command(self, command: str) -> bool:
|
|
"""判断是否需要进度报告的命令"""
|
|
# 批量操作和比较操作需要进度报告
|
|
return command in ["batch_detect", "compare"]
|
|
|
|
def _execute_with_progress(self, command: str, args: Dict[str, Any]) -> Any:
|
|
"""执行带进度的命令"""
|
|
if command == "batch_detect":
|
|
return self._batch_detect_with_progress(args)
|
|
elif command == "compare":
|
|
return self._compare_with_progress(args)
|
|
else:
|
|
raise ValueError(f"Unknown progressive command: {command}")
|
|
|
|
def _execute_simple_command(self, command: str, args: Dict[str, Any]) -> Any:
|
|
"""执行简单命令(不需要进度)"""
|
|
if command == "detect":
|
|
return self._detect_single_video(args)
|
|
elif command == "analyze":
|
|
return self._analyze_results(args)
|
|
else:
|
|
raise ValueError(f"Unknown command: {command}")
|
|
|
|
def _detect_single_video(self, args: Dict[str, Any]) -> dict:
|
|
"""检测单个视频"""
|
|
config = self._create_config(args)
|
|
|
|
result = self.service.detect_single_video(args["video_path"], config)
|
|
|
|
# 保存结果(如果指定了输出路径)
|
|
if args.get("output"):
|
|
output_path = args["output"]
|
|
# 创建临时批量结果来使用保存功能
|
|
from .types import BatchDetectionResult
|
|
batch_result = BatchDetectionResult(
|
|
total_files=1,
|
|
processed_files=1 if result.success else 0,
|
|
failed_files=0 if result.success else 1,
|
|
total_scenes=result.total_scenes,
|
|
total_duration=result.total_duration,
|
|
average_scenes_per_video=result.total_scenes,
|
|
detection_time=result.detection_time,
|
|
results=[result] if result.success else [],
|
|
failed_list=[] if result.success else [{"filename": result.filename, "error": result.error}],
|
|
config=config
|
|
)
|
|
|
|
self.service.save_results(batch_result, output_path)
|
|
|
|
return asdict(result)
|
|
|
|
def _batch_detect_with_progress(self, args: Dict[str, Any]) -> dict:
|
|
"""带进度的批量检测"""
|
|
config = self._create_config(args)
|
|
input_directory = args["input_directory"]
|
|
|
|
# 先扫描文件数量
|
|
video_files = self.service._scan_video_files(input_directory)
|
|
|
|
if not video_files:
|
|
return {
|
|
"total_files": 0,
|
|
"processed_files": 0,
|
|
"failed_files": 0,
|
|
"message": "No video files found in directory"
|
|
}
|
|
|
|
# 使用进度任务
|
|
with self.create_task("批量场景检测", len(video_files)) as task:
|
|
def progress_callback(message: str):
|
|
# 从消息中提取进度信息
|
|
if "(" in message and "/" in message:
|
|
# 提取 (x/y) 格式的进度
|
|
try:
|
|
progress_part = message.split("(")[1].split(")")[0]
|
|
current, total = progress_part.split("/")
|
|
task.update(int(current) - 1, message)
|
|
except:
|
|
task.update(message=message)
|
|
else:
|
|
task.update(message=message)
|
|
|
|
# 执行批量检测
|
|
result = self.service.batch_detect_scenes(
|
|
input_directory, config, progress_callback
|
|
)
|
|
|
|
# 保存结果(如果指定了输出路径)
|
|
if args.get("output"):
|
|
self.service.save_results(result, args["output"])
|
|
|
|
task.finish(f"批量检测完成: {result.processed_files} 成功, {result.failed_files} 失败")
|
|
|
|
return asdict(result)
|
|
|
|
def _compare_with_progress(self, args: Dict[str, Any]) -> dict:
|
|
"""带进度的检测器比较"""
|
|
video_path = args["video_path"]
|
|
thresholds_str = args.get("thresholds", "20,30,40")
|
|
|
|
try:
|
|
thresholds = [float(t.strip()) for t in thresholds_str.split(",")]
|
|
except ValueError:
|
|
raise ValueError("Invalid thresholds format. Use comma-separated numbers like '20,30,40'")
|
|
|
|
detectors = ["content", "threshold", "adaptive"]
|
|
total_tests = len(detectors) * len(thresholds)
|
|
|
|
with self.create_task("比较检测器", total_tests) as task:
|
|
results = []
|
|
test_count = 0
|
|
|
|
for detector in detectors:
|
|
for threshold in thresholds:
|
|
test_count += 1
|
|
task.update(test_count - 1, f"测试 {detector} 检测器 (阈值: {threshold})")
|
|
|
|
config = BatchDetectionConfig(
|
|
detector_type=DetectorType(detector),
|
|
threshold=threshold,
|
|
min_scene_length=1.0
|
|
)
|
|
|
|
result = self.service.detect_single_video(video_path, config)
|
|
|
|
results.append({
|
|
"detector": detector,
|
|
"threshold": threshold,
|
|
"success": result.success,
|
|
"total_scenes": result.total_scenes,
|
|
"detection_time": result.detection_time,
|
|
"error": result.error
|
|
})
|
|
|
|
task.finish("检测器比较完成")
|
|
|
|
# 分析比较结果
|
|
comparison_result = {
|
|
"video_path": video_path,
|
|
"total_tests": total_tests,
|
|
"results": results,
|
|
"summary": self._analyze_comparison(results)
|
|
}
|
|
|
|
# 保存比较结果
|
|
if args.get("output"):
|
|
import json
|
|
with open(args["output"], 'w', encoding='utf-8') as f:
|
|
json.dump(comparison_result, f, indent=2, ensure_ascii=False)
|
|
|
|
return comparison_result
|
|
|
|
def _analyze_results(self, args: Dict[str, Any]) -> dict:
|
|
"""分析检测结果"""
|
|
result_file = args["result_file"]
|
|
|
|
try:
|
|
import json
|
|
with open(result_file, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
# 重构批量结果对象
|
|
from .types import BatchDetectionResult, VideoSceneResult, SceneInfo
|
|
|
|
results = []
|
|
for video_data in data.get("results", []):
|
|
scenes = [
|
|
SceneInfo(
|
|
index=scene["index"],
|
|
start_time=scene["start_time"],
|
|
end_time=scene["end_time"],
|
|
duration=scene["duration"],
|
|
confidence=scene.get("confidence", 1.0)
|
|
)
|
|
for scene in video_data.get("scenes", [])
|
|
]
|
|
|
|
video_result = VideoSceneResult(
|
|
video_path=video_data["video_path"],
|
|
filename=video_data["filename"],
|
|
success=True,
|
|
total_scenes=video_data["total_scenes"],
|
|
total_duration=video_data["total_duration"],
|
|
scenes=scenes,
|
|
detection_time=video_data["detection_time"],
|
|
detector_type=data.get("config", {}).get("detector_type", "unknown"),
|
|
threshold=data.get("config", {}).get("threshold", 0.0)
|
|
)
|
|
results.append(video_result)
|
|
|
|
# 创建批量结果对象
|
|
batch_result = BatchDetectionResult(
|
|
total_files=data["summary"]["total_files"],
|
|
processed_files=data["summary"]["processed_files"],
|
|
failed_files=data["summary"]["failed_files"],
|
|
total_scenes=data["summary"]["total_scenes"],
|
|
total_duration=data["summary"]["total_duration"],
|
|
average_scenes_per_video=data["summary"]["average_scenes_per_video"],
|
|
detection_time=data["summary"]["detection_time"],
|
|
results=results,
|
|
failed_list=data.get("failed_files", []),
|
|
config=BatchDetectionConfig() # 简化配置
|
|
)
|
|
|
|
# 计算统计信息
|
|
stats = self.service.calculate_stats(batch_result)
|
|
|
|
analysis_result = {
|
|
"source_file": result_file,
|
|
"statistics": asdict(stats),
|
|
"summary": data["summary"]
|
|
}
|
|
|
|
# 保存分析结果
|
|
if args.get("output"):
|
|
with open(args["output"], 'w', encoding='utf-8') as f:
|
|
json.dump(analysis_result, f, indent=2, ensure_ascii=False)
|
|
|
|
return analysis_result
|
|
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to analyze results: {e}")
|
|
|
|
def _analyze_comparison(self, results: list) -> dict:
|
|
"""分析比较结果"""
|
|
successful_results = [r for r in results if r["success"]]
|
|
|
|
if not successful_results:
|
|
return {"message": "No successful detections"}
|
|
|
|
# 按检测器分组
|
|
by_detector = {}
|
|
for result in successful_results:
|
|
detector = result["detector"]
|
|
if detector not in by_detector:
|
|
by_detector[detector] = []
|
|
by_detector[detector].append(result)
|
|
|
|
# 分析每个检测器的表现
|
|
detector_analysis = {}
|
|
for detector, detector_results in by_detector.items():
|
|
avg_scenes = sum(r["total_scenes"] for r in detector_results) / len(detector_results)
|
|
avg_time = sum(r["detection_time"] for r in detector_results) / len(detector_results)
|
|
|
|
detector_analysis[detector] = {
|
|
"average_scenes": avg_scenes,
|
|
"average_detection_time": avg_time,
|
|
"test_count": len(detector_results)
|
|
}
|
|
|
|
# 找出最佳检测器
|
|
best_detector = max(detector_analysis.keys(),
|
|
key=lambda d: detector_analysis[d]["average_scenes"])
|
|
|
|
return {
|
|
"total_successful_tests": len(successful_results),
|
|
"detector_analysis": detector_analysis,
|
|
"best_detector": best_detector,
|
|
"recommendation": f"推荐使用 {best_detector} 检测器"
|
|
}
|
|
|
|
def _create_config(self, args: Dict[str, Any]) -> BatchDetectionConfig:
|
|
"""创建检测配置"""
|
|
return BatchDetectionConfig(
|
|
detector_type=DetectorType(args.get("detector", "content")),
|
|
threshold=args.get("threshold", 30.0),
|
|
min_scene_length=args.get("min_scene_length", 1.0),
|
|
adaptive_threshold=args.get("adaptive", False),
|
|
generate_thumbnails=args.get("thumbnails", False),
|
|
output_format=args.get("format", "json")
|
|
)
|
|
|
|
def main():
|
|
"""主函数"""
|
|
commander = SceneDetectionCommander()
|
|
commander.run()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|