mixvideo-v2/apps/desktop/src-tauri/src/business/services/ai_classification_service.rs

276 lines
9.4 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::data::models::ai_classification::{
AiClassification, CreateAiClassificationRequest, UpdateAiClassificationRequest,
AiClassificationQuery, AiClassificationPreview, generate_full_prompt
};
use crate::data::repositories::ai_classification_repository::AiClassificationRepository;
use crate::business::errors::BusinessError;
use anyhow::Result;
use std::sync::Arc;
/// AI分类业务服务
/// 遵循 Tauri 开发规范的业务逻辑层设计原则
pub struct AiClassificationService {
repository: Arc<AiClassificationRepository>,
}
impl AiClassificationService {
/// 创建新的AI分类服务实例
pub fn new(repository: Arc<AiClassificationRepository>) -> Self {
Self { repository }
}
/// 创建AI分类
pub async fn create_classification(&self, request: CreateAiClassificationRequest) -> Result<AiClassification> {
// 验证输入数据
self.validate_create_request(&request)?;
// 检查名称是否已存在
if self.repository.exists_by_name(&request.name, None).await? {
return Err(BusinessError::DuplicateName(request.name).into());
}
// 创建分类
let classification = self.repository.create(request).await?;
Ok(classification)
}
/// 获取AI分类列表
pub async fn get_classifications(&self, query: Option<AiClassificationQuery>) -> Result<Vec<AiClassification>> {
self.repository.get_all(query).await
}
/// 根据ID获取AI分类
pub async fn get_classification_by_id(&self, id: &str) -> Result<Option<AiClassification>> {
if id.trim().is_empty() {
return Err(BusinessError::InvalidInput("ID不能为空".to_string()).into());
}
self.repository.get_by_id(id).await
}
/// 更新AI分类
pub async fn update_classification(&self, id: &str, request: UpdateAiClassificationRequest) -> Result<Option<AiClassification>> {
// 验证输入数据
self.validate_update_request(&request)?;
// 如果更新名称,检查是否与其他分类重复
if let Some(ref name) = request.name {
if self.repository.exists_by_name(name, Some(id)).await? {
return Err(BusinessError::DuplicateName(name.clone()).into());
}
}
// 更新分类
self.repository.update(id, request).await
}
/// 删除AI分类
pub async fn delete_classification(&self, id: &str) -> Result<bool> {
if id.trim().is_empty() {
return Err(BusinessError::InvalidInput("ID不能为空".to_string()).into());
}
self.repository.delete(id).await
}
/// 获取分类总数
pub async fn get_classification_count(&self, active_only: Option<bool>) -> Result<i64> {
self.repository.count(active_only).await
}
/// 生成AI分类预览
pub async fn generate_preview(&self) -> Result<AiClassificationPreview> {
let classifications = self.get_classifications(Some(AiClassificationQuery {
active_only: Some(true),
sort_by: Some("sort_order".to_string()),
sort_order: Some("ASC".to_string()),
..Default::default()
})).await?;
let full_prompt = generate_full_prompt(&classifications);
Ok(AiClassificationPreview {
classifications,
full_prompt,
})
}
/// 批量更新分类排序
pub async fn update_sort_orders(&self, updates: Vec<(String, i32)>) -> Result<Vec<AiClassification>> {
let mut results = Vec::new();
for (id, sort_order) in updates {
let request = UpdateAiClassificationRequest {
name: None,
prompt_text: None,
description: None,
is_active: None,
sort_order: Some(sort_order),
};
if let Some(classification) = self.repository.update(&id, request).await? {
results.push(classification);
}
}
Ok(results)
}
/// 切换分类激活状态
pub async fn toggle_classification_status(&self, id: &str) -> Result<Option<AiClassification>> {
// 先获取当前状态
if let Some(classification) = self.repository.get_by_id(id).await? {
let request = UpdateAiClassificationRequest {
name: None,
prompt_text: None,
description: None,
is_active: Some(!classification.is_active),
sort_order: None,
};
self.repository.update(id, request).await
} else {
Ok(None)
}
}
/// 验证创建请求
fn validate_create_request(&self, request: &CreateAiClassificationRequest) -> Result<()> {
if request.name.trim().is_empty() {
return Err(BusinessError::InvalidInput("分类名称不能为空".to_string()).into());
}
if request.name.len() > 100 {
return Err(BusinessError::InvalidInput("分类名称不能超过100个字符".to_string()).into());
}
if request.prompt_text.trim().is_empty() {
return Err(BusinessError::InvalidInput("提示词不能为空".to_string()).into());
}
if request.prompt_text.len() > 1000 {
return Err(BusinessError::InvalidInput("提示词不能超过1000个字符".to_string()).into());
}
if let Some(ref description) = request.description {
if description.len() > 500 {
return Err(BusinessError::InvalidInput("描述不能超过500个字符".to_string()).into());
}
}
Ok(())
}
/// 验证更新请求
fn validate_update_request(&self, request: &UpdateAiClassificationRequest) -> Result<()> {
if let Some(ref name) = request.name {
if name.trim().is_empty() {
return Err(BusinessError::InvalidInput("分类名称不能为空".to_string()).into());
}
if name.len() > 100 {
return Err(BusinessError::InvalidInput("分类名称不能超过100个字符".to_string()).into());
}
}
if let Some(ref prompt_text) = request.prompt_text {
if prompt_text.trim().is_empty() {
return Err(BusinessError::InvalidInput("提示词不能为空".to_string()).into());
}
if prompt_text.len() > 1000 {
return Err(BusinessError::InvalidInput("提示词不能超过1000个字符".to_string()).into());
}
}
if let Some(ref description) = request.description {
if description.len() > 500 {
return Err(BusinessError::InvalidInput("描述不能超过500个字符".to_string()).into());
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::infrastructure::database::Database;
use std::sync::Arc;
async fn create_test_service() -> AiClassificationService {
let database = Arc::new(Database::new_with_path(":memory:").unwrap());
let repository = Arc::new(AiClassificationRepository::new(database));
AiClassificationService::new(repository)
}
#[tokio::test]
async fn test_create_classification() {
let service = create_test_service().await;
let request = CreateAiClassificationRequest {
name: "全身".to_string(),
prompt_text: "头顶到脚底完整入镜肢体可见度≥90%".to_string(),
description: Some("全身分类描述".to_string()),
sort_order: Some(1),
};
let result = service.create_classification(request).await;
assert!(result.is_ok());
let classification = result.unwrap();
assert_eq!(classification.name, "全身");
assert!(classification.is_active);
}
#[tokio::test]
async fn test_duplicate_name_validation() {
let service = create_test_service().await;
let request = CreateAiClassificationRequest {
name: "全身".to_string(),
prompt_text: "头顶到脚底完整入镜".to_string(),
description: None,
sort_order: None,
};
// 第一次创建应该成功
let result1 = service.create_classification(request.clone()).await;
assert!(result1.is_ok());
// 第二次创建相同名称应该失败
let result2 = service.create_classification(request).await;
assert!(result2.is_err());
}
#[tokio::test]
async fn test_generate_preview() {
let service = create_test_service().await;
// 创建几个测试分类
let requests = vec![
CreateAiClassificationRequest {
name: "全身".to_string(),
prompt_text: "头顶到脚底完整入镜".to_string(),
description: None,
sort_order: Some(1),
},
CreateAiClassificationRequest {
name: "上半身".to_string(),
prompt_text: "头部到腰部".to_string(),
description: None,
sort_order: Some(2),
},
];
for request in requests {
service.create_classification(request).await.unwrap();
}
let preview = service.generate_preview().await.unwrap();
assert_eq!(preview.classifications.len(), 2);
assert!(preview.full_prompt.contains("全身"));
assert!(preview.full_prompt.contains("上半身"));
}
}