diff --git a/apps/desktop/src-tauri/src/data/repositories/comfyui_repository.rs b/apps/desktop/src-tauri/src/data/repositories/comfyui_repository.rs index 20badd7..17eaefa 100644 --- a/apps/desktop/src-tauri/src/data/repositories/comfyui_repository.rs +++ b/apps/desktop/src-tauri/src/data/repositories/comfyui_repository.rs @@ -24,9 +24,14 @@ impl ComfyUIRepository { Self { database } } - /// 获取数据库连接 - fn get_connection(&self) -> Arc> { - self.database.get_connection() + /// 获取数据库连接(使用连接池) + fn get_connection(&self) -> Result { + // 强制使用连接池,如果没有连接池则报错 + if !self.database.has_pool() { + return Err(anyhow!("ComfyUI Repository 必须使用连接池模式,请启用连接池")); + } + + self.database.get_best_connection() } /// 执行数据库操作的辅助方法 @@ -34,8 +39,7 @@ impl ComfyUIRepository { where F: FnOnce(&Connection) -> Result, { - let conn = self.get_connection(); - let conn = conn.lock().map_err(|e| anyhow!("获取数据库连接失败: {}", e))?; + let mut conn = self.get_connection()?; f(&*conn) } @@ -81,7 +85,7 @@ impl ComfyUIRepository { /// 获取工作流 pub fn get_workflow(&self, id: &str) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, name, description, workflow_data, version, created_at, updated_at, enabled, tags, category @@ -101,7 +105,7 @@ impl ComfyUIRepository { /// 获取所有工作流 pub fn list_workflows(&self, enabled_only: bool) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let sql = if enabled_only { "SELECT id, name, description, workflow_data, version, created_at, updated_at, enabled, tags, category @@ -126,7 +130,7 @@ impl ComfyUIRepository { /// 更新工作流 pub fn update_workflow(&self, workflow: &WorkflowModel) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let workflow_data_json = serde_json::to_string(&workflow.workflow_data)?; let tags_json = serde_json::to_string(&workflow.tags)?; @@ -158,7 +162,7 @@ impl ComfyUIRepository { /// 删除工作流 pub fn delete_workflow(&self, id: &str) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let rows_affected = conn.execute("DELETE FROM comfyui_workflows WHERE id = ?1", [id])?; @@ -172,7 +176,7 @@ impl ComfyUIRepository { /// 按分类获取工作流 pub fn get_workflows_by_category(&self, category: &str) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, name, description, workflow_data, version, created_at, updated_at, enabled, tags, category @@ -193,7 +197,7 @@ impl ComfyUIRepository { /// 搜索工作流 pub fn search_workflows(&self, query: &str) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let search_pattern = format!("%{}%", query); let mut stmt = conn.prepare( @@ -245,7 +249,7 @@ impl ComfyUIRepository { /// 创建模板 pub fn create_template(&self, template: &TemplateModel) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let template_data_json = serde_json::to_string(&template.template_data)?; let parameter_schema_json = serde_json::to_string(&template.parameter_schema)?; @@ -276,7 +280,7 @@ impl ComfyUIRepository { /// 获取模板 pub fn get_template(&self, id: &str) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, name, category, description, template_data, parameter_schema, created_at, updated_at, enabled, tags, author, version @@ -296,7 +300,7 @@ impl ComfyUIRepository { /// 获取所有模板 pub fn list_templates(&self, enabled_only: bool) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let sql = if enabled_only { "SELECT id, name, category, description, template_data, parameter_schema, created_at, updated_at, enabled, tags, author, version @@ -321,7 +325,7 @@ impl ComfyUIRepository { /// 更新模板 pub fn update_template(&self, template: &TemplateModel) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let template_data_json = serde_json::to_string(&template.template_data)?; let parameter_schema_json = serde_json::to_string(&template.parameter_schema)?; @@ -356,7 +360,7 @@ impl ComfyUIRepository { /// 删除模板 pub fn delete_template(&self, id: &str) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let rows_affected = conn.execute("DELETE FROM comfyui_templates WHERE id = ?1", [id])?; @@ -370,7 +374,7 @@ impl ComfyUIRepository { /// 按分类获取模板 pub fn get_templates_by_category(&self, category: &str) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, name, category, description, template_data, parameter_schema, created_at, updated_at, enabled, tags, author, version @@ -422,7 +426,7 @@ impl ComfyUIRepository { /// 创建执行记录 pub fn create_execution(&self, execution: &ExecutionModel) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let parameters_json = execution.parameters.as_ref() .map(|p| serde_json::to_string(p)) @@ -462,7 +466,7 @@ impl ComfyUIRepository { /// 获取执行记录 pub fn get_execution(&self, id: &str) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, workflow_id, template_id, prompt_id, status, parameters, results, output_urls, error_message, execution_time, created_at, completed_at, client_id, node_outputs @@ -482,7 +486,7 @@ impl ComfyUIRepository { /// 通过 prompt_id 获取执行记录 pub fn get_execution_by_prompt_id(&self, prompt_id: &str) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, workflow_id, template_id, prompt_id, status, parameters, results, output_urls, error_message, execution_time, created_at, completed_at, client_id, node_outputs @@ -502,7 +506,7 @@ impl ComfyUIRepository { /// 获取执行记录列表 pub fn list_executions(&self, limit: Option, status_filter: Option) -> Result> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let (sql, params): (String, Vec) = match (limit, status_filter) { (Some(limit), Some(status)) => ( @@ -542,7 +546,7 @@ impl ComfyUIRepository { /// 更新执行记录 pub fn update_execution(&self, execution: &ExecutionModel) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let parameters_json = execution.parameters.as_ref() .map(|p| serde_json::to_string(p)) @@ -585,7 +589,7 @@ impl ComfyUIRepository { /// 删除执行记录 pub fn delete_execution(&self, id: &str) -> Result<()> { - let conn = self.get_connection()?; + let mut conn = self.get_connection()?; let rows_affected = conn.execute("DELETE FROM comfyui_executions WHERE id = ?1", [id])?;