feat: 实现项目一键AI分类功能

- 添加ProjectBatchClassificationRequest和ProjectBatchClassificationResponse数据模型
- 在VideoClassificationService中实现create_project_batch_classification_tasks方法
- 添加start_project_batch_classification Tauri命令接口
- 在前端添加startProjectBatchClassification方法和相关类型定义
- 在项目详情页面添加一键AI分类按钮和队列状态监控
- 支持批量处理项目下所有符合条件的视频素材
- 集成现有的AI分类队列系统,确保兼容性
This commit is contained in:
imeepos 2025-07-14 18:26:03 +08:00
parent 5d8e2df8d0
commit eeeef4ead4
9 changed files with 425 additions and 8 deletions

View File

@ -43,7 +43,7 @@ pub struct TaskProgress {
/// AI视频分类任务队列
/// 遵循 Tauri 开发规范的业务层设计模式
pub struct VideoClassificationQueue {
service: Arc<VideoClassificationService>,
pub service: Arc<VideoClassificationService>,
status: Arc<RwLock<QueueStatus>>,
current_task: Arc<Mutex<Option<String>>>,
task_progress: Arc<RwLock<HashMap<String, TaskProgress>>>,

View File

@ -102,6 +102,105 @@ impl VideoClassificationService {
Ok(tasks)
}
/// 为项目创建一键批量分类任务
pub async fn create_project_batch_classification_tasks(&self, request: ProjectBatchClassificationRequest) -> Result<ProjectBatchClassificationResponse> {
println!("🚀 开始项目一键分类");
println!(" 项目ID: {}", request.project_id);
// 验证项目是否存在
let _project = self.material_repo.get_project_by_id(&request.project_id).await?
.ok_or_else(|| anyhow!("项目不存在: {}", request.project_id))?;
// 获取项目所有素材
let all_materials = self.material_repo.get_by_project_id(&request.project_id)?;
let total_materials = all_materials.len() as u32;
println!(" 项目总素材数: {}", total_materials);
// 过滤符合条件的素材
let material_types = request.material_types.unwrap_or_else(|| vec![crate::data::models::material::MaterialType::Video]);
let overwrite_existing = request.overwrite_existing.unwrap_or(false);
let mut eligible_materials = Vec::new();
let mut skipped_materials = Vec::new();
for material in all_materials {
// 检查素材类型
if !material_types.contains(&material.material_type) {
continue;
}
// 检查处理状态 - 只处理已完成处理的素材
if material.processing_status != crate::data::models::material::ProcessingStatus::Completed {
continue;
}
// 获取素材片段
let segments = self.material_repo.get_segments(&material.id)?;
if segments.is_empty() {
continue;
}
// 检查是否已有分类记录
if !overwrite_existing {
let mut has_classification = false;
for segment in &segments {
if self.video_repo.is_segment_classified(&segment.id).await? {
has_classification = true;
break;
}
}
if has_classification {
skipped_materials.push(material.id.clone());
continue;
}
}
eligible_materials.push(material);
}
let eligible_count = eligible_materials.len() as u32;
println!(" 符合条件的素材数: {}", eligible_count);
println!(" 跳过的素材数: {}", skipped_materials.len());
// 为每个符合条件的素材创建批量分类任务
let mut all_task_ids = Vec::new();
let mut created_tasks_count = 0u32;
for material in eligible_materials {
let batch_request = BatchClassificationRequest {
material_id: material.id.clone(),
project_id: request.project_id.clone(),
overwrite_existing,
priority: request.priority,
};
match self.create_batch_classification_tasks(batch_request).await {
Ok(tasks) => {
let task_ids: Vec<String> = tasks.iter().map(|t| t.id.clone()).collect();
created_tasks_count += task_ids.len() as u32;
all_task_ids.extend(task_ids);
println!(" 为素材 {} 创建了 {} 个分类任务", material.name, tasks.len());
}
Err(e) => {
println!(" 为素材 {} 创建分类任务失败: {}", material.name, e);
// 继续处理其他素材,不因单个素材失败而中断整个流程
}
}
}
println!("✅ 项目一键分类任务创建完成");
println!(" 总共创建任务数: {}", created_tasks_count);
Ok(ProjectBatchClassificationResponse {
total_materials,
eligible_materials: eligible_count,
created_tasks: created_tasks_count,
task_ids: all_task_ids,
skipped_materials,
})
}
/// 处理单个分类任务
pub async fn process_classification_task(&self, task_id: &str) -> Result<VideoClassificationRecord> {
// 获取任务

View File

@ -117,6 +117,34 @@ pub struct BatchClassificationRequest {
pub priority: Option<i32>,
}
/// 项目一键分类请求
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProjectBatchClassificationRequest {
/// 项目ID
pub project_id: String,
/// 是否覆盖已有分类
pub overwrite_existing: Option<bool>,
/// 要处理的素材类型(可选,默认只处理视频)
pub material_types: Option<Vec<crate::data::models::material::MaterialType>>,
/// 任务优先级
pub priority: Option<i32>,
}
/// 项目一键分类响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProjectBatchClassificationResponse {
/// 项目中总素材数
pub total_materials: u32,
/// 符合条件的素材数
pub eligible_materials: u32,
/// 创建的任务数
pub created_tasks: u32,
/// 创建的任务ID列表
pub task_ids: Vec<String>,
/// 跳过的素材ID列表已有分类
pub skipped_materials: Vec<String>,
}
/// 分类任务查询参数
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ClassificationTaskQuery {

View File

@ -106,6 +106,7 @@ pub fn run() {
commands::ai_classification_commands::validate_ai_classification_name,
// AI视频分类命令
commands::video_classification_commands::start_video_classification,
commands::video_classification_commands::start_project_batch_classification,
commands::video_classification_commands::get_classification_queue_status,
commands::video_classification_commands::get_project_classification_queue_status,
commands::video_classification_commands::get_classification_task_progress,

View File

@ -61,6 +61,31 @@ pub async fn start_video_classification(
Ok(task_ids)
}
/// 启动项目一键AI视频分类
/// 遍历项目下所有符合条件的素材并添加到分类队列
#[command]
pub async fn start_project_batch_classification(
request: ProjectBatchClassificationRequest,
state: State<'_, AppState>,
) -> Result<ProjectBatchClassificationResponse, String> {
let queue = get_queue_instance(&state).await;
// 创建项目批量分类任务
let response = queue.service.create_project_batch_classification_tasks(request)
.await
.map_err(|e| e.to_string())?;
// 启动队列(如果尚未启动)
if let Err(e) = queue.start().await {
// 如果队列已经在运行,忽略错误
if !e.to_string().contains("已经在运行中") {
return Err(e.to_string());
}
}
Ok(response)
}
/// 获取分类队列状态
#[command]
pub async fn get_classification_queue_status(

View File

@ -127,10 +127,10 @@ export const VideoClassificationProgress: React.FC<VideoClassificationProgressPr
// 计算进度百分比
const getOverallProgress = useCallback(() => {
const getOverallProgress = () => {
if (!typedQueueStats || typedQueueStats.total_tasks === 0) return 0;
return Math.round(((typedQueueStats.completed_tasks + typedQueueStats.failed_tasks) / typedQueueStats.total_tasks) * 100);
}, [typedQueueStats]);
};
// 过滤相关任务
const relevantTasks = materialId
@ -142,9 +142,7 @@ export const VideoClassificationProgress: React.FC<VideoClassificationProgressPr
}
const statusInfo = typedQueueStats ? getStatusInfo(typedQueueStats.status) : null;
const overallProgress = useMemo(() => {
return getOverallProgress()
}, [getOverallProgress]);
const overallProgress = getOverallProgress()
return (
<div className="bg-white rounded-lg shadow-sm border border-gray-200 overflow-hidden">

View File

@ -1,10 +1,12 @@
import React, { useEffect, useState } from 'react';
import { useParams, useNavigate } from 'react-router-dom';
import { ArrowLeft, FolderOpen, Settings, Upload, FileVideo, FileAudio, FileImage, HardDrive } from 'lucide-react';
import { ArrowLeft, FolderOpen, Settings, Upload, FileVideo, FileAudio, FileImage, HardDrive, Brain, Loader2 } from 'lucide-react';
import { useProjectStore } from '../store/projectStore';
import { useMaterialStore } from '../store/materialStore';
import { useVideoClassificationStore } from '../store/videoClassificationStore';
import { Project } from '../types/project';
import { MaterialImportResult } from '../types/material';
import { ProjectBatchClassificationRequest, ProjectBatchClassificationResponse } from '../types/videoClassification';
import { LoadingSpinner } from '../components/LoadingSpinner';
import { ErrorMessage } from '../components/ErrorMessage';
import { MaterialImportDialog } from '../components/MaterialImportDialog';
@ -28,9 +30,17 @@ export const ProjectDetails: React.FC = () => {
loadMaterialStats,
isLoading: materialsLoading
} = useMaterialStore();
const {
startProjectBatchClassification,
isLoading: classificationLoading,
error: classificationError,
queueStats,
getProjectQueueStatus
} = useVideoClassificationStore();
const [project, setProject] = useState<Project | null>(null);
const [showImportDialog, setShowImportDialog] = useState(false);
const [activeTab, setActiveTab] = useState<'materials' | 'debug' | 'ai-logs'>('materials');
const [batchClassificationResult, setBatchClassificationResult] = useState<ProjectBatchClassificationResponse | null>(null);
// 加载项目详情
useEffect(() => {
@ -53,6 +63,21 @@ export const ProjectDetails: React.FC = () => {
}
}, [id, projects, loadMaterials, loadMaterialStats]);
// 监控AI分类队列状态
useEffect(() => {
if (!project) return;
// 初始加载队列状态
getProjectQueueStatus(project.id);
// 设置定时刷新
const interval = setInterval(() => {
getProjectQueueStatus(project.id);
}, 3000); // 每3秒刷新一次
return () => clearInterval(interval);
}, [project, getProjectQueueStatus]);
// 返回项目列表
const handleBack = () => {
navigate('/');
@ -109,6 +134,40 @@ export const ProjectDetails: React.FC = () => {
}
};
// 一键AI分类处理
const handleBatchClassification = async () => {
if (!project) return;
try {
const request: ProjectBatchClassificationRequest = {
project_id: project.id,
overwrite_existing: false, // 默认不覆盖已有分类
material_types: undefined, // 使用默认值(只处理视频)
priority: undefined, // 使用默认优先级
};
const response = await startProjectBatchClassification(request);
setBatchClassificationResult(response);
// 显示结果提示
const message = `一键分类启动成功!\n` +
`项目总素材数: ${response.total_materials}\n` +
`符合条件的素材数: ${response.eligible_materials}\n` +
`创建的任务数: ${response.created_tasks}\n` +
`跳过的素材数: ${response.skipped_materials.length}`;
alert(message);
// 刷新队列状态
if (project) {
getProjectQueueStatus(project.id);
}
} catch (error) {
console.error('一键分类失败:', error);
alert(`一键分类失败: ${error}`);
}
};
if (isLoading) {
return (
<div className="flex items-center justify-center min-h-[400px]">
@ -168,6 +227,21 @@ export const ProjectDetails: React.FC = () => {
<span className="hidden sm:inline ml-2"></span>
</button>
<button
onClick={handleBatchClassification}
disabled={classificationLoading}
className="inline-flex items-center px-3 sm:px-4 py-2 bg-gradient-to-r from-purple-500 to-pink-500 text-white rounded-lg hover:from-purple-600 hover:to-pink-600 disabled:opacity-50 disabled:cursor-not-allowed transition-all duration-200 shadow-sm hover:shadow-md text-sm"
>
{classificationLoading ? (
<Loader2 className="w-4 h-4 animate-spin" />
) : (
<Brain className="w-4 h-4" />
)}
<span className="hidden sm:inline ml-2">
{classificationLoading ? '分类中...' : '一键AI分类'}
</span>
</button>
<button className="inline-flex items-center px-3 sm:px-4 py-2 bg-gray-100 text-gray-700 rounded-lg hover:bg-gray-200 transition-colors text-sm">
<Settings className="w-4 h-4" />
<span className="hidden sm:inline ml-2"></span>
@ -177,7 +251,7 @@ export const ProjectDetails: React.FC = () => {
</div>
{/* 项目统计概览 */}
<div className="grid grid-cols-2 md:grid-cols-4 gap-4 md:gap-6 mb-6">
<div className="grid grid-cols-2 md:grid-cols-5 gap-4 md:gap-6 mb-6">
{/* 总素材数 */}
<div className="bg-white rounded-lg shadow-sm border border-gray-200 p-4 md:p-6">
<div className="flex items-center justify-between">
@ -229,6 +303,29 @@ export const ProjectDetails: React.FC = () => {
</div>
</div>
</div>
{/* AI分类状态 */}
<div className="bg-white rounded-lg shadow-sm border border-gray-200 p-4 md:p-6">
<div className="flex items-center justify-between">
<div className="min-w-0 flex-1">
<p className="text-xs md:text-sm font-medium text-gray-600 truncate">AI分类队列</p>
<div className="flex items-center space-x-1">
<p className="text-xl md:text-2xl font-bold text-gray-900">
{queueStats?.pending_tasks || 0}
</p>
<span className="text-xs text-gray-500"></span>
</div>
{queueStats?.processing_tasks && queueStats.processing_tasks > 0 && (
<p className="text-xs text-blue-600 mt-1">
{queueStats.processing_tasks}
</p>
)}
</div>
<div className="w-10 h-10 md:w-12 md:h-12 bg-gradient-to-br from-purple-100 to-pink-100 rounded-lg flex items-center justify-center ml-2">
<Brain className="w-5 h-5 md:w-6 md:h-6 text-purple-600" />
</div>
</div>
</div>
</div>
{/* 主要内容区域 */}

View File

@ -1,5 +1,9 @@
import { create } from 'zustand';
import { invoke } from '@tauri-apps/api/core';
import {
ProjectBatchClassificationRequest,
ProjectBatchClassificationResponse
} from '../types/videoClassification';
// 类型定义
export interface VideoClassificationRecord {
@ -70,6 +74,7 @@ interface VideoClassificationState {
// Actions
startClassification: (request: BatchClassificationRequest) => Promise<string[]>;
startProjectBatchClassification: (request: ProjectBatchClassificationRequest) => Promise<ProjectBatchClassificationResponse>;
getQueueStatus: () => Promise<QueueStats>;
getProjectQueueStatus: (projectId: string) => Promise<QueueStats>;
getTaskProgress: (taskId: string) => Promise<TaskProgress | null>;
@ -118,6 +123,24 @@ export const useVideoClassificationStore = create<VideoClassificationState>((set
}
},
startProjectBatchClassification: async (request: ProjectBatchClassificationRequest) => {
set({ isLoading: true, error: null });
try {
const response = await invoke<ProjectBatchClassificationResponse>('start_project_batch_classification', { request });
// 刷新队列状态
await get().refreshQueueStatus();
await get().refreshTaskProgress();
set({ isLoading: false });
return response;
} catch (error) {
const errorMessage = typeof error === 'string' ? error : '启动项目一键分类失败';
set({ error: errorMessage, isLoading: false });
throw new Error(errorMessage);
}
},
getQueueStatus: async () => {
try {
const stats = await invoke<QueueStats>('get_classification_queue_status');

View File

@ -0,0 +1,146 @@
import { MaterialType } from './material';
/**
*
*/
export enum TaskStatus {
Pending = 'Pending',
Uploading = 'Uploading',
Analyzing = 'Analyzing',
Completed = 'Completed',
Failed = 'Failed',
Cancelled = 'Cancelled',
}
/**
*
*/
export enum QueueStatus {
Stopped = 'Stopped',
Running = 'Running',
Paused = 'Paused',
}
/**
*
*/
export enum ClassificationStatus {
Classified = 'Classified',
Failed = 'Failed',
NeedsReview = 'NeedsReview',
}
/**
*
*/
export interface BatchClassificationRequest {
material_id: string;
project_id: string;
overwrite_existing: boolean;
priority?: number;
}
/**
*
*/
export interface ProjectBatchClassificationRequest {
project_id: string;
overwrite_existing?: boolean;
material_types?: MaterialType[];
priority?: number;
}
/**
*
*/
export interface ProjectBatchClassificationResponse {
total_materials: number;
eligible_materials: number;
created_tasks: number;
task_ids: string[];
skipped_materials: string[];
}
/**
*
*/
export interface TaskProgress {
task_id: string;
status: TaskStatus;
progress_percentage: number;
current_step: string;
error_message?: string;
started_at?: string;
estimated_completion?: string;
}
/**
*
*/
export interface QueueStats {
status: QueueStatus;
total_tasks: number;
pending_tasks: number;
processing_tasks: number;
completed_tasks: number;
failed_tasks: number;
current_task_id?: string;
processing_rate: number;
}
/**
*
*/
export interface VideoClassificationRecord {
id: string;
task_id: string;
segment_id: string;
material_id: string;
project_id: string;
video_file_path: string;
classification_result: string;
confidence_score: number;
reasoning: string;
features: string[];
product_match: boolean;
quality_score: number;
gemini_file_uri: string;
raw_response: string;
status: ClassificationStatus;
created_at: string;
updated_at: string;
}
/**
*
*/
export interface ClassificationStats {
total_tasks: number;
pending_tasks: number;
processing_tasks: number;
completed_tasks: number;
failed_tasks: number;
total_classifications: number;
average_confidence_score: number;
average_quality_score: number;
}
/**
*
*/
export interface VideoClassificationTask {
id: string;
segment_id: string;
material_id: string;
project_id: string;
video_file_path: string;
status: TaskStatus;
priority: number;
gemini_file_uri?: string;
prompt_text?: string;
error_message?: string;
created_at: string;
updated_at: string;
started_at?: string;
completed_at?: string;
}