use crate::data::models::ai_classification::{ AiClassification, CreateAiClassificationRequest, UpdateAiClassificationRequest, AiClassificationQuery, AiClassificationPreview, generate_full_prompt }; use crate::data::repositories::ai_classification_repository::AiClassificationRepository; use crate::business::errors::BusinessError; use anyhow::Result; use std::sync::Arc; /// AI分类业务服务 /// 遵循 Tauri 开发规范的业务逻辑层设计原则 pub struct AiClassificationService { repository: Arc, } impl AiClassificationService { /// 创建新的AI分类服务实例 pub fn new(repository: Arc) -> Self { Self { repository } } /// 创建AI分类 pub async fn create_classification(&self, request: CreateAiClassificationRequest) -> Result { // 验证输入数据 self.validate_create_request(&request)?; // 检查名称是否已存在 if self.repository.exists_by_name(&request.name, None).await? { return Err(BusinessError::DuplicateName(request.name).into()); } // 创建分类 let classification = self.repository.create(request).await?; Ok(classification) } /// 获取AI分类列表 pub async fn get_classifications(&self, query: Option) -> Result> { self.repository.get_all(query).await } /// 根据ID获取AI分类 pub async fn get_classification_by_id(&self, id: &str) -> Result> { if id.trim().is_empty() { return Err(BusinessError::InvalidInput("ID不能为空".to_string()).into()); } self.repository.get_by_id(id).await } /// 更新AI分类 pub async fn update_classification(&self, id: &str, request: UpdateAiClassificationRequest) -> Result> { // 验证输入数据 self.validate_update_request(&request)?; // 如果更新名称,检查是否与其他分类重复 if let Some(ref name) = request.name { if self.repository.exists_by_name(name, Some(id)).await? { return Err(BusinessError::DuplicateName(name.clone()).into()); } } // 更新分类 self.repository.update(id, request).await } /// 删除AI分类 pub async fn delete_classification(&self, id: &str) -> Result { if id.trim().is_empty() { return Err(BusinessError::InvalidInput("ID不能为空".to_string()).into()); } self.repository.delete(id).await } /// 获取分类总数 pub async fn get_classification_count(&self, active_only: Option) -> Result { self.repository.count(active_only).await } /// 生成AI分类预览 pub async fn generate_preview(&self) -> Result { let classifications = self.get_classifications(Some(AiClassificationQuery { active_only: Some(true), sort_by: Some("sort_order".to_string()), sort_order: Some("ASC".to_string()), ..Default::default() })).await?; let full_prompt = generate_full_prompt(&classifications); Ok(AiClassificationPreview { classifications, full_prompt, }) } /// 批量更新分类排序 pub async fn update_sort_orders(&self, updates: Vec<(String, i32)>) -> Result> { let mut results = Vec::new(); for (id, sort_order) in updates { let request = UpdateAiClassificationRequest { name: None, prompt_text: None, description: None, is_active: None, sort_order: Some(sort_order), }; if let Some(classification) = self.repository.update(&id, request).await? { results.push(classification); } } Ok(results) } /// 切换分类激活状态 pub async fn toggle_classification_status(&self, id: &str) -> Result> { // 先获取当前状态 if let Some(classification) = self.repository.get_by_id(id).await? { let request = UpdateAiClassificationRequest { name: None, prompt_text: None, description: None, is_active: Some(!classification.is_active), sort_order: None, }; self.repository.update(id, request).await } else { Ok(None) } } /// 验证创建请求 fn validate_create_request(&self, request: &CreateAiClassificationRequest) -> Result<()> { if request.name.trim().is_empty() { return Err(BusinessError::InvalidInput("分类名称不能为空".to_string()).into()); } if request.name.len() > 100 { return Err(BusinessError::InvalidInput("分类名称不能超过100个字符".to_string()).into()); } if request.prompt_text.trim().is_empty() { return Err(BusinessError::InvalidInput("提示词不能为空".to_string()).into()); } if request.prompt_text.len() > 1000 { return Err(BusinessError::InvalidInput("提示词不能超过1000个字符".to_string()).into()); } if let Some(ref description) = request.description { if description.len() > 500 { return Err(BusinessError::InvalidInput("描述不能超过500个字符".to_string()).into()); } } Ok(()) } /// 验证更新请求 fn validate_update_request(&self, request: &UpdateAiClassificationRequest) -> Result<()> { if let Some(ref name) = request.name { if name.trim().is_empty() { return Err(BusinessError::InvalidInput("分类名称不能为空".to_string()).into()); } if name.len() > 100 { return Err(BusinessError::InvalidInput("分类名称不能超过100个字符".to_string()).into()); } } if let Some(ref prompt_text) = request.prompt_text { if prompt_text.trim().is_empty() { return Err(BusinessError::InvalidInput("提示词不能为空".to_string()).into()); } if prompt_text.len() > 1000 { return Err(BusinessError::InvalidInput("提示词不能超过1000个字符".to_string()).into()); } } if let Some(ref description) = request.description { if description.len() > 500 { return Err(BusinessError::InvalidInput("描述不能超过500个字符".to_string()).into()); } } Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::infrastructure::database::Database; use std::sync::Arc; async fn create_test_service() -> AiClassificationService { let database = Arc::new(Database::new_with_path(":memory:").unwrap()); let repository = Arc::new(AiClassificationRepository::new(database)); AiClassificationService::new(repository) } #[tokio::test] async fn test_create_classification() { let service = create_test_service().await; let request = CreateAiClassificationRequest { name: "全身".to_string(), prompt_text: "头顶到脚底完整入镜,肢体可见度≥90%".to_string(), description: Some("全身分类描述".to_string()), sort_order: Some(1), }; let result = service.create_classification(request).await; assert!(result.is_ok()); let classification = result.unwrap(); assert_eq!(classification.name, "全身"); assert!(classification.is_active); } #[tokio::test] async fn test_duplicate_name_validation() { let service = create_test_service().await; let request = CreateAiClassificationRequest { name: "全身".to_string(), prompt_text: "头顶到脚底完整入镜".to_string(), description: None, sort_order: None, }; // 第一次创建应该成功 let result1 = service.create_classification(request.clone()).await; assert!(result1.is_ok()); // 第二次创建相同名称应该失败 let result2 = service.create_classification(request).await; assert!(result2.is_err()); } #[tokio::test] async fn test_generate_preview() { let service = create_test_service().await; // 创建几个测试分类 let requests = vec![ CreateAiClassificationRequest { name: "全身".to_string(), prompt_text: "头顶到脚底完整入镜".to_string(), description: None, sort_order: Some(1), }, CreateAiClassificationRequest { name: "上半身".to_string(), prompt_text: "头部到腰部".to_string(), description: None, sort_order: Some(2), }, ]; for request in requests { service.create_classification(request).await.unwrap(); } let preview = service.generate_preview().await.unwrap(); assert_eq!(preview.classifications.len(), 2); assert!(preview.full_prompt.contains("全身")); assert!(preview.full_prompt.contains("上半身")); } }