//! 模板引擎 //! 负责工作流模板的管理、验证和实例化 use anyhow::{Result, anyhow}; use std::collections::HashMap; use std::sync::Arc; use tracing::{info, warn, error, debug}; use comfyui_sdk::templates::{WorkflowTemplate, WorkflowInstance, TemplateManager}; use comfyui_sdk::types::{ParameterValues, ValidationResult, ValidationError, ComfyUIWorkflow}; use crate::data::models::comfyui::{TemplateModel, WorkflowModel}; use crate::data::repositories::comfyui_repository::ComfyUIRepository; /// 模板引擎 /// 提供模板管理和实例化功能 pub struct TemplateEngine { /// 数据仓库 repository: Arc, /// SDK 模板管理器 template_manager: TemplateManager, /// 内存中的模板缓存 template_cache: Arc>>, } impl TemplateEngine { /// 创建新的模板引擎 pub fn new(repository: Arc) -> Self { Self { repository, template_manager: TemplateManager::new(), template_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())), } } /// 加载模板 pub async fn load_template(&self, template_id: &str) -> Result { // 先检查缓存 { let cache = self.template_cache.read().await; if let Some(template) = cache.get(template_id) { debug!("从缓存加载模板: {}", template_id); return Ok(template.clone()); } } // 从数据库加载 match self.repository.get_template(template_id).await? { Some(template) => { // 更新缓存 { let mut cache = self.template_cache.write().await; cache.insert(template_id.to_string(), template.clone()); } debug!("从数据库加载模板: {}", template_id); Ok(template) } None => Err(anyhow!("模板不存在: {}", template_id)), } } /// 获取所有模板 pub async fn list_templates(&self, enabled_only: bool) -> Result> { self.repository.list_templates(enabled_only).await } /// 按分类获取模板 pub async fn get_templates_by_category(&self, category: &str) -> Result> { self.repository.get_templates_by_category(category).await } /// 创建模板 pub async fn create_template(&self, template: TemplateModel) -> Result<()> { // 验证模板 self.validate_template_structure(&template)?; // 保存到数据库 self.repository.create_template(&template).await?; // 更新缓存 { let mut cache = self.template_cache.write().await; cache.insert(template.id.clone(), template.clone()); } info!("创建模板成功: {} ({})", template.name, template.id); Ok(()) } /// 更新模板 pub async fn update_template(&self, template: TemplateModel) -> Result<()> { // 验证模板 self.validate_template_structure(&template)?; // 更新数据库 self.repository.update_template(&template).await?; // 更新缓存 { let mut cache = self.template_cache.write().await; cache.insert(template.id.clone(), template.clone()); } info!("更新模板成功: {} ({})", template.name, template.id); Ok(()) } /// 删除模板 pub async fn delete_template(&self, template_id: &str) -> Result<()> { // 从数据库删除 self.repository.delete_template(template_id).await?; // 从缓存删除 { let mut cache = self.template_cache.write().await; cache.remove(template_id); } info!("删除模板成功: {}", template_id); Ok(()) } /// 验证模板参数 pub async fn validate_parameters(&self, template_id: &str, parameters: &ParameterValues) -> Result { let template = self.load_template(template_id).await?; Ok(template.validate_parameters(parameters)) } /// 创建模板实例 pub async fn create_instance(&self, template_id: &str, parameters: ParameterValues) -> Result { let template = self.load_template(template_id).await?; // 验证参数 let validation = template.validate_parameters(¶meters); if !validation.valid { return Err(anyhow!("参数验证失败: {:?}", validation.errors)); } // 使用 SDK 创建工作流模板 let workflow_template = WorkflowTemplate::new( template.template_data.clone(), template.parameter_schema.clone(), )?; // 创建实例 let instance = workflow_template.create_instance(parameters)?; debug!("创建模板实例成功: {} -> {}", template_id, instance.get_id()); Ok(instance) } /// 从工作流创建模板 pub async fn create_template_from_workflow(&self, workflow: &WorkflowModel, parameter_schema: HashMap) -> Result { // 创建模板元数据 let template_metadata = comfyui_sdk::types::TemplateMetadata { id: uuid::Uuid::new_v4().to_string(), name: format!("{} Template", workflow.name), description: workflow.description.clone(), version: Some(workflow.version.clone()), author: None, tags: Some(workflow.tags.clone()), category: workflow.category.clone(), created_at: Some(chrono::Utc::now()), updated_at: Some(chrono::Utc::now()), }; // 创建模板数据 let template_data = comfyui_sdk::types::WorkflowTemplateData { metadata: template_metadata.clone(), workflow: workflow.workflow_data.clone(), parameters: parameter_schema.clone(), }; // 创建模板模型 let template = TemplateModel::new( template_metadata.name.clone(), template_data, parameter_schema, ); Ok(template) } /// 预览模板实例(不执行) pub async fn preview_instance(&self, template_id: &str, parameters: ParameterValues) -> Result { let instance = self.create_instance(template_id, parameters).await?; Ok(instance.get_workflow().clone()) } /// 获取模板的参数定义 pub async fn get_parameter_schema(&self, template_id: &str) -> Result> { let template = self.load_template(template_id).await?; Ok(template.parameter_schema) } /// 验证模板结构 fn validate_template_structure(&self, template: &TemplateModel) -> Result<()> { // 检查模板名称 if template.name.trim().is_empty() { return Err(anyhow!("模板名称不能为空")); } // 检查工作流数据 if template.template_data.workflow.is_empty() { return Err(anyhow!("工作流数据不能为空")); } // 检查参数定义 if template.parameter_schema.is_empty() { warn!("模板 {} 没有定义参数", template.id); } // 验证工作流结构 for (node_id, node) in &template.template_data.workflow { if node.class_type.trim().is_empty() { return Err(anyhow!("节点 {} 的 class_type 不能为空", node_id)); } } Ok(()) } /// 清除模板缓存 pub async fn clear_cache(&self) { let mut cache = self.template_cache.write().await; cache.clear(); info!("模板缓存已清除"); } /// 预热模板缓存 pub async fn warm_cache(&self) -> Result<()> { let templates = self.repository.list_templates(true).await?; { let mut cache = self.template_cache.write().await; for template in templates { cache.insert(template.id.clone(), template); } } info!("模板缓存预热完成"); Ok(()) } /// 获取缓存统计信息 pub async fn get_cache_stats(&self) -> CacheStats { let cache = self.template_cache.read().await; CacheStats { cached_templates: cache.len(), cache_keys: cache.keys().cloned().collect(), } } /// 搜索模板 pub async fn search_templates(&self, query: &str) -> Result> { let all_templates = self.repository.list_templates(true).await?; let query_lower = query.to_lowercase(); let filtered_templates: Vec = all_templates .into_iter() .filter(|template| { template.name.to_lowercase().contains(&query_lower) || template.description.as_ref().map_or(false, |desc| desc.to_lowercase().contains(&query_lower)) || template.tags.iter().any(|tag| tag.to_lowercase().contains(&query_lower)) }) .collect(); Ok(filtered_templates) } /// 导出模板 pub async fn export_template(&self, template_id: &str) -> Result { let template = self.load_template(template_id).await?; serde_json::to_string_pretty(&template) .map_err(|e| anyhow!("导出模板失败: {}", e)) } /// 导入模板 pub async fn import_template(&self, template_json: &str) -> Result { let mut template: TemplateModel = serde_json::from_str(template_json) .map_err(|e| anyhow!("解析模板失败: {}", e))?; // 生成新的 ID 避免冲突 template.id = uuid::Uuid::new_v4().to_string(); template.created_at = chrono::Utc::now(); template.updated_at = chrono::Utc::now(); // 创建模板 self.create_template(template.clone()).await?; Ok(template.id) } } /// 缓存统计信息 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct CacheStats { pub cached_templates: usize, pub cache_keys: Vec, }