mixvideo-v2/apps/desktop/src-tauri/src/business/services/video_classification_servic...

372 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::data::models::video_classification::*;
use crate::data::models::ai_classification::generate_full_prompt;
use crate::data::repositories::video_classification_repository::VideoClassificationRepository;
use crate::data::repositories::ai_classification_repository::AiClassificationRepository;
use crate::data::repositories::material_repository::MaterialRepository;
use crate::infrastructure::gemini_service::{GeminiService, GeminiConfig};
use crate::infrastructure::file_system::FileSystemService;
use anyhow::{Result, anyhow};
use std::sync::Arc;
use std::path::Path;
use tokio::sync::Mutex;
use serde_json;
/// AI视频分类业务服务
/// 遵循 Tauri 开发规范的业务层设计模式
pub struct VideoClassificationService {
video_repo: Arc<VideoClassificationRepository>,
ai_classification_repo: Arc<AiClassificationRepository>,
material_repo: Arc<MaterialRepository>,
gemini_service: Arc<Mutex<GeminiService>>,
}
impl VideoClassificationService {
/// 创建新的视频分类服务实例
pub fn new(
video_repo: Arc<VideoClassificationRepository>,
ai_classification_repo: Arc<AiClassificationRepository>,
material_repo: Arc<MaterialRepository>,
gemini_config: Option<GeminiConfig>,
) -> Self {
let gemini_service = Arc::new(Mutex::new(GeminiService::new(gemini_config)));
Self {
video_repo,
ai_classification_repo,
material_repo,
gemini_service,
}
}
/// 为素材创建批量分类任务
pub async fn create_batch_classification_tasks(&self, request: BatchClassificationRequest) -> Result<Vec<VideoClassificationTask>> {
println!("🎬 创建批量分类任务");
println!(" 素材ID: {}", request.material_id);
println!(" 项目ID: {}", request.project_id);
// 获取素材信息
let material = self.material_repo.get_by_id(&request.material_id)?
.ok_or_else(|| anyhow!("素材不存在: {}", request.material_id))?;
println!(" 素材项目ID: {}", material.project_id);
// 验证项目ID是否匹配如果不匹配则尝试修复
if material.project_id != request.project_id {
println!("⚠️ 项目ID不匹配尝试修复: 素材项目ID={}, 请求项目ID={}", material.project_id, request.project_id);
// 验证请求的项目ID是否存在
if let Ok(Some(_)) = self.material_repo.get_project_by_id(&request.project_id).await {
println!("✅ 请求的项目ID有效将使用请求的项目ID进行分类");
// 使用请求的项目ID而不是素材中的项目ID
} else {
return Err(anyhow!("请求的项目ID无效: {}", request.project_id));
}
}
// 获取素材的所有片段
let segments = self.material_repo.get_segments(&request.material_id)?;
if segments.is_empty() {
return Err(anyhow!("素材没有切分片段"));
}
let mut tasks = Vec::new();
for segment in segments {
// 检查是否已经分类(如果不覆盖现有分类)
if !request.overwrite_existing {
if self.video_repo.is_segment_classified(&segment.id).await? {
continue;
}
}
// 检查视频文件是否存在
if !Path::new(&segment.file_path).exists() {
println!("警告: 视频文件不存在,跳过: {}", segment.file_path);
continue;
}
// 创建分类任务
let task = VideoClassificationTask::new(
segment.id.clone(),
request.material_id.clone(),
request.project_id.clone(),
segment.file_path.clone(),
request.priority,
);
let created_task = self.video_repo.create_classification_task(task).await?;
tasks.push(created_task);
}
Ok(tasks)
}
/// 处理单个分类任务
pub async fn process_classification_task(&self, task_id: &str) -> Result<VideoClassificationRecord> {
// 获取任务
let mut task = self.get_task_by_id(task_id).await?;
// 开始处理
task.start_processing();
self.video_repo.update_classification_task(&task).await?;
// 获取AI分类提示词
let prompt = self.generate_classification_prompt().await?;
let result = self.classify_video_with_gemini(&mut task, &prompt).await;
match result {
Ok(record) => {
// 任务完成
task.complete();
self.video_repo.update_classification_task(&task).await?;
// 移动视频文件到分类文件夹
if let Err(e) = self.move_video_to_category_folder(&record).await {
println!("警告: 移动视频文件失败: {}", e);
}
Ok(record)
}
Err(e) => {
// 任务失败
task.fail(e.to_string());
self.video_repo.update_classification_task(&task).await?;
Err(e)
}
}
}
/// 使用Gemini进行视频分类
async fn classify_video_with_gemini(&self, task: &mut VideoClassificationTask, prompt: &str) -> Result<VideoClassificationRecord> {
let mut gemini_service = self.gemini_service.lock().await;
// 调用Gemini API进行分类
let (file_uri, raw_response) = gemini_service.classify_video(&task.video_file_path, prompt).await?;
// 更新任务状态
task.set_analyzing(file_uri.clone(), prompt.to_string());
self.video_repo.update_classification_task(task).await?;
// 解析Gemini响应
let gemini_response = self.parse_gemini_response(&raw_response)?;
// 输出分类结果
println!("🎯 AI分类结果:");
println!(" 📁 视频文件: {}", task.video_file_path);
println!(" 🏷️ 分类类别: {}", gemini_response.category);
println!(" 📊 置信度: {:.1}%", gemini_response.confidence * 100.0);
println!(" ⭐ 质量评分: {:.1}/10", gemini_response.quality_score * 10.0);
println!(" 💭 分类理由: {}", gemini_response.reasoning);
if !gemini_response.features.is_empty() {
println!(" 🔍 识别特征: {}", gemini_response.features.join(", "));
}
// 创建分类记录
let mut record = VideoClassificationRecord::new(
task.segment_id.clone(),
task.material_id.clone(),
task.project_id.clone(),
gemini_response,
Some(file_uri),
Some(raw_response),
);
// 检查是否需要人工审核
if record.needs_review() {
println!(" ⚠️ 需要人工审核 (置信度或质量评分较低)");
record.mark_as_needs_review("置信度或质量评分较低,建议人工审核".to_string());
} else {
println!(" ✅ 分类质量良好");
}
// 保存分类记录
let saved_record = self.video_repo.create_classification_record(record).await?;
Ok(saved_record)
}
/// 解析Gemini响应为结构化数据
fn parse_gemini_response(&self, raw_response: &str) -> Result<GeminiClassificationResponse> {
// 尝试从响应中提取JSON
let json_start = raw_response.find('{');
let json_end = raw_response.rfind('}');
if let (Some(start), Some(end)) = (json_start, json_end) {
let json_str = &raw_response[start..=end];
match serde_json::from_str::<GeminiClassificationResponse>(json_str) {
Ok(response) => Ok(response),
Err(e) => {
// 如果解析失败,创建一个默认响应
Ok(GeminiClassificationResponse {
category: "未分类".to_string(),
confidence: 0.5,
reasoning: format!("AI响应解析失败: {} - 原始JSON: {}", e, json_str),
features: vec!["解析失败".to_string()],
product_match: false,
quality_score: 0.5,
})
}
}
} else {
// 没有找到JSON格式创建默认响应
Ok(GeminiClassificationResponse {
category: "未分类".to_string(),
confidence: 0.3,
reasoning: format!("AI响应格式异常未找到JSON: {}", &raw_response[..std::cmp::min(200, raw_response.len())]),
features: vec!["格式异常".to_string()],
product_match: false,
quality_score: 0.3,
})
}
}
/// 生成分类提示词
async fn generate_classification_prompt(&self) -> Result<String> {
// 获取激活的AI分类
let classifications = self.ai_classification_repo.get_all(Some(
crate::data::models::ai_classification::AiClassificationQuery {
active_only: Some(true),
sort_by: Some("sort_order".to_string()),
sort_order: Some("ASC".to_string()),
..Default::default()
}
)).await?;
if classifications.is_empty() {
return Err(anyhow!("没有激活的AI分类无法生成提示词"));
}
Ok(generate_full_prompt(&classifications))
}
/// 移动视频文件到分类文件夹
async fn move_video_to_category_folder(&self, record: &VideoClassificationRecord) -> Result<()> {
println!("🔍 开始移动文件到分类文件夹");
println!(" 项目ID: {}", record.project_id);
println!(" 分类类别: {}", record.category);
println!(" 片段ID: {}", record.segment_id);
// 获取项目信息
let project = self.material_repo.get_project_by_id(&record.project_id).await?
.ok_or_else(|| anyhow!("项目不存在: {}", record.project_id))?;
// 获取片段信息
let segment = self.material_repo.get_segment_by_id(&record.segment_id).await?
.ok_or_else(|| anyhow!("片段不存在: {}", record.segment_id))?;
// 构建目标目录路径
let category_dir = Path::new(&project.path)
.join("assets")
.join(&record.category);
// 创建分类目录
FileSystemService::create_directory_if_not_exists(category_dir.to_str().unwrap())?;
// 构建目标文件路径
let source_path = Path::new(&segment.file_path);
let file_name = source_path.file_name()
.ok_or_else(|| anyhow!("无法获取文件名"))?;
let target_path = category_dir.join(file_name);
// 移动文件
FileSystemService::move_file(&segment.file_path, target_path.to_str().unwrap())?;
println!("✅ 视频文件已移动到分类文件夹: {} -> {}", segment.file_path, target_path.display());
// 更新片段的文件路径
let new_path = target_path.to_str().unwrap().to_string();
match self.material_repo.update_segment_file_path(&record.segment_id, &new_path).await {
Ok(()) => {
println!("✅ 片段路径已更新: {}", new_path);
}
Err(e) => {
println!("⚠️ 更新片段路径失败: {}", e);
// 注意:文件已经移动,但数据库更新失败
// 这种情况下应该考虑回滚文件移动或者记录错误
}
}
Ok(())
}
/// 获取待处理的任务
pub async fn get_pending_tasks(&self, limit: Option<i32>) -> Result<Vec<VideoClassificationTask>> {
self.video_repo.get_pending_tasks(limit).await
}
/// 获取分类统计信息
pub async fn get_classification_stats(&self, project_id: Option<&str>) -> Result<ClassificationStats> {
self.video_repo.get_classification_stats(project_id).await
}
/// 根据素材ID获取分类记录
pub async fn get_classifications_by_material(&self, material_id: &str) -> Result<Vec<VideoClassificationRecord>> {
self.video_repo.get_by_material_id(material_id).await
}
/// 获取任务详情
async fn get_task_by_id(&self, task_id: &str) -> Result<VideoClassificationTask> {
self.video_repo.get_task_by_id(task_id).await?
.ok_or_else(|| anyhow!("任务不存在: {}", task_id))
}
/// 取消分类任务
pub async fn cancel_task(&self, task_id: &str) -> Result<()> {
self.video_repo.delete_task(task_id).await
}
/// 恢复卡住的任务状态
pub async fn recover_stuck_tasks(&self) -> Result<usize> {
self.video_repo.recover_stuck_tasks().await
}
/// 修复数据库中的日期格式问题
pub async fn fix_date_formats(&self) -> Result<usize> {
self.video_repo.fix_date_formats().await
}
/// 重试失败的任务
pub async fn retry_failed_task(&self, _task_id: &str) -> Result<()> {
// 这里需要实现重试逻辑
// 暂时返回错误,后续实现
Err(anyhow!("重试任务功能待实现"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_gemini_response() {
let service = create_test_service();
let json_response = r#"
这是一些前置文本
{
"category": "全身",
"confidence": 0.85,
"reasoning": "视频显示完整的人体轮廓",
"features": ["全身可见", "清晰度高"],
"product_match": true,
"quality_score": 0.9
}
这是一些后置文本
"#;
let result = service.parse_gemini_response(json_response).unwrap();
assert_eq!(result.category, "全身");
assert_eq!(result.confidence, 0.85);
assert!(result.product_match);
}
fn create_test_service() -> VideoClassificationService {
// 这里需要创建测试用的service实例
// 暂时使用空实现
todo!("实现测试用的service创建")
}
}