mixvideo-v2/apps/desktop/src-tauri/src/business/services/video_classification_servic...

469 lines
18 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::data::models::video_classification::*;
use crate::data::models::ai_classification::generate_full_prompt;
use crate::data::repositories::video_classification_repository::VideoClassificationRepository;
use crate::data::repositories::ai_classification_repository::AiClassificationRepository;
use crate::data::repositories::material_repository::MaterialRepository;
use crate::infrastructure::gemini_service::{GeminiService, GeminiConfig};
use crate::infrastructure::file_system::FileSystemService;
use anyhow::{Result, anyhow};
use std::sync::Arc;
use std::path::Path;
use serde_json;
/// AI视频分类业务服务
/// 遵循 Tauri 开发规范的业务层设计模式
pub struct VideoClassificationService {
video_repo: Arc<VideoClassificationRepository>,
ai_classification_repo: Arc<AiClassificationRepository>,
material_repo: Arc<MaterialRepository>,
gemini_config: Option<GeminiConfig>,
}
impl VideoClassificationService {
/// 创建新的视频分类服务实例
pub fn new(
video_repo: Arc<VideoClassificationRepository>,
ai_classification_repo: Arc<AiClassificationRepository>,
material_repo: Arc<MaterialRepository>,
gemini_config: Option<GeminiConfig>,
) -> Self {
Self {
video_repo,
ai_classification_repo,
material_repo,
gemini_config,
}
}
/// 为素材创建批量分类任务
pub async fn create_batch_classification_tasks(&self, request: BatchClassificationRequest) -> Result<Vec<VideoClassificationTask>> {
println!("🎬 创建批量分类任务");
println!(" 素材ID: {}", request.material_id);
println!(" 项目ID: {}", request.project_id);
// 获取素材信息
let material = self.material_repo.get_by_id(&request.material_id)?
.ok_or_else(|| anyhow!("素材不存在: {}", request.material_id))?;
println!(" 素材项目ID: {}", material.project_id);
// 验证项目ID是否匹配如果不匹配则尝试修复
if material.project_id != request.project_id {
println!("⚠️ 项目ID不匹配尝试修复: 素材项目ID={}, 请求项目ID={}", material.project_id, request.project_id);
// 验证请求的项目ID是否存在
if let Ok(Some(_)) = self.material_repo.get_project_by_id(&request.project_id).await {
println!("✅ 请求的项目ID有效将使用请求的项目ID进行分类");
// 使用请求的项目ID而不是素材中的项目ID
} else {
return Err(anyhow!("请求的项目ID无效: {}", request.project_id));
}
}
// 获取素材的所有片段
let segments = self.material_repo.get_segments(&request.material_id)?;
if segments.is_empty() {
return Err(anyhow!("素材没有切分片段"));
}
let mut tasks = Vec::new();
for segment in segments {
// 检查是否已经分类(如果不覆盖现有分类)
if !request.overwrite_existing {
if self.video_repo.is_segment_classified(&segment.id).await? {
continue;
}
}
// 检查视频文件是否存在
if !Path::new(&segment.file_path).exists() {
println!("警告: 视频文件不存在,跳过: {}", segment.file_path);
continue;
}
// 创建分类任务
let task = VideoClassificationTask::new(
segment.id.clone(),
request.material_id.clone(),
request.project_id.clone(),
segment.file_path.clone(),
request.priority,
);
let created_task = self.video_repo.create_classification_task(task).await?;
tasks.push(created_task);
}
Ok(tasks)
}
/// 为项目创建一键批量分类任务
pub async fn create_project_batch_classification_tasks(&self, request: ProjectBatchClassificationRequest) -> Result<ProjectBatchClassificationResponse> {
println!("🚀 开始项目一键分类");
println!(" 项目ID: {}", request.project_id);
// 验证项目是否存在
let _project = self.material_repo.get_project_by_id(&request.project_id).await?
.ok_or_else(|| anyhow!("项目不存在: {}", request.project_id))?;
// 获取项目所有素材
let all_materials = self.material_repo.get_by_project_id(&request.project_id)?;
let total_materials = all_materials.len() as u32;
println!(" 项目总素材数: {}", total_materials);
// 过滤符合条件的素材
let material_types = request.material_types.unwrap_or_else(|| vec![crate::data::models::material::MaterialType::Video]);
let overwrite_existing = request.overwrite_existing.unwrap_or(false);
let mut eligible_materials = Vec::new();
let mut skipped_materials = Vec::new();
for material in all_materials {
// 检查素材类型
if !material_types.contains(&material.material_type) {
continue;
}
// 检查处理状态 - 只处理已完成处理的素材
if material.processing_status != crate::data::models::material::ProcessingStatus::Completed {
continue;
}
// 获取素材片段
let segments = self.material_repo.get_segments(&material.id)?;
if segments.is_empty() {
continue;
}
// 检查是否已有分类记录
if !overwrite_existing {
let mut has_classification = false;
for segment in &segments {
if self.video_repo.is_segment_classified(&segment.id).await? {
has_classification = true;
break;
}
}
if has_classification {
skipped_materials.push(material.id.clone());
continue;
}
}
eligible_materials.push(material);
}
let eligible_count = eligible_materials.len() as u32;
println!(" 符合条件的素材数: {}", eligible_count);
println!(" 跳过的素材数: {}", skipped_materials.len());
// 为每个符合条件的素材创建批量分类任务
let mut all_task_ids = Vec::new();
let mut created_tasks_count = 0u32;
for material in eligible_materials {
let batch_request = BatchClassificationRequest {
material_id: material.id.clone(),
project_id: request.project_id.clone(),
overwrite_existing,
priority: request.priority,
};
match self.create_batch_classification_tasks(batch_request).await {
Ok(tasks) => {
let task_ids: Vec<String> = tasks.iter().map(|t| t.id.clone()).collect();
created_tasks_count += task_ids.len() as u32;
all_task_ids.extend(task_ids);
println!(" 为素材 {} 创建了 {} 个分类任务", material.name, tasks.len());
}
Err(e) => {
println!(" 为素材 {} 创建分类任务失败: {}", material.name, e);
// 继续处理其他素材,不因单个素材失败而中断整个流程
}
}
}
println!("✅ 项目一键分类任务创建完成");
println!(" 总共创建任务数: {}", created_tasks_count);
Ok(ProjectBatchClassificationResponse {
total_materials,
eligible_materials: eligible_count,
created_tasks: created_tasks_count,
task_ids: all_task_ids,
skipped_materials,
})
}
/// 处理单个分类任务
pub async fn process_classification_task(&self, task_id: &str) -> Result<VideoClassificationRecord> {
// 获取任务
let mut task = self.get_task_by_id(task_id).await?;
// 开始处理
task.start_processing();
self.video_repo.update_classification_task(&task).await?;
// 获取AI分类提示词
let prompt = self.generate_classification_prompt().await?;
let result = self.classify_video_with_gemini(&mut task, &prompt).await;
match result {
Ok(record) => {
// 任务完成
task.complete();
self.video_repo.update_classification_task(&task).await?;
// 移动视频文件到分类文件夹
if let Err(e) = self.move_video_to_category_folder(&record).await {
println!("警告: 移动视频文件失败: {}", e);
}
Ok(record)
}
Err(e) => {
// 任务失败
task.fail(e.to_string());
self.video_repo.update_classification_task(&task).await?;
Err(e)
}
}
}
/// 使用Gemini进行视频分类
async fn classify_video_with_gemini(&self, task: &mut VideoClassificationTask, prompt: &str) -> Result<VideoClassificationRecord> {
// 为每个任务创建独立的GeminiService实例避免并发瓶颈
let mut gemini_service = GeminiService::new(self.gemini_config.clone());
// 调用Gemini API进行分类
let (file_uri, raw_response) = gemini_service.classify_video(&task.video_file_path, prompt).await?;
// 更新任务状态
task.set_analyzing(file_uri.clone(), prompt.to_string());
self.video_repo.update_classification_task(task).await?;
// 解析Gemini响应
let gemini_response = self.parse_gemini_response(&raw_response)?;
// 输出分类结果
println!("🎯 AI分类结果:");
println!(" 📁 视频文件: {}", task.video_file_path);
println!(" 🏷️ 分类类别: {}", gemini_response.category);
println!(" 📊 置信度: {:.1}%", gemini_response.confidence * 100.0);
println!(" ⭐ 质量评分: {:.1}/10", gemini_response.quality_score * 10.0);
println!(" 💭 分类理由: {}", gemini_response.reasoning);
if !gemini_response.features.is_empty() {
println!(" 🔍 识别特征: {}", gemini_response.features.join(", "));
}
// 创建分类记录
let mut record = VideoClassificationRecord::new(
task.segment_id.clone(),
task.material_id.clone(),
task.project_id.clone(),
gemini_response,
Some(file_uri),
Some(raw_response),
);
// 检查是否需要人工审核
if record.needs_review() {
println!(" ⚠️ 需要人工审核 (置信度或质量评分较低)");
record.mark_as_needs_review("置信度或质量评分较低,建议人工审核".to_string());
} else {
println!(" ✅ 分类质量良好");
}
// 保存分类记录
let saved_record = self.video_repo.create_classification_record(record).await?;
Ok(saved_record)
}
/// 解析Gemini响应为结构化数据
fn parse_gemini_response(&self, raw_response: &str) -> Result<GeminiClassificationResponse> {
// 尝试从响应中提取JSON
let json_start = raw_response.find('{');
let json_end = raw_response.rfind('}');
if let (Some(start), Some(end)) = (json_start, json_end) {
let json_str = &raw_response[start..=end];
match serde_json::from_str::<GeminiClassificationResponse>(json_str) {
Ok(response) => Ok(response),
Err(e) => {
// 如果解析失败,创建一个默认响应
Ok(GeminiClassificationResponse {
category: "未分类".to_string(),
confidence: 0.5,
reasoning: format!("AI响应解析失败: {} - 原始JSON: {}", e, json_str),
features: vec!["解析失败".to_string()],
product_match: false,
quality_score: 0.5,
})
}
}
} else {
// 没有找到JSON格式创建默认响应
Ok(GeminiClassificationResponse {
category: "未分类".to_string(),
confidence: 0.3,
reasoning: format!("AI响应格式异常未找到JSON: {}", &raw_response[..std::cmp::min(200, raw_response.len())]),
features: vec!["格式异常".to_string()],
product_match: false,
quality_score: 0.3,
})
}
}
/// 生成分类提示词
async fn generate_classification_prompt(&self) -> Result<String> {
// 获取激活的AI分类
let classifications = self.ai_classification_repo.get_all(Some(
crate::data::models::ai_classification::AiClassificationQuery {
active_only: Some(true),
sort_by: Some("sort_order".to_string()),
sort_order: Some("ASC".to_string()),
..Default::default()
}
)).await?;
if classifications.is_empty() {
return Err(anyhow!("没有激活的AI分类无法生成提示词"));
}
Ok(generate_full_prompt(&classifications))
}
/// 移动视频文件到分类文件夹
async fn move_video_to_category_folder(&self, record: &VideoClassificationRecord) -> Result<()> {
println!("🔍 开始移动文件到分类文件夹");
println!(" 项目ID: {}", record.project_id);
println!(" 分类类别: {}", record.category);
println!(" 片段ID: {}", record.segment_id);
// 获取项目信息
let project = self.material_repo.get_project_by_id(&record.project_id).await?
.ok_or_else(|| anyhow!("项目不存在: {}", record.project_id))?;
// 获取片段信息
let segment = self.material_repo.get_segment_by_id(&record.segment_id).await?
.ok_or_else(|| anyhow!("片段不存在: {}", record.segment_id))?;
// 构建目标目录路径
let category_dir = Path::new(&project.path)
.join("assets")
.join(&record.category);
// 创建分类目录
FileSystemService::create_directory_if_not_exists(category_dir.to_str().unwrap())?;
// 构建目标文件路径
let source_path = Path::new(&segment.file_path);
let file_name = source_path.file_name()
.ok_or_else(|| anyhow!("无法获取文件名"))?;
let target_path = category_dir.join(file_name);
// 移动文件
FileSystemService::move_file(&segment.file_path, target_path.to_str().unwrap())?;
println!("✅ 视频文件已移动到分类文件夹: {} -> {}", segment.file_path, target_path.display());
// 更新片段的文件路径
let new_path = target_path.to_str().unwrap().to_string();
match self.material_repo.update_segment_file_path(&record.segment_id, &new_path).await {
Ok(()) => {
println!("✅ 片段路径已更新: {}", new_path);
}
Err(e) => {
println!("⚠️ 更新片段路径失败: {}", e);
// 注意:文件已经移动,但数据库更新失败
// 这种情况下应该考虑回滚文件移动或者记录错误
}
}
Ok(())
}
/// 获取待处理的任务
pub async fn get_pending_tasks(&self, limit: Option<i32>) -> Result<Vec<VideoClassificationTask>> {
self.video_repo.get_pending_tasks(limit).await
}
/// 获取分类统计信息
pub async fn get_classification_stats(&self, project_id: Option<&str>) -> Result<ClassificationStats> {
self.video_repo.get_classification_stats(project_id).await
}
/// 根据素材ID获取分类记录
pub async fn get_classifications_by_material(&self, material_id: &str) -> Result<Vec<VideoClassificationRecord>> {
self.video_repo.get_by_material_id(material_id).await
}
/// 获取任务详情
async fn get_task_by_id(&self, task_id: &str) -> Result<VideoClassificationTask> {
self.video_repo.get_task_by_id(task_id).await?
.ok_or_else(|| anyhow!("任务不存在: {}", task_id))
}
/// 取消分类任务
pub async fn cancel_task(&self, task_id: &str) -> Result<()> {
self.video_repo.delete_task(task_id).await
}
/// 恢复卡住的任务状态
pub async fn recover_stuck_tasks(&self) -> Result<usize> {
self.video_repo.recover_stuck_tasks().await
}
/// 修复数据库中的日期格式问题
pub async fn fix_date_formats(&self) -> Result<usize> {
self.video_repo.fix_date_formats().await
}
/// 重试失败的任务
pub async fn retry_failed_task(&self, _task_id: &str) -> Result<()> {
// 这里需要实现重试逻辑
// 暂时返回错误,后续实现
Err(anyhow!("重试任务功能待实现"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_gemini_response() {
let service = create_test_service();
let json_response = r#"
这是一些前置文本
{
"category": "全身",
"confidence": 0.85,
"reasoning": "视频显示完整的人体轮廓",
"features": ["全身可见", "清晰度高"],
"product_match": true,
"quality_score": 0.9
}
这是一些后置文本
"#;
let result = service.parse_gemini_response(json_response).unwrap();
assert_eq!(result.category, "全身");
assert_eq!(result.confidence, 0.85);
assert!(result.product_match);
}
fn create_test_service() -> VideoClassificationService {
// 这里需要创建测试用的service实例
// 暂时使用空实现
todo!("实现测试用的service创建")
}
}