469 lines
18 KiB
Rust
469 lines
18 KiB
Rust
use crate::data::models::video_classification::*;
|
||
use crate::data::models::ai_classification::generate_full_prompt;
|
||
use crate::data::repositories::video_classification_repository::VideoClassificationRepository;
|
||
use crate::data::repositories::ai_classification_repository::AiClassificationRepository;
|
||
use crate::data::repositories::material_repository::MaterialRepository;
|
||
use crate::infrastructure::gemini_service::{GeminiService, GeminiConfig};
|
||
use crate::infrastructure::file_system::FileSystemService;
|
||
use anyhow::{Result, anyhow};
|
||
use std::sync::Arc;
|
||
use std::path::Path;
|
||
use serde_json;
|
||
|
||
/// AI视频分类业务服务
|
||
/// 遵循 Tauri 开发规范的业务层设计模式
|
||
pub struct VideoClassificationService {
|
||
video_repo: Arc<VideoClassificationRepository>,
|
||
ai_classification_repo: Arc<AiClassificationRepository>,
|
||
material_repo: Arc<MaterialRepository>,
|
||
gemini_config: Option<GeminiConfig>,
|
||
}
|
||
|
||
impl VideoClassificationService {
|
||
/// 创建新的视频分类服务实例
|
||
pub fn new(
|
||
video_repo: Arc<VideoClassificationRepository>,
|
||
ai_classification_repo: Arc<AiClassificationRepository>,
|
||
material_repo: Arc<MaterialRepository>,
|
||
gemini_config: Option<GeminiConfig>,
|
||
) -> Self {
|
||
Self {
|
||
video_repo,
|
||
ai_classification_repo,
|
||
material_repo,
|
||
gemini_config,
|
||
}
|
||
}
|
||
|
||
/// 为素材创建批量分类任务
|
||
pub async fn create_batch_classification_tasks(&self, request: BatchClassificationRequest) -> Result<Vec<VideoClassificationTask>> {
|
||
println!("🎬 创建批量分类任务");
|
||
println!(" 素材ID: {}", request.material_id);
|
||
println!(" 项目ID: {}", request.project_id);
|
||
|
||
// 获取素材信息
|
||
let material = self.material_repo.get_by_id(&request.material_id)?
|
||
.ok_or_else(|| anyhow!("素材不存在: {}", request.material_id))?;
|
||
|
||
println!(" 素材项目ID: {}", material.project_id);
|
||
|
||
// 验证项目ID是否匹配,如果不匹配则尝试修复
|
||
if material.project_id != request.project_id {
|
||
println!("⚠️ 项目ID不匹配,尝试修复: 素材项目ID={}, 请求项目ID={}", material.project_id, request.project_id);
|
||
|
||
// 验证请求的项目ID是否存在
|
||
if let Ok(Some(_)) = self.material_repo.get_project_by_id(&request.project_id).await {
|
||
println!("✅ 请求的项目ID有效,将使用请求的项目ID进行分类");
|
||
// 使用请求的项目ID,而不是素材中的项目ID
|
||
} else {
|
||
return Err(anyhow!("请求的项目ID无效: {}", request.project_id));
|
||
}
|
||
}
|
||
|
||
// 获取素材的所有片段
|
||
let segments = self.material_repo.get_segments(&request.material_id)?;
|
||
|
||
if segments.is_empty() {
|
||
return Err(anyhow!("素材没有切分片段"));
|
||
}
|
||
|
||
let mut tasks = Vec::new();
|
||
|
||
for segment in segments {
|
||
// 检查是否已经分类(如果不覆盖现有分类)
|
||
if !request.overwrite_existing {
|
||
if self.video_repo.is_segment_classified(&segment.id).await? {
|
||
continue;
|
||
}
|
||
}
|
||
|
||
// 检查视频文件是否存在
|
||
if !Path::new(&segment.file_path).exists() {
|
||
println!("警告: 视频文件不存在,跳过: {}", segment.file_path);
|
||
continue;
|
||
}
|
||
|
||
// 创建分类任务
|
||
let task = VideoClassificationTask::new(
|
||
segment.id.clone(),
|
||
request.material_id.clone(),
|
||
request.project_id.clone(),
|
||
segment.file_path.clone(),
|
||
request.priority,
|
||
);
|
||
|
||
let created_task = self.video_repo.create_classification_task(task).await?;
|
||
tasks.push(created_task);
|
||
}
|
||
|
||
Ok(tasks)
|
||
}
|
||
|
||
/// 为项目创建一键批量分类任务
|
||
pub async fn create_project_batch_classification_tasks(&self, request: ProjectBatchClassificationRequest) -> Result<ProjectBatchClassificationResponse> {
|
||
println!("🚀 开始项目一键分类");
|
||
println!(" 项目ID: {}", request.project_id);
|
||
|
||
// 验证项目是否存在
|
||
let _project = self.material_repo.get_project_by_id(&request.project_id).await?
|
||
.ok_or_else(|| anyhow!("项目不存在: {}", request.project_id))?;
|
||
|
||
// 获取项目所有素材
|
||
let all_materials = self.material_repo.get_by_project_id(&request.project_id)?;
|
||
let total_materials = all_materials.len() as u32;
|
||
|
||
println!(" 项目总素材数: {}", total_materials);
|
||
|
||
// 过滤符合条件的素材
|
||
let material_types = request.material_types.unwrap_or_else(|| vec![crate::data::models::material::MaterialType::Video]);
|
||
let overwrite_existing = request.overwrite_existing.unwrap_or(false);
|
||
|
||
let mut eligible_materials = Vec::new();
|
||
let mut skipped_materials = Vec::new();
|
||
|
||
for material in all_materials {
|
||
// 检查素材类型
|
||
if !material_types.contains(&material.material_type) {
|
||
continue;
|
||
}
|
||
|
||
// 检查处理状态 - 只处理已完成处理的素材
|
||
if material.processing_status != crate::data::models::material::ProcessingStatus::Completed {
|
||
continue;
|
||
}
|
||
|
||
// 获取素材片段
|
||
let segments = self.material_repo.get_segments(&material.id)?;
|
||
if segments.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
// 检查是否已有分类记录
|
||
if !overwrite_existing {
|
||
let mut has_classification = false;
|
||
for segment in &segments {
|
||
if self.video_repo.is_segment_classified(&segment.id).await? {
|
||
has_classification = true;
|
||
break;
|
||
}
|
||
}
|
||
if has_classification {
|
||
skipped_materials.push(material.id.clone());
|
||
continue;
|
||
}
|
||
}
|
||
|
||
eligible_materials.push(material);
|
||
}
|
||
|
||
let eligible_count = eligible_materials.len() as u32;
|
||
println!(" 符合条件的素材数: {}", eligible_count);
|
||
println!(" 跳过的素材数: {}", skipped_materials.len());
|
||
|
||
// 为每个符合条件的素材创建批量分类任务
|
||
let mut all_task_ids = Vec::new();
|
||
let mut created_tasks_count = 0u32;
|
||
|
||
for material in eligible_materials {
|
||
let batch_request = BatchClassificationRequest {
|
||
material_id: material.id.clone(),
|
||
project_id: request.project_id.clone(),
|
||
overwrite_existing,
|
||
priority: request.priority,
|
||
};
|
||
|
||
match self.create_batch_classification_tasks(batch_request).await {
|
||
Ok(tasks) => {
|
||
let task_ids: Vec<String> = tasks.iter().map(|t| t.id.clone()).collect();
|
||
created_tasks_count += task_ids.len() as u32;
|
||
all_task_ids.extend(task_ids);
|
||
println!(" 为素材 {} 创建了 {} 个分类任务", material.name, tasks.len());
|
||
}
|
||
Err(e) => {
|
||
println!(" 为素材 {} 创建分类任务失败: {}", material.name, e);
|
||
// 继续处理其他素材,不因单个素材失败而中断整个流程
|
||
}
|
||
}
|
||
}
|
||
|
||
println!("✅ 项目一键分类任务创建完成");
|
||
println!(" 总共创建任务数: {}", created_tasks_count);
|
||
|
||
Ok(ProjectBatchClassificationResponse {
|
||
total_materials,
|
||
eligible_materials: eligible_count,
|
||
created_tasks: created_tasks_count,
|
||
task_ids: all_task_ids,
|
||
skipped_materials,
|
||
})
|
||
}
|
||
|
||
/// 处理单个分类任务
|
||
pub async fn process_classification_task(&self, task_id: &str) -> Result<VideoClassificationRecord> {
|
||
// 获取任务
|
||
let mut task = self.get_task_by_id(task_id).await?;
|
||
|
||
// 开始处理
|
||
task.start_processing();
|
||
self.video_repo.update_classification_task(&task).await?;
|
||
|
||
// 获取AI分类提示词
|
||
let prompt = self.generate_classification_prompt().await?;
|
||
|
||
let result = self.classify_video_with_gemini(&mut task, &prompt).await;
|
||
|
||
match result {
|
||
Ok(record) => {
|
||
// 任务完成
|
||
task.complete();
|
||
self.video_repo.update_classification_task(&task).await?;
|
||
|
||
// 移动视频文件到分类文件夹
|
||
if let Err(e) = self.move_video_to_category_folder(&record).await {
|
||
println!("警告: 移动视频文件失败: {}", e);
|
||
}
|
||
|
||
Ok(record)
|
||
}
|
||
Err(e) => {
|
||
// 任务失败
|
||
task.fail(e.to_string());
|
||
self.video_repo.update_classification_task(&task).await?;
|
||
Err(e)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 使用Gemini进行视频分类
|
||
async fn classify_video_with_gemini(&self, task: &mut VideoClassificationTask, prompt: &str) -> Result<VideoClassificationRecord> {
|
||
// 为每个任务创建独立的GeminiService实例,避免并发瓶颈
|
||
let mut gemini_service = GeminiService::new(self.gemini_config.clone());
|
||
|
||
// 调用Gemini API进行分类
|
||
let (file_uri, raw_response) = gemini_service.classify_video(&task.video_file_path, prompt).await?;
|
||
|
||
// 更新任务状态
|
||
task.set_analyzing(file_uri.clone(), prompt.to_string());
|
||
self.video_repo.update_classification_task(task).await?;
|
||
|
||
// 解析Gemini响应
|
||
let gemini_response = self.parse_gemini_response(&raw_response)?;
|
||
|
||
// 输出分类结果
|
||
println!("🎯 AI分类结果:");
|
||
println!(" 📁 视频文件: {}", task.video_file_path);
|
||
println!(" 🏷️ 分类类别: {}", gemini_response.category);
|
||
println!(" 📊 置信度: {:.1}%", gemini_response.confidence * 100.0);
|
||
println!(" ⭐ 质量评分: {:.1}/10", gemini_response.quality_score * 10.0);
|
||
println!(" 💭 分类理由: {}", gemini_response.reasoning);
|
||
if !gemini_response.features.is_empty() {
|
||
println!(" 🔍 识别特征: {}", gemini_response.features.join(", "));
|
||
}
|
||
|
||
// 创建分类记录
|
||
let mut record = VideoClassificationRecord::new(
|
||
task.segment_id.clone(),
|
||
task.material_id.clone(),
|
||
task.project_id.clone(),
|
||
gemini_response,
|
||
Some(file_uri),
|
||
Some(raw_response),
|
||
);
|
||
|
||
// 检查是否需要人工审核
|
||
if record.needs_review() {
|
||
println!(" ⚠️ 需要人工审核 (置信度或质量评分较低)");
|
||
record.mark_as_needs_review("置信度或质量评分较低,建议人工审核".to_string());
|
||
} else {
|
||
println!(" ✅ 分类质量良好");
|
||
}
|
||
|
||
// 保存分类记录
|
||
let saved_record = self.video_repo.create_classification_record(record).await?;
|
||
|
||
Ok(saved_record)
|
||
}
|
||
|
||
/// 解析Gemini响应为结构化数据
|
||
fn parse_gemini_response(&self, raw_response: &str) -> Result<GeminiClassificationResponse> {
|
||
// 尝试从响应中提取JSON
|
||
let json_start = raw_response.find('{');
|
||
let json_end = raw_response.rfind('}');
|
||
|
||
if let (Some(start), Some(end)) = (json_start, json_end) {
|
||
let json_str = &raw_response[start..=end];
|
||
|
||
match serde_json::from_str::<GeminiClassificationResponse>(json_str) {
|
||
Ok(response) => Ok(response),
|
||
Err(e) => {
|
||
// 如果解析失败,创建一个默认响应
|
||
Ok(GeminiClassificationResponse {
|
||
category: "未分类".to_string(),
|
||
confidence: 0.5,
|
||
reasoning: format!("AI响应解析失败: {} - 原始JSON: {}", e, json_str),
|
||
features: vec!["解析失败".to_string()],
|
||
product_match: false,
|
||
quality_score: 0.5,
|
||
})
|
||
}
|
||
}
|
||
} else {
|
||
// 没有找到JSON格式,创建默认响应
|
||
Ok(GeminiClassificationResponse {
|
||
category: "未分类".to_string(),
|
||
confidence: 0.3,
|
||
reasoning: format!("AI响应格式异常,未找到JSON: {}", &raw_response[..std::cmp::min(200, raw_response.len())]),
|
||
features: vec!["格式异常".to_string()],
|
||
product_match: false,
|
||
quality_score: 0.3,
|
||
})
|
||
}
|
||
}
|
||
|
||
/// 生成分类提示词
|
||
async fn generate_classification_prompt(&self) -> Result<String> {
|
||
// 获取激活的AI分类
|
||
let classifications = self.ai_classification_repo.get_all(Some(
|
||
crate::data::models::ai_classification::AiClassificationQuery {
|
||
active_only: Some(true),
|
||
sort_by: Some("sort_order".to_string()),
|
||
sort_order: Some("ASC".to_string()),
|
||
..Default::default()
|
||
}
|
||
)).await?;
|
||
|
||
if classifications.is_empty() {
|
||
return Err(anyhow!("没有激活的AI分类,无法生成提示词"));
|
||
}
|
||
|
||
Ok(generate_full_prompt(&classifications))
|
||
}
|
||
|
||
/// 移动视频文件到分类文件夹
|
||
async fn move_video_to_category_folder(&self, record: &VideoClassificationRecord) -> Result<()> {
|
||
println!("🔍 开始移动文件到分类文件夹");
|
||
println!(" 项目ID: {}", record.project_id);
|
||
println!(" 分类类别: {}", record.category);
|
||
println!(" 片段ID: {}", record.segment_id);
|
||
|
||
// 获取项目信息
|
||
let project = self.material_repo.get_project_by_id(&record.project_id).await?
|
||
.ok_or_else(|| anyhow!("项目不存在: {}", record.project_id))?;
|
||
|
||
// 获取片段信息
|
||
let segment = self.material_repo.get_segment_by_id(&record.segment_id).await?
|
||
.ok_or_else(|| anyhow!("片段不存在: {}", record.segment_id))?;
|
||
|
||
// 构建目标目录路径
|
||
let category_dir = Path::new(&project.path)
|
||
.join("assets")
|
||
.join(&record.category);
|
||
|
||
// 创建分类目录
|
||
FileSystemService::create_directory_if_not_exists(category_dir.to_str().unwrap())?;
|
||
|
||
// 构建目标文件路径
|
||
let source_path = Path::new(&segment.file_path);
|
||
let file_name = source_path.file_name()
|
||
.ok_or_else(|| anyhow!("无法获取文件名"))?;
|
||
let target_path = category_dir.join(file_name);
|
||
|
||
// 移动文件
|
||
FileSystemService::move_file(&segment.file_path, target_path.to_str().unwrap())?;
|
||
|
||
println!("✅ 视频文件已移动到分类文件夹: {} -> {}", segment.file_path, target_path.display());
|
||
|
||
// 更新片段的文件路径
|
||
let new_path = target_path.to_str().unwrap().to_string();
|
||
match self.material_repo.update_segment_file_path(&record.segment_id, &new_path).await {
|
||
Ok(()) => {
|
||
println!("✅ 片段路径已更新: {}", new_path);
|
||
}
|
||
Err(e) => {
|
||
println!("⚠️ 更新片段路径失败: {}", e);
|
||
// 注意:文件已经移动,但数据库更新失败
|
||
// 这种情况下应该考虑回滚文件移动或者记录错误
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取待处理的任务
|
||
pub async fn get_pending_tasks(&self, limit: Option<i32>) -> Result<Vec<VideoClassificationTask>> {
|
||
self.video_repo.get_pending_tasks(limit).await
|
||
}
|
||
|
||
/// 获取分类统计信息
|
||
pub async fn get_classification_stats(&self, project_id: Option<&str>) -> Result<ClassificationStats> {
|
||
self.video_repo.get_classification_stats(project_id).await
|
||
}
|
||
|
||
/// 根据素材ID获取分类记录
|
||
pub async fn get_classifications_by_material(&self, material_id: &str) -> Result<Vec<VideoClassificationRecord>> {
|
||
self.video_repo.get_by_material_id(material_id).await
|
||
}
|
||
|
||
/// 获取任务详情
|
||
async fn get_task_by_id(&self, task_id: &str) -> Result<VideoClassificationTask> {
|
||
self.video_repo.get_task_by_id(task_id).await?
|
||
.ok_or_else(|| anyhow!("任务不存在: {}", task_id))
|
||
}
|
||
|
||
/// 取消分类任务
|
||
pub async fn cancel_task(&self, task_id: &str) -> Result<()> {
|
||
self.video_repo.delete_task(task_id).await
|
||
}
|
||
|
||
/// 恢复卡住的任务状态
|
||
pub async fn recover_stuck_tasks(&self) -> Result<usize> {
|
||
self.video_repo.recover_stuck_tasks().await
|
||
}
|
||
|
||
/// 修复数据库中的日期格式问题
|
||
pub async fn fix_date_formats(&self) -> Result<usize> {
|
||
self.video_repo.fix_date_formats().await
|
||
}
|
||
|
||
/// 重试失败的任务
|
||
pub async fn retry_failed_task(&self, _task_id: &str) -> Result<()> {
|
||
// 这里需要实现重试逻辑
|
||
// 暂时返回错误,后续实现
|
||
Err(anyhow!("重试任务功能待实现"))
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_parse_gemini_response() {
|
||
let service = create_test_service();
|
||
|
||
let json_response = r#"
|
||
这是一些前置文本
|
||
{
|
||
"category": "全身",
|
||
"confidence": 0.85,
|
||
"reasoning": "视频显示完整的人体轮廓",
|
||
"features": ["全身可见", "清晰度高"],
|
||
"product_match": true,
|
||
"quality_score": 0.9
|
||
}
|
||
这是一些后置文本
|
||
"#;
|
||
|
||
let result = service.parse_gemini_response(json_response).unwrap();
|
||
assert_eq!(result.category, "全身");
|
||
assert_eq!(result.confidence, 0.85);
|
||
assert!(result.product_match);
|
||
}
|
||
|
||
fn create_test_service() -> VideoClassificationService {
|
||
// 这里需要创建测试用的service实例
|
||
// 暂时使用空实现
|
||
todo!("实现测试用的service创建")
|
||
}
|
||
}
|