mixvideo-v2/apps/desktop/src-tauri/src/presentation/commands/material_matching_commands.rs

328 lines
12 KiB
Rust

/**
* 素材匹配相关的 Tauri 命令
* 遵循 Tauri 开发规范的 API 设计原则
*/
use tauri::{command, State};
use std::sync::Arc;
use crate::business::services::material_matching_service::{
MaterialMatchingService, MaterialMatchingRequest, MaterialMatchingResult,
BatchMatchingRequest, BatchMatchingResult,
};
use crate::business::services::template_service::TemplateService;
use crate::business::services::template_matching_result_service::TemplateMatchingResultService;
use crate::data::repositories::{
material_repository::MaterialRepository,
video_classification_repository::VideoClassificationRepository,
template_matching_result_repository::TemplateMatchingResultRepository,
};
use crate::data::models::template_matching_result::TemplateMatchingResult;
use crate::infrastructure::database::Database;
/// 执行素材匹配
#[command]
pub async fn execute_material_matching(
request: MaterialMatchingRequest,
database: State<'_, Arc<Database>>,
) -> Result<MaterialMatchingResult, String> {
// 创建服务实例
let material_repo = Arc::new(
MaterialRepository::new(database.inner().clone())
.map_err(|e| format!("创建素材仓库失败: {}", e))?
);
let template_service = Arc::new(TemplateService::new(database.inner().clone()));
let video_classification_repo = Arc::new(
VideoClassificationRepository::new(database.inner().clone())
);
let material_usage_repo = Arc::new(
crate::data::repositories::material_usage_repository::MaterialUsageRepository::new(database.inner().clone())
);
let ai_classification_repo = Arc::new(
crate::data::repositories::ai_classification_repository::AiClassificationRepository::new(database.inner().clone())
);
let ai_classification_service = Arc::new(
crate::business::services::ai_classification_service::AiClassificationService::new(ai_classification_repo)
);
let matching_service = MaterialMatchingService::new(
material_repo,
material_usage_repo,
template_service,
video_classification_repo,
ai_classification_service,
);
// 执行匹配
matching_service.match_materials(request)
.await
.map_err(|e| e.to_string())
}
/// 执行素材匹配并自动保存结果
#[command]
pub async fn execute_material_matching_with_save(
request: MaterialMatchingRequest,
result_name: String,
description: Option<String>,
database: State<'_, Arc<Database>>,
) -> Result<(MaterialMatchingResult, Option<TemplateMatchingResult>), String> {
// 创建服务实例
let material_repo = Arc::new(
MaterialRepository::new(database.inner().clone())
.map_err(|e| format!("创建素材仓库失败: {}", e))?
);
let template_service = Arc::new(TemplateService::new(database.inner().clone()));
let video_classification_repo = Arc::new(
VideoClassificationRepository::new(database.inner().clone())
);
// 创建匹配结果服务
let material_usage_repo = Arc::new(
crate::data::repositories::material_usage_repository::MaterialUsageRepository::new(database.inner().clone())
);
let matching_result_repo = Arc::new(TemplateMatchingResultRepository::new(database.inner().clone()));
let matching_result_service = Arc::new(TemplateMatchingResultService::new(matching_result_repo));
let ai_classification_repo = Arc::new(
crate::data::repositories::ai_classification_repository::AiClassificationRepository::new(database.inner().clone())
);
let ai_classification_service = Arc::new(
crate::business::services::ai_classification_service::AiClassificationService::new(ai_classification_repo)
);
let matching_service = MaterialMatchingService::new_with_result_service(
material_repo,
material_usage_repo,
template_service,
video_classification_repo,
ai_classification_service,
matching_result_service,
);
// 执行匹配并保存结果
matching_service.match_materials_and_save(request, result_name, description)
.await
.map_err(|e| e.to_string())
}
/// 获取项目的可用素材统计信息
#[command]
pub async fn get_project_material_stats_for_matching(
project_id: String,
database: State<'_, Arc<Database>>,
) -> Result<ProjectMaterialMatchingStats, String> {
let material_repo = MaterialRepository::new(database.inner().clone())
.map_err(|e| format!("创建素材仓库失败: {}", e))?;
let video_classification_repo = VideoClassificationRepository::new(database.inner().clone());
// 获取项目的所有素材
let materials = material_repo.get_by_project_id(&project_id)
.map_err(|e| format!("获取项目素材失败: {}", e))?;
let mut total_segments = 0;
let mut classified_segments = 0;
let mut available_models = std::collections::HashSet::new();
let mut available_categories = std::collections::HashSet::new();
for material in &materials {
total_segments += material.segments.len();
// 获取分类记录
let classification_records = video_classification_repo.get_by_material_id(&material.id)
.await
.map_err(|e| format!("获取分类记录失败: {}", e))?;
// 统计已分类的片段
for segment in &material.segments {
if classification_records.iter().any(|r| r.segment_id == segment.id) {
classified_segments += 1;
// 记录分类类别
if let Some(record) = classification_records.iter().find(|r| r.segment_id == segment.id) {
available_categories.insert(record.category.clone());
}
}
}
// 记录模特
if let Some(model_id) = &material.model_id {
available_models.insert(model_id.clone());
}
}
Ok(ProjectMaterialMatchingStats {
project_id,
total_materials: materials.len() as u32,
total_segments: total_segments as u32,
classified_segments: classified_segments as u32,
available_models: available_models.len() as u32,
available_categories: available_categories.into_iter().collect(),
classification_rate: if total_segments > 0 {
classified_segments as f64 / total_segments as f64
} else {
0.0
},
})
}
/// 验证模板绑定是否可以进行素材匹配
#[command]
pub async fn validate_template_binding_for_matching(
binding_id: String,
database: State<'_, Arc<Database>>,
) -> Result<TemplateBindingMatchingValidation, String> {
use crate::data::repositories::project_template_binding_repository::ProjectTemplateBindingRepository;
use crate::business::services::template_service::TemplateService;
let mut validation_errors = Vec::new();
let mut total_segments = 0;
let mut matchable_segments = 0;
// 获取模板绑定信息
let binding_repo = ProjectTemplateBindingRepository::new(database.inner().clone());
let binding = match binding_repo.get_by_id(&binding_id) {
Ok(Some(binding)) => binding,
Ok(None) => {
validation_errors.push("模板绑定不存在".to_string());
return Ok(TemplateBindingMatchingValidation {
binding_id,
is_valid: false,
validation_errors,
total_segments: 0,
matchable_segments: 0,
});
}
Err(e) => {
validation_errors.push(format!("获取模板绑定失败: {}", e));
return Ok(TemplateBindingMatchingValidation {
binding_id,
is_valid: false,
validation_errors,
total_segments: 0,
matchable_segments: 0,
});
}
};
// 检查绑定是否激活
if !binding.is_active {
validation_errors.push("模板绑定未激活".to_string());
}
// 获取模板信息并统计片段
let template_service = TemplateService::new(database.inner().clone());
match template_service.get_template_by_id(&binding.template_id).await {
Ok(Some(template)) => {
// 统计所有轨道片段
for track in &template.tracks {
total_segments += track.segments.len() as u32;
// 统计可匹配的片段(非固定素材)
for segment in &track.segments {
if !segment.matching_rule.is_fixed_material() {
matchable_segments += 1;
}
}
}
}
Ok(None) => {
validation_errors.push("关联的模板不存在".to_string());
}
Err(e) => {
validation_errors.push(format!("获取模板信息失败: {}", e));
}
}
let is_valid = validation_errors.is_empty();
Ok(TemplateBindingMatchingValidation {
binding_id,
is_valid,
validation_errors,
total_segments,
matchable_segments,
})
}
/// 项目素材匹配统计信息
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct ProjectMaterialMatchingStats {
pub project_id: String,
pub total_materials: u32,
pub total_segments: u32,
pub classified_segments: u32,
pub available_models: u32,
pub available_categories: Vec<String>,
pub classification_rate: f64,
}
/// 模板绑定匹配验证结果
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct TemplateBindingMatchingValidation {
pub binding_id: String,
pub is_valid: bool,
pub validation_errors: Vec<String>,
pub total_segments: u32,
pub matchable_segments: u32,
}
/// 一键匹配所有模板
/// 遍历项目的所有活跃模板绑定并逐一执行匹配
#[tauri::command]
pub async fn batch_match_all_templates(
request: BatchMatchingRequest,
state: State<'_, crate::app_state::AppState>,
app_handle: tauri::AppHandle,
) -> Result<BatchMatchingResult, String> {
let database = state.get_database();
// 创建所需的仓储实例
let material_repo = Arc::new(
MaterialRepository::new(database.clone())
.map_err(|e| format!("创建素材仓储失败: {}", e))?
);
let template_service = Arc::new(TemplateService::new(database.clone()));
let video_classification_repo = Arc::new(
VideoClassificationRepository::new(database.clone())
);
// 创建匹配结果服务
let material_usage_repo = Arc::new(
crate::data::repositories::material_usage_repository::MaterialUsageRepository::new(database.clone())
);
let matching_result_repo = Arc::new(TemplateMatchingResultRepository::new(database.clone()));
let matching_result_service = Arc::new(TemplateMatchingResultService::new(matching_result_repo));
let ai_classification_repo = Arc::new(
crate::data::repositories::ai_classification_repository::AiClassificationRepository::new(database.clone())
);
let ai_classification_service = Arc::new(
crate::business::services::ai_classification_service::AiClassificationService::new(ai_classification_repo)
);
let matching_service = MaterialMatchingService::new_with_result_service(
material_repo,
material_usage_repo,
template_service,
video_classification_repo,
ai_classification_service,
matching_result_service,
);
// 执行一键匹配(带事件发送)
matching_service.batch_match_all_templates_with_events(request, database, Some(app_handle))
.await
.map_err(|e| e.to_string())
}