From 67ce7104e34f9c44b680abc22270eeb42ea46911 Mon Sep 17 00:00:00 2001 From: imeepos Date: Tue, 15 Jul 2025 17:45:08 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=B0=86=E6=89=80=E6=9C=89=E4=BB=93?= =?UTF-8?q?=E5=BA=93=E7=BB=9F=E4=B8=80=E4=BD=BF=E7=94=A8=20database:=20Arc?= =?UTF-8?q?=20=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/business/services/material_service.rs | 9 +++++---- .../commands/material_commands.rs | 14 +++----------- .../commands/material_matching_commands.rs | 4 ++-- .../material_segment_view_commands.rs | 8 ++++---- .../project_template_binding_commands.rs | 4 ++-- .../commands/video_classification_commands.rs | 19 ++++++------------- 6 files changed, 22 insertions(+), 36 deletions(-) diff --git a/apps/desktop/src-tauri/src/business/services/material_service.rs b/apps/desktop/src-tauri/src/business/services/material_service.rs index 737e4d8..2f19584 100644 --- a/apps/desktop/src-tauri/src/business/services/material_service.rs +++ b/apps/desktop/src-tauri/src/business/services/material_service.rs @@ -2,6 +2,7 @@ use anyhow::{Result, anyhow}; use std::path::Path; use std::fs; use std::time::Instant; +use std::sync::Arc; use crate::data::models::material::{ Material, MaterialType, ProcessingStatus, CreateMaterialRequest, @@ -699,8 +700,8 @@ impl MaterialService { use crate::business::services::project_service::ProjectService; // 创建数据库连接 - let db = Database::new()?; - let project_repo = ProjectRepository::new(db.get_connection())?; + let db = Arc::new(Database::new()?); + let project_repo = ProjectRepository::new(db.clone())?; // 获取项目信息 ProjectService::get_project_by_id(&project_repo, project_id)? @@ -869,8 +870,8 @@ impl MaterialService { use crate::infrastructure::database::Database; use crate::data::repositories::model_repository::ModelRepository; - let db = Database::new()?; - let model_repo = ModelRepository::new(db.get_connection()); + let db = Arc::new(Database::new()?); + let model_repo = ModelRepository::new(db.clone()); if model_repo.get_by_id(model_id)?.is_none() { return Err(anyhow!("找不到模特: {}", model_id)); diff --git a/apps/desktop/src-tauri/src/presentation/commands/material_commands.rs b/apps/desktop/src-tauri/src/presentation/commands/material_commands.rs index 51ba449..b88fce8 100644 --- a/apps/desktop/src-tauri/src/presentation/commands/material_commands.rs +++ b/apps/desktop/src-tauri/src/presentation/commands/material_commands.rs @@ -44,16 +44,8 @@ pub async fn import_materials_async( ) -> Result { // 添加调试日志 println!("Import request received with model_id: {:?}", request.model_id); - // 获取数据库连接,避免持有MutexGuard - let connection = { - let repository_guard = state.get_material_repository() - .map_err(|e| format!("获取素材仓库失败: {}", e))?; - - let repository = repository_guard.as_ref() - .ok_or("素材仓库未初始化")?; - - repository.get_connection() - }; // repository_guard在这里被释放 + // 获取数据库实例 + let database = state.get_database(); let mut config = MaterialProcessingConfig::default(); config.auto_process = Some(request.auto_process); @@ -62,7 +54,7 @@ pub async fn import_materials_async( } // 创建一个新的MaterialRepository实例用于异步操作 - let async_repository = MaterialRepository::new(connection) + let async_repository = MaterialRepository::new(database) .map_err(|e| format!("创建异步仓库失败: {}", e))?; let repository_arc = Arc::new(async_repository); diff --git a/apps/desktop/src-tauri/src/presentation/commands/material_matching_commands.rs b/apps/desktop/src-tauri/src/presentation/commands/material_matching_commands.rs index 933e559..adec077 100644 --- a/apps/desktop/src-tauri/src/presentation/commands/material_matching_commands.rs +++ b/apps/desktop/src-tauri/src/presentation/commands/material_matching_commands.rs @@ -24,7 +24,7 @@ pub async fn execute_material_matching( ) -> Result { // 创建服务实例 let material_repo = Arc::new( - MaterialRepository::new(database.get_connection()) + MaterialRepository::new(database.inner().clone()) .map_err(|e| format!("创建素材仓库失败: {}", e))? ); @@ -52,7 +52,7 @@ pub async fn get_project_material_stats_for_matching( project_id: String, database: State<'_, Arc>, ) -> Result { - let material_repo = MaterialRepository::new(database.get_connection()) + let material_repo = MaterialRepository::new(database.inner().clone()) .map_err(|e| format!("创建素材仓库失败: {}", e))?; let video_classification_repo = VideoClassificationRepository::new(database.inner().clone()); diff --git a/apps/desktop/src-tauri/src/presentation/commands/material_segment_view_commands.rs b/apps/desktop/src-tauri/src/presentation/commands/material_segment_view_commands.rs index 5d16979..12bf26d 100644 --- a/apps/desktop/src-tauri/src/presentation/commands/material_segment_view_commands.rs +++ b/apps/desktop/src-tauri/src/presentation/commands/material_segment_view_commands.rs @@ -17,7 +17,7 @@ pub async fn get_project_segment_view( // 创建仓库实例 let material_repository = Arc::new( crate::data::repositories::material_repository::MaterialRepository::new( - database.get_connection() + database.clone() ).map_err(|e| format!("创建素材仓库失败: {}", e))? ); @@ -29,7 +29,7 @@ pub async fn get_project_segment_view( let model_repository = Arc::new( crate::data::repositories::model_repository::ModelRepository::new( - database.get_connection() + database.clone() ) ); @@ -58,7 +58,7 @@ pub async fn get_project_segment_view_with_query( // 创建仓库实例 let material_repository = Arc::new( crate::data::repositories::material_repository::MaterialRepository::new( - database.get_connection() + database.clone() ).map_err(|e| format!("创建素材仓库失败: {}", e))? ); @@ -70,7 +70,7 @@ pub async fn get_project_segment_view_with_query( let model_repository = Arc::new( crate::data::repositories::model_repository::ModelRepository::new( - database.get_connection() + database.clone() ) ); diff --git a/apps/desktop/src-tauri/src/presentation/commands/project_template_binding_commands.rs b/apps/desktop/src-tauri/src/presentation/commands/project_template_binding_commands.rs index aa7239e..9e886c3 100644 --- a/apps/desktop/src-tauri/src/presentation/commands/project_template_binding_commands.rs +++ b/apps/desktop/src-tauri/src/presentation/commands/project_template_binding_commands.rs @@ -9,8 +9,8 @@ use crate::data::models::project_template_binding::{ /// 辅助函数:创建服务实例 fn create_service(state: &State<'_, AppState>) -> Result { - let connection = state.get_connection().map_err(|e| e.to_string())?; - ProjectTemplateBindingService::new(connection).map_err(|e| e.to_string()) + let database = state.get_database(); + ProjectTemplateBindingService::new(database).map_err(|e| e.to_string()) } /// 创建项目-模板绑定 diff --git a/apps/desktop/src-tauri/src/presentation/commands/video_classification_commands.rs b/apps/desktop/src-tauri/src/presentation/commands/video_classification_commands.rs index 268eabc..93fdb77 100644 --- a/apps/desktop/src-tauri/src/presentation/commands/video_classification_commands.rs +++ b/apps/desktop/src-tauri/src/presentation/commands/video_classification_commands.rs @@ -22,7 +22,7 @@ async fn get_queue_instance(state: &AppState) -> Arc { let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); - let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); + let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap()); let gemini_config = Some(GeminiConfig::default()); let service = Arc::new(VideoClassificationService::new( @@ -151,7 +151,7 @@ pub async fn recover_stuck_classification_tasks( let database = state.get_database(); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); - let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); + let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap()); let service = VideoClassificationService::new( video_repo, @@ -192,7 +192,7 @@ pub async fn get_material_classification_records( let database = state.get_database(); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); - let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); + let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap()); let service = VideoClassificationService::new( video_repo, @@ -215,7 +215,7 @@ pub async fn get_classification_statistics( let database = state.get_database(); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); - let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); + let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap()); let service = VideoClassificationService::new( video_repo, @@ -252,7 +252,7 @@ pub async fn cancel_classification_task( let database = state.get_database(); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); - let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); + let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap()); let service = VideoClassificationService::new( video_repo, @@ -275,7 +275,7 @@ pub async fn retry_classification_task( let database = state.get_database(); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); - let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); + let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap()); let service = VideoClassificationService::new( video_repo, @@ -320,11 +320,4 @@ mod tests { // 应该返回同一个实例 assert!(Arc::ptr_eq(&queue1, &queue2)); } - - #[tokio::test] - async fn test_classification_queue_status() { - let state = create_test_state().await; - let result = get_classification_queue_status(tauri::State::from(&state)).await; - assert!(result.is_ok()); - } }