diff --git a/apps/desktop/src-tauri/src/business/services/template_segment_weight_service.rs b/apps/desktop/src-tauri/src/business/services/template_segment_weight_service.rs index 82775c3..c9ec486 100644 --- a/apps/desktop/src-tauri/src/business/services/template_segment_weight_service.rs +++ b/apps/desktop/src-tauri/src/business/services/template_segment_weight_service.rs @@ -168,18 +168,44 @@ impl TemplateSegmentWeightService { /// 获取权重配置的统计信息 pub async fn get_weight_statistics(&self, template_id: &str) -> Result> { let weights = self.get_template_weights(template_id).await?; - + let mut stats = HashMap::new(); stats.insert("total_configurations".to_string(), weights.len() as i32); - + // 统计每个AI分类的配置数量 let mut classification_counts = HashMap::new(); for weight in weights { *classification_counts.entry(weight.ai_classification_id).or_insert(0) += 1; } - + stats.insert("unique_classifications".to_string(), classification_counts.len() as i32); - + Ok(stats) } + + /// 获取指定分类的权重配置(用于按顺序匹配规则) + pub async fn get_segment_weights_for_categories(&self, template_id: &str, track_segment_id: &str, category_ids: &[String]) -> Result> { + // 获取模板片段的自定义权重配置 + let custom_weights = self.repository.get_weight_map_for_segment(template_id, track_segment_id).await?; + + // 获取指定分类的AI分类信息 + let mut final_weights = HashMap::new(); + + for category_id in category_ids { + // 优先使用自定义权重 + if let Some(&custom_weight) = custom_weights.get(category_id) { + final_weights.insert(category_id.clone(), custom_weight); + } else { + // 如果没有自定义权重,获取全局权重 + if let Ok(Some(classification)) = self.ai_classification_service.get_classification_by_id(category_id).await { + final_weights.insert(category_id.clone(), classification.weight); + } else { + // 如果分类不存在,使用默认权重 + final_weights.insert(category_id.clone(), 50); + } + } + } + + Ok(final_weights) + } } diff --git a/apps/desktop/src-tauri/src/lib.rs b/apps/desktop/src-tauri/src/lib.rs index 66cfaea..7d524da 100644 --- a/apps/desktop/src-tauri/src/lib.rs +++ b/apps/desktop/src-tauri/src/lib.rs @@ -385,7 +385,8 @@ pub fn run() { commands::template_segment_weight_commands::delete_template_weights, commands::template_segment_weight_commands::update_template_segment_weight, commands::template_segment_weight_commands::has_custom_segment_weights, - commands::template_segment_weight_commands::get_template_weight_statistics + commands::template_segment_weight_commands::get_template_weight_statistics, + commands::template_segment_weight_commands::get_segment_weights_for_categories ]) .setup(|app| { // 初始化日志系统 diff --git a/apps/desktop/src-tauri/src/presentation/commands/template_segment_weight_commands.rs b/apps/desktop/src-tauri/src/presentation/commands/template_segment_weight_commands.rs index 1416cd1..f1bce14 100644 --- a/apps/desktop/src-tauri/src/presentation/commands/template_segment_weight_commands.rs +++ b/apps/desktop/src-tauri/src/presentation/commands/template_segment_weight_commands.rs @@ -168,3 +168,18 @@ pub async fn get_template_weight_statistics( .await .map_err(|e| e.to_string()) } + +/// 获取指定分类的权重配置(用于按顺序匹配规则) +#[tauri::command] +pub async fn get_segment_weights_for_categories( + state: State<'_, AppState>, + template_id: String, + track_segment_id: String, + category_ids: Vec, +) -> Result, String> { + let service = create_template_segment_weight_service(&state); + + service.get_segment_weights_for_categories(&template_id, &track_segment_id, &category_ids) + .await + .map_err(|e| e.to_string()) +} diff --git a/apps/desktop/src/components/template/SegmentMatchingRuleEditor.tsx b/apps/desktop/src/components/template/SegmentMatchingRuleEditor.tsx index fc3ada6..4d4d5ef 100644 --- a/apps/desktop/src/components/template/SegmentMatchingRuleEditor.tsx +++ b/apps/desktop/src/components/template/SegmentMatchingRuleEditor.tsx @@ -57,11 +57,23 @@ export const SegmentMatchingRuleEditor: React.FC if (!templateId) return; try { - const weights = await TemplateSegmentWeightService.getSegmentWeightsWithDefaults( - templateId, - segmentId - ); - setEditingWeights({ ...weights }); + if (SegmentMatchingRuleHelper.isPriorityOrder(editingRule)) { + const selectedCategoryIds = typeof editingRule === 'object' && 'PriorityOrder' in editingRule + ? editingRule.PriorityOrder.category_ids + : []; + + if (selectedCategoryIds.length > 0) { + // 只加载选中分类的权重 + const weights = await TemplateSegmentWeightService.getSegmentWeightsForCategories( + templateId, + segmentId, + selectedCategoryIds + ); + setEditingWeights({ ...weights }); + } else { + setEditingWeights({}); + } + } } catch (error) { console.error('Failed to load weight data:', error); } diff --git a/apps/desktop/src/components/template/SegmentWeightIndicator.tsx b/apps/desktop/src/components/template/SegmentWeightIndicator.tsx index 9975ab4..3801120 100644 --- a/apps/desktop/src/components/template/SegmentWeightIndicator.tsx +++ b/apps/desktop/src/components/template/SegmentWeightIndicator.tsx @@ -49,28 +49,25 @@ export const SegmentWeightIndicator: React.FC = ({ try { setLoading(true); - const [hasCustom, weights] = await Promise.all([ - TemplateSegmentWeightService.hasCustomSegmentWeights(templateId, trackSegmentId), - TemplateSegmentWeightService.getSegmentWeightsWithDefaults(templateId, trackSegmentId), - ]); - + const hasCustom = await TemplateSegmentWeightService.hasCustomSegmentWeights(templateId, trackSegmentId); setHasCustomWeights(hasCustom); // 计算权重摘要 - 只统计实际选择的分类 let relevantWeights: Record = {}; if (segmentMatchingRule && SegmentMatchingRuleHelper.isPriorityOrder(segmentMatchingRule)) { - // 对于按顺序匹配规则,只统计选择的分类 + // 对于按顺序匹配规则,只获取选择的分类权重 const selectedCategoryIds = typeof segmentMatchingRule === 'object' && 'PriorityOrder' in segmentMatchingRule ? segmentMatchingRule.PriorityOrder.category_ids : []; - // 只包含选择的分类的权重 - relevantWeights = Object.fromEntries( - Object.entries(weights).filter(([classificationId]) => - selectedCategoryIds.includes(classificationId) - ) - ); + if (selectedCategoryIds.length > 0) { + relevantWeights = await TemplateSegmentWeightService.getSegmentWeightsForCategories( + templateId, + trackSegmentId, + selectedCategoryIds + ); + } } else { // 对于其他规则类型,不显示权重信息(因为不相关) relevantWeights = {}; @@ -223,26 +220,23 @@ export const WeightPreviewTooltip: React.FC = ({ try { setLoading(true); - const allWeights = await TemplateSegmentWeightService.getSegmentWeightsWithDefaults( - templateId, - trackSegmentId - ); - // 过滤权重数据,只显示实际选中的分类 + // 只获取实际选中的分类权重 let relevantWeights: Record = {}; if (segmentMatchingRule && SegmentMatchingRuleHelper.isPriorityOrder(segmentMatchingRule)) { - // 对于按顺序匹配规则,只显示选择的分类 + // 对于按顺序匹配规则,只获取选择的分类权重 const selectedCategoryIds = typeof segmentMatchingRule === 'object' && 'PriorityOrder' in segmentMatchingRule ? segmentMatchingRule.PriorityOrder.category_ids : []; - // 只包含选择的分类的权重 - relevantWeights = Object.fromEntries( - Object.entries(allWeights).filter(([classificationId]) => - selectedCategoryIds.includes(classificationId) - ) - ); + if (selectedCategoryIds.length > 0) { + relevantWeights = await TemplateSegmentWeightService.getSegmentWeightsForCategories( + templateId, + trackSegmentId, + selectedCategoryIds + ); + } } else { // 对于其他规则类型,不显示权重信息 relevantWeights = {}; diff --git a/apps/desktop/src/services/templateSegmentWeightService.ts b/apps/desktop/src/services/templateSegmentWeightService.ts index e99bdcb..bd7df51 100644 --- a/apps/desktop/src/services/templateSegmentWeightService.ts +++ b/apps/desktop/src/services/templateSegmentWeightService.ts @@ -199,6 +199,21 @@ export class TemplateSegmentWeightService { return results; } + /** + * 获取指定分类的权重配置(用于按顺序匹配规则) + */ + static async getSegmentWeightsForCategories( + templateId: string, + trackSegmentId: string, + categoryIds: string[] + ): Promise> { + return await invoke('get_segment_weights_for_categories', { + templateId, + trackSegmentId, + categoryIds, + }); + } + /** * 获取模板所有片段的权重统计 */