fix: 将所有仓库统一使用 database: Arc<Database> 方式

This commit is contained in:
imeepos 2025-07-15 17:45:08 +08:00
parent 54630ea2ff
commit 67ce7104e3
6 changed files with 22 additions and 36 deletions

View File

@ -2,6 +2,7 @@ use anyhow::{Result, anyhow};
use std::path::Path; use std::path::Path;
use std::fs; use std::fs;
use std::time::Instant; use std::time::Instant;
use std::sync::Arc;
use crate::data::models::material::{ use crate::data::models::material::{
Material, MaterialType, ProcessingStatus, CreateMaterialRequest, Material, MaterialType, ProcessingStatus, CreateMaterialRequest,
@ -699,8 +700,8 @@ impl MaterialService {
use crate::business::services::project_service::ProjectService; use crate::business::services::project_service::ProjectService;
// 创建数据库连接 // 创建数据库连接
let db = Database::new()?; let db = Arc::new(Database::new()?);
let project_repo = ProjectRepository::new(db.get_connection())?; let project_repo = ProjectRepository::new(db.clone())?;
// 获取项目信息 // 获取项目信息
ProjectService::get_project_by_id(&project_repo, project_id)? ProjectService::get_project_by_id(&project_repo, project_id)?
@ -869,8 +870,8 @@ impl MaterialService {
use crate::infrastructure::database::Database; use crate::infrastructure::database::Database;
use crate::data::repositories::model_repository::ModelRepository; use crate::data::repositories::model_repository::ModelRepository;
let db = Database::new()?; let db = Arc::new(Database::new()?);
let model_repo = ModelRepository::new(db.get_connection()); let model_repo = ModelRepository::new(db.clone());
if model_repo.get_by_id(model_id)?.is_none() { if model_repo.get_by_id(model_id)?.is_none() {
return Err(anyhow!("找不到模特: {}", model_id)); return Err(anyhow!("找不到模特: {}", model_id));

View File

@ -44,16 +44,8 @@ pub async fn import_materials_async(
) -> Result<MaterialImportResult, String> { ) -> Result<MaterialImportResult, String> {
// 添加调试日志 // 添加调试日志
println!("Import request received with model_id: {:?}", request.model_id); println!("Import request received with model_id: {:?}", request.model_id);
// 获取数据库连接避免持有MutexGuard // 获取数据库实例
let connection = { let database = state.get_database();
let repository_guard = state.get_material_repository()
.map_err(|e| format!("获取素材仓库失败: {}", e))?;
let repository = repository_guard.as_ref()
.ok_or("素材仓库未初始化")?;
repository.get_connection()
}; // repository_guard在这里被释放
let mut config = MaterialProcessingConfig::default(); let mut config = MaterialProcessingConfig::default();
config.auto_process = Some(request.auto_process); config.auto_process = Some(request.auto_process);
@ -62,7 +54,7 @@ pub async fn import_materials_async(
} }
// 创建一个新的MaterialRepository实例用于异步操作 // 创建一个新的MaterialRepository实例用于异步操作
let async_repository = MaterialRepository::new(connection) let async_repository = MaterialRepository::new(database)
.map_err(|e| format!("创建异步仓库失败: {}", e))?; .map_err(|e| format!("创建异步仓库失败: {}", e))?;
let repository_arc = Arc::new(async_repository); let repository_arc = Arc::new(async_repository);

View File

@ -24,7 +24,7 @@ pub async fn execute_material_matching(
) -> Result<MaterialMatchingResult, String> { ) -> Result<MaterialMatchingResult, String> {
// 创建服务实例 // 创建服务实例
let material_repo = Arc::new( let material_repo = Arc::new(
MaterialRepository::new(database.get_connection()) MaterialRepository::new(database.inner().clone())
.map_err(|e| format!("创建素材仓库失败: {}", e))? .map_err(|e| format!("创建素材仓库失败: {}", e))?
); );
@ -52,7 +52,7 @@ pub async fn get_project_material_stats_for_matching(
project_id: String, project_id: String,
database: State<'_, Arc<Database>>, database: State<'_, Arc<Database>>,
) -> Result<ProjectMaterialMatchingStats, String> { ) -> Result<ProjectMaterialMatchingStats, String> {
let material_repo = MaterialRepository::new(database.get_connection()) let material_repo = MaterialRepository::new(database.inner().clone())
.map_err(|e| format!("创建素材仓库失败: {}", e))?; .map_err(|e| format!("创建素材仓库失败: {}", e))?;
let video_classification_repo = VideoClassificationRepository::new(database.inner().clone()); let video_classification_repo = VideoClassificationRepository::new(database.inner().clone());

View File

@ -17,7 +17,7 @@ pub async fn get_project_segment_view(
// 创建仓库实例 // 创建仓库实例
let material_repository = Arc::new( let material_repository = Arc::new(
crate::data::repositories::material_repository::MaterialRepository::new( crate::data::repositories::material_repository::MaterialRepository::new(
database.get_connection() database.clone()
).map_err(|e| format!("创建素材仓库失败: {}", e))? ).map_err(|e| format!("创建素材仓库失败: {}", e))?
); );
@ -29,7 +29,7 @@ pub async fn get_project_segment_view(
let model_repository = Arc::new( let model_repository = Arc::new(
crate::data::repositories::model_repository::ModelRepository::new( crate::data::repositories::model_repository::ModelRepository::new(
database.get_connection() database.clone()
) )
); );
@ -58,7 +58,7 @@ pub async fn get_project_segment_view_with_query(
// 创建仓库实例 // 创建仓库实例
let material_repository = Arc::new( let material_repository = Arc::new(
crate::data::repositories::material_repository::MaterialRepository::new( crate::data::repositories::material_repository::MaterialRepository::new(
database.get_connection() database.clone()
).map_err(|e| format!("创建素材仓库失败: {}", e))? ).map_err(|e| format!("创建素材仓库失败: {}", e))?
); );
@ -70,7 +70,7 @@ pub async fn get_project_segment_view_with_query(
let model_repository = Arc::new( let model_repository = Arc::new(
crate::data::repositories::model_repository::ModelRepository::new( crate::data::repositories::model_repository::ModelRepository::new(
database.get_connection() database.clone()
) )
); );

View File

@ -9,8 +9,8 @@ use crate::data::models::project_template_binding::{
/// 辅助函数:创建服务实例 /// 辅助函数:创建服务实例
fn create_service(state: &State<'_, AppState>) -> Result<ProjectTemplateBindingService, String> { fn create_service(state: &State<'_, AppState>) -> Result<ProjectTemplateBindingService, String> {
let connection = state.get_connection().map_err(|e| e.to_string())?; let database = state.get_database();
ProjectTemplateBindingService::new(connection).map_err(|e| e.to_string()) ProjectTemplateBindingService::new(database).map_err(|e| e.to_string())
} }
/// 创建项目-模板绑定 /// 创建项目-模板绑定

View File

@ -22,7 +22,7 @@ async fn get_queue_instance(state: &AppState) -> Arc<VideoClassificationQueue> {
let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone()));
let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone()));
let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap());
let gemini_config = Some(GeminiConfig::default()); let gemini_config = Some(GeminiConfig::default());
let service = Arc::new(VideoClassificationService::new( let service = Arc::new(VideoClassificationService::new(
@ -151,7 +151,7 @@ pub async fn recover_stuck_classification_tasks(
let database = state.get_database(); let database = state.get_database();
let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone()));
let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone()));
let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap());
let service = VideoClassificationService::new( let service = VideoClassificationService::new(
video_repo, video_repo,
@ -192,7 +192,7 @@ pub async fn get_material_classification_records(
let database = state.get_database(); let database = state.get_database();
let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone()));
let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone()));
let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap());
let service = VideoClassificationService::new( let service = VideoClassificationService::new(
video_repo, video_repo,
@ -215,7 +215,7 @@ pub async fn get_classification_statistics(
let database = state.get_database(); let database = state.get_database();
let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone()));
let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone()));
let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap());
let service = VideoClassificationService::new( let service = VideoClassificationService::new(
video_repo, video_repo,
@ -252,7 +252,7 @@ pub async fn cancel_classification_task(
let database = state.get_database(); let database = state.get_database();
let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone()));
let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone()));
let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap());
let service = VideoClassificationService::new( let service = VideoClassificationService::new(
video_repo, video_repo,
@ -275,7 +275,7 @@ pub async fn retry_classification_task(
let database = state.get_database(); let database = state.get_database();
let video_repo = Arc::new(VideoClassificationRepository::new(database.clone())); let video_repo = Arc::new(VideoClassificationRepository::new(database.clone()));
let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone())); let ai_classification_repo = Arc::new(AiClassificationRepository::new(database.clone()));
let material_repo = Arc::new(MaterialRepository::new(database.get_connection()).unwrap()); let material_repo = Arc::new(MaterialRepository::new(database.clone()).unwrap());
let service = VideoClassificationService::new( let service = VideoClassificationService::new(
video_repo, video_repo,
@ -320,11 +320,4 @@ mod tests {
// 应该返回同一个实例 // 应该返回同一个实例
assert!(Arc::ptr_eq(&queue1, &queue2)); assert!(Arc::ptr_eq(&queue1, &queue2));
} }
#[tokio::test]
async fn test_classification_queue_status() {
let state = create_test_state().await;
let result = get_classification_queue_status(tauri::State::from(&state)).await;
assert!(result.is_ok());
}
} }