mixvideo-v2/apps/desktop/src/services/templateSegmentWeightServic...

252 lines
7.0 KiB
TypeScript

import { invoke } from '@tauri-apps/api/core';
import {
TemplateSegmentWeight,
CreateTemplateSegmentWeightRequest,
UpdateTemplateSegmentWeightRequest,
BatchUpdateTemplateSegmentWeightRequest,
} from '../types/template';
import { AiClassification } from '../types/aiClassification';
/**
* 模板片段权重配置服务
* 遵循前端开发规范的服务层设计原则
*/
export class TemplateSegmentWeightService {
/**
* 创建模板片段权重配置
*/
static async createTemplateSegmentWeight(
request: CreateTemplateSegmentWeightRequest
): Promise<TemplateSegmentWeight> {
return await invoke('create_template_segment_weight', { request });
}
/**
* 获取模板片段的权重配置(包含默认值)
*/
static async getSegmentWeightsWithDefaults(
templateId: string,
trackSegmentId: string
): Promise<Record<string, number>> {
return await invoke('get_segment_weights_with_defaults', {
templateId,
trackSegmentId,
});
}
/**
* 获取模板片段的AI分类按权重排序
*/
static async getClassificationsBySegmentWeight(
templateId: string,
trackSegmentId: string
): Promise<AiClassification[]> {
return await invoke('get_classifications_by_segment_weight', {
templateId,
trackSegmentId,
});
}
/**
* 批量更新模板片段权重配置
*/
static async batchUpdateTemplateSegmentWeights(
request: BatchUpdateTemplateSegmentWeightRequest
): Promise<TemplateSegmentWeight[]> {
return await invoke('batch_update_template_segment_weights', { request });
}
/**
* 初始化模板片段的默认权重配置
*/
static async initializeDefaultSegmentWeights(
templateId: string,
trackSegmentId: string
): Promise<TemplateSegmentWeight[]> {
return await invoke('initialize_default_segment_weights', {
templateId,
trackSegmentId,
});
}
/**
* 重置模板片段权重配置为全局默认值
*/
static async resetSegmentWeightsToGlobal(
templateId: string,
trackSegmentId: string
): Promise<TemplateSegmentWeight[]> {
return await invoke('reset_segment_weights_to_global', {
templateId,
trackSegmentId,
});
}
/**
* 获取模板的所有权重配置
*/
static async getTemplateWeights(templateId: string): Promise<TemplateSegmentWeight[]> {
return await invoke('get_template_weights', { templateId });
}
/**
* 删除模板的所有权重配置
*/
static async deleteTemplateWeights(templateId: string): Promise<number> {
return await invoke('delete_template_weights', { templateId });
}
/**
* 更新单个权重配置
*/
static async updateTemplateSegmentWeight(
id: string,
request: UpdateTemplateSegmentWeightRequest
): Promise<TemplateSegmentWeight | null> {
return await invoke('update_template_segment_weight', { id, request });
}
/**
* 检查模板片段是否有自定义权重配置
*/
static async hasCustomSegmentWeights(
templateId: string,
trackSegmentId: string
): Promise<boolean> {
return await invoke('has_custom_segment_weights', {
templateId,
trackSegmentId,
});
}
/**
* 获取权重配置的统计信息
*/
static async getTemplateWeightStatistics(
templateId: string
): Promise<Record<string, number>> {
return await invoke('get_template_weight_statistics', { templateId });
}
/**
* 批量设置片段权重(便捷方法)
*/
static async setSegmentWeights(
templateId: string,
trackSegmentId: string,
weights: Record<string, number>
): Promise<TemplateSegmentWeight[]> {
const weightConfigs = Object.entries(weights).map(([aiClassificationId, weight]) => ({
ai_classification_id: aiClassificationId,
weight,
}));
const request: BatchUpdateTemplateSegmentWeightRequest = {
template_id: templateId,
track_segment_id: trackSegmentId,
weights: weightConfigs,
};
return await this.batchUpdateTemplateSegmentWeights(request);
}
/**
* 获取片段权重映射(便捷方法)
*/
static async getSegmentWeightMap(
templateId: string,
trackSegmentId: string
): Promise<Record<string, number>> {
return await this.getSegmentWeightsWithDefaults(templateId, trackSegmentId);
}
/**
* 复制权重配置到其他片段
*/
static async copyWeightsToSegments(
sourceTemplateId: string,
sourceTrackSegmentId: string,
targetSegments: Array<{ templateId: string; trackSegmentId: string }>
): Promise<TemplateSegmentWeight[][]> {
// 获取源片段的权重配置
const sourceWeights = await this.getSegmentWeightsWithDefaults(
sourceTemplateId,
sourceTrackSegmentId
);
// 批量应用到目标片段
const results = await Promise.all(
targetSegments.map(({ templateId, trackSegmentId }) =>
this.setSegmentWeights(templateId, trackSegmentId, sourceWeights)
)
);
return results;
}
/**
* 重置多个片段的权重配置
*/
static async resetMultipleSegmentsToGlobal(
segments: Array<{ templateId: string; trackSegmentId: string }>
): Promise<TemplateSegmentWeight[][]> {
const results = await Promise.all(
segments.map(({ templateId, trackSegmentId }) =>
this.resetSegmentWeightsToGlobal(templateId, trackSegmentId)
)
);
return results;
}
/**
* 获取指定分类的权重配置(用于按顺序匹配规则)
*/
static async getSegmentWeightsForCategories(
templateId: string,
trackSegmentId: string,
categoryIds: string[]
): Promise<Record<string, number>> {
return await invoke('get_segment_weights_for_categories', {
templateId,
trackSegmentId,
categoryIds,
});
}
/**
* 获取模板所有片段的权重统计
*/
static async getTemplateSegmentWeightSummary(templateId: string): Promise<{
totalSegments: number;
segmentsWithCustomWeights: number;
averageWeightPerClassification: Record<string, number>;
}> {
const [weights, statistics] = await Promise.all([
this.getTemplateWeights(templateId),
this.getTemplateWeightStatistics(templateId),
]);
// 计算每个分类的平均权重
const weightsByClassification: Record<string, number[]> = {};
weights.forEach((weight) => {
if (!weightsByClassification[weight.ai_classification_id]) {
weightsByClassification[weight.ai_classification_id] = [];
}
weightsByClassification[weight.ai_classification_id].push(weight.weight);
});
const averageWeightPerClassification: Record<string, number> = {};
Object.entries(weightsByClassification).forEach(([classificationId, weights]) => {
const average = weights.reduce((sum, weight) => sum + weight, 0) / weights.length;
averageWeightPerClassification[classificationId] = Math.round(average * 100) / 100;
});
return {
totalSegments: statistics.total_configurations || 0,
segmentsWithCustomWeights: statistics.unique_classifications || 0,
averageWeightPerClassification,
};
}
}