276 lines
9.4 KiB
Rust
276 lines
9.4 KiB
Rust
use crate::data::models::ai_classification::{
|
||
AiClassification, CreateAiClassificationRequest, UpdateAiClassificationRequest,
|
||
AiClassificationQuery, AiClassificationPreview, generate_full_prompt
|
||
};
|
||
use crate::data::repositories::ai_classification_repository::AiClassificationRepository;
|
||
use crate::business::errors::BusinessError;
|
||
use anyhow::Result;
|
||
use std::sync::Arc;
|
||
|
||
/// AI分类业务服务
|
||
/// 遵循 Tauri 开发规范的业务逻辑层设计原则
|
||
pub struct AiClassificationService {
|
||
repository: Arc<AiClassificationRepository>,
|
||
}
|
||
|
||
impl AiClassificationService {
|
||
/// 创建新的AI分类服务实例
|
||
pub fn new(repository: Arc<AiClassificationRepository>) -> Self {
|
||
Self { repository }
|
||
}
|
||
|
||
/// 创建AI分类
|
||
pub async fn create_classification(&self, request: CreateAiClassificationRequest) -> Result<AiClassification> {
|
||
// 验证输入数据
|
||
self.validate_create_request(&request)?;
|
||
|
||
// 检查名称是否已存在
|
||
if self.repository.exists_by_name(&request.name, None).await? {
|
||
return Err(BusinessError::DuplicateName(request.name).into());
|
||
}
|
||
|
||
// 创建分类
|
||
let classification = self.repository.create(request).await?;
|
||
|
||
Ok(classification)
|
||
}
|
||
|
||
/// 获取AI分类列表
|
||
pub async fn get_classifications(&self, query: Option<AiClassificationQuery>) -> Result<Vec<AiClassification>> {
|
||
self.repository.get_all(query).await
|
||
}
|
||
|
||
/// 根据ID获取AI分类
|
||
pub async fn get_classification_by_id(&self, id: &str) -> Result<Option<AiClassification>> {
|
||
if id.trim().is_empty() {
|
||
return Err(BusinessError::InvalidInput("ID不能为空".to_string()).into());
|
||
}
|
||
|
||
self.repository.get_by_id(id).await
|
||
}
|
||
|
||
/// 更新AI分类
|
||
pub async fn update_classification(&self, id: &str, request: UpdateAiClassificationRequest) -> Result<Option<AiClassification>> {
|
||
// 验证输入数据
|
||
self.validate_update_request(&request)?;
|
||
|
||
// 如果更新名称,检查是否与其他分类重复
|
||
if let Some(ref name) = request.name {
|
||
if self.repository.exists_by_name(name, Some(id)).await? {
|
||
return Err(BusinessError::DuplicateName(name.clone()).into());
|
||
}
|
||
}
|
||
|
||
// 更新分类
|
||
self.repository.update(id, request).await
|
||
}
|
||
|
||
/// 删除AI分类
|
||
pub async fn delete_classification(&self, id: &str) -> Result<bool> {
|
||
if id.trim().is_empty() {
|
||
return Err(BusinessError::InvalidInput("ID不能为空".to_string()).into());
|
||
}
|
||
|
||
self.repository.delete(id).await
|
||
}
|
||
|
||
/// 获取分类总数
|
||
pub async fn get_classification_count(&self, active_only: Option<bool>) -> Result<i64> {
|
||
self.repository.count(active_only).await
|
||
}
|
||
|
||
/// 生成AI分类预览
|
||
pub async fn generate_preview(&self) -> Result<AiClassificationPreview> {
|
||
let classifications = self.get_classifications(Some(AiClassificationQuery {
|
||
active_only: Some(true),
|
||
sort_by: Some("sort_order".to_string()),
|
||
sort_order: Some("ASC".to_string()),
|
||
..Default::default()
|
||
})).await?;
|
||
|
||
let full_prompt = generate_full_prompt(&classifications);
|
||
|
||
Ok(AiClassificationPreview {
|
||
classifications,
|
||
full_prompt,
|
||
})
|
||
}
|
||
|
||
/// 批量更新分类排序
|
||
pub async fn update_sort_orders(&self, updates: Vec<(String, i32)>) -> Result<Vec<AiClassification>> {
|
||
let mut results = Vec::new();
|
||
|
||
for (id, sort_order) in updates {
|
||
let request = UpdateAiClassificationRequest {
|
||
name: None,
|
||
prompt_text: None,
|
||
description: None,
|
||
is_active: None,
|
||
sort_order: Some(sort_order),
|
||
};
|
||
|
||
if let Some(classification) = self.repository.update(&id, request).await? {
|
||
results.push(classification);
|
||
}
|
||
}
|
||
|
||
Ok(results)
|
||
}
|
||
|
||
/// 切换分类激活状态
|
||
pub async fn toggle_classification_status(&self, id: &str) -> Result<Option<AiClassification>> {
|
||
// 先获取当前状态
|
||
if let Some(classification) = self.repository.get_by_id(id).await? {
|
||
let request = UpdateAiClassificationRequest {
|
||
name: None,
|
||
prompt_text: None,
|
||
description: None,
|
||
is_active: Some(!classification.is_active),
|
||
sort_order: None,
|
||
};
|
||
|
||
self.repository.update(id, request).await
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 验证创建请求
|
||
fn validate_create_request(&self, request: &CreateAiClassificationRequest) -> Result<()> {
|
||
if request.name.trim().is_empty() {
|
||
return Err(BusinessError::InvalidInput("分类名称不能为空".to_string()).into());
|
||
}
|
||
|
||
if request.name.len() > 100 {
|
||
return Err(BusinessError::InvalidInput("分类名称不能超过100个字符".to_string()).into());
|
||
}
|
||
|
||
if request.prompt_text.trim().is_empty() {
|
||
return Err(BusinessError::InvalidInput("提示词不能为空".to_string()).into());
|
||
}
|
||
|
||
if request.prompt_text.len() > 1000 {
|
||
return Err(BusinessError::InvalidInput("提示词不能超过1000个字符".to_string()).into());
|
||
}
|
||
|
||
if let Some(ref description) = request.description {
|
||
if description.len() > 500 {
|
||
return Err(BusinessError::InvalidInput("描述不能超过500个字符".to_string()).into());
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 验证更新请求
|
||
fn validate_update_request(&self, request: &UpdateAiClassificationRequest) -> Result<()> {
|
||
if let Some(ref name) = request.name {
|
||
if name.trim().is_empty() {
|
||
return Err(BusinessError::InvalidInput("分类名称不能为空".to_string()).into());
|
||
}
|
||
if name.len() > 100 {
|
||
return Err(BusinessError::InvalidInput("分类名称不能超过100个字符".to_string()).into());
|
||
}
|
||
}
|
||
|
||
if let Some(ref prompt_text) = request.prompt_text {
|
||
if prompt_text.trim().is_empty() {
|
||
return Err(BusinessError::InvalidInput("提示词不能为空".to_string()).into());
|
||
}
|
||
if prompt_text.len() > 1000 {
|
||
return Err(BusinessError::InvalidInput("提示词不能超过1000个字符".to_string()).into());
|
||
}
|
||
}
|
||
|
||
if let Some(ref description) = request.description {
|
||
if description.len() > 500 {
|
||
return Err(BusinessError::InvalidInput("描述不能超过500个字符".to_string()).into());
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::infrastructure::database::Database;
|
||
use std::sync::Arc;
|
||
|
||
async fn create_test_service() -> AiClassificationService {
|
||
let database = Arc::new(Database::new_with_path(":memory:").unwrap());
|
||
let repository = Arc::new(AiClassificationRepository::new(database));
|
||
AiClassificationService::new(repository)
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_create_classification() {
|
||
let service = create_test_service().await;
|
||
|
||
let request = CreateAiClassificationRequest {
|
||
name: "全身".to_string(),
|
||
prompt_text: "头顶到脚底完整入镜,肢体可见度≥90%".to_string(),
|
||
description: Some("全身分类描述".to_string()),
|
||
sort_order: Some(1),
|
||
};
|
||
|
||
let result = service.create_classification(request).await;
|
||
assert!(result.is_ok());
|
||
|
||
let classification = result.unwrap();
|
||
assert_eq!(classification.name, "全身");
|
||
assert!(classification.is_active);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_duplicate_name_validation() {
|
||
let service = create_test_service().await;
|
||
|
||
let request = CreateAiClassificationRequest {
|
||
name: "全身".to_string(),
|
||
prompt_text: "头顶到脚底完整入镜".to_string(),
|
||
description: None,
|
||
sort_order: None,
|
||
};
|
||
|
||
// 第一次创建应该成功
|
||
let result1 = service.create_classification(request.clone()).await;
|
||
assert!(result1.is_ok());
|
||
|
||
// 第二次创建相同名称应该失败
|
||
let result2 = service.create_classification(request).await;
|
||
assert!(result2.is_err());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_generate_preview() {
|
||
let service = create_test_service().await;
|
||
|
||
// 创建几个测试分类
|
||
let requests = vec![
|
||
CreateAiClassificationRequest {
|
||
name: "全身".to_string(),
|
||
prompt_text: "头顶到脚底完整入镜".to_string(),
|
||
description: None,
|
||
sort_order: Some(1),
|
||
},
|
||
CreateAiClassificationRequest {
|
||
name: "上半身".to_string(),
|
||
prompt_text: "头部到腰部".to_string(),
|
||
description: None,
|
||
sort_order: Some(2),
|
||
},
|
||
];
|
||
|
||
for request in requests {
|
||
service.create_classification(request).await.unwrap();
|
||
}
|
||
|
||
let preview = service.generate_preview().await.unwrap();
|
||
assert_eq!(preview.classifications.len(), 2);
|
||
assert!(preview.full_prompt.contains("全身"));
|
||
assert!(preview.full_prompt.contains("上半身"));
|
||
}
|
||
}
|