mixvideo-v2/apps/desktop/src-tauri/src/business/services/template_engine.rs

302 lines
10 KiB
Rust

//! 模板引擎
//! 负责工作流模板的管理、验证和实例化
use anyhow::{Result, anyhow};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{info, warn, error, debug};
use comfyui_sdk::templates::{WorkflowTemplate, WorkflowInstance, TemplateManager};
use comfyui_sdk::types::{ParameterValues, ValidationResult, ValidationError, ComfyUIWorkflow};
use crate::data::models::comfyui::{TemplateModel, WorkflowModel};
use crate::data::repositories::comfyui_repository::ComfyUIRepository;
/// 模板引擎
/// 提供模板管理和实例化功能
pub struct TemplateEngine {
/// 数据仓库
repository: Arc<ComfyUIRepository>,
/// SDK 模板管理器
template_manager: TemplateManager,
/// 内存中的模板缓存
template_cache: Arc<tokio::sync::RwLock<HashMap<String, TemplateModel>>>,
}
impl TemplateEngine {
/// 创建新的模板引擎
pub fn new(repository: Arc<ComfyUIRepository>) -> Self {
Self {
repository,
template_manager: TemplateManager::new(),
template_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
/// 加载模板
pub async fn load_template(&self, template_id: &str) -> Result<TemplateModel> {
// 先检查缓存
{
let cache = self.template_cache.read().await;
if let Some(template) = cache.get(template_id) {
debug!("从缓存加载模板: {}", template_id);
return Ok(template.clone());
}
}
// 从数据库加载
match self.repository.get_template(template_id).await? {
Some(template) => {
// 更新缓存
{
let mut cache = self.template_cache.write().await;
cache.insert(template_id.to_string(), template.clone());
}
debug!("从数据库加载模板: {}", template_id);
Ok(template)
}
None => Err(anyhow!("模板不存在: {}", template_id)),
}
}
/// 获取所有模板
pub async fn list_templates(&self, enabled_only: bool) -> Result<Vec<TemplateModel>> {
self.repository.list_templates(enabled_only).await
}
/// 按分类获取模板
pub async fn get_templates_by_category(&self, category: &str) -> Result<Vec<TemplateModel>> {
self.repository.get_templates_by_category(category).await
}
/// 创建模板
pub async fn create_template(&self, template: TemplateModel) -> Result<()> {
// 验证模板
self.validate_template_structure(&template)?;
// 保存到数据库
self.repository.create_template(&template).await?;
// 更新缓存
{
let mut cache = self.template_cache.write().await;
cache.insert(template.id.clone(), template.clone());
}
info!("创建模板成功: {} ({})", template.name, template.id);
Ok(())
}
/// 更新模板
pub async fn update_template(&self, template: TemplateModel) -> Result<()> {
// 验证模板
self.validate_template_structure(&template)?;
// 更新数据库
self.repository.update_template(&template).await?;
// 更新缓存
{
let mut cache = self.template_cache.write().await;
cache.insert(template.id.clone(), template.clone());
}
info!("更新模板成功: {} ({})", template.name, template.id);
Ok(())
}
/// 删除模板
pub async fn delete_template(&self, template_id: &str) -> Result<()> {
// 从数据库删除
self.repository.delete_template(template_id).await?;
// 从缓存删除
{
let mut cache = self.template_cache.write().await;
cache.remove(template_id);
}
info!("删除模板成功: {}", template_id);
Ok(())
}
/// 验证模板参数
pub async fn validate_parameters(&self, template_id: &str, parameters: &ParameterValues) -> Result<ValidationResult> {
let template = self.load_template(template_id).await?;
Ok(template.validate_parameters(parameters))
}
/// 创建模板实例
pub async fn create_instance(&self, template_id: &str, parameters: ParameterValues) -> Result<WorkflowInstance> {
let template = self.load_template(template_id).await?;
// 验证参数
let validation = template.validate_parameters(&parameters);
if !validation.valid {
return Err(anyhow!("参数验证失败: {:?}", validation.errors));
}
// 使用 SDK 创建工作流模板
let workflow_template = WorkflowTemplate::new(
template.template_data.clone(),
template.parameter_schema.clone(),
)?;
// 创建实例
let instance = workflow_template.create_instance(parameters)?;
debug!("创建模板实例成功: {} -> {}", template_id, instance.get_id());
Ok(instance)
}
/// 从工作流创建模板
pub async fn create_template_from_workflow(&self, workflow: &WorkflowModel, parameter_schema: HashMap<String, comfyui_sdk::types::ParameterSchema>) -> Result<TemplateModel> {
// 创建模板元数据
let template_metadata = comfyui_sdk::types::TemplateMetadata {
id: uuid::Uuid::new_v4().to_string(),
name: format!("{} Template", workflow.name),
description: workflow.description.clone(),
version: Some(workflow.version.clone()),
author: None,
tags: Some(workflow.tags.clone()),
category: workflow.category.clone(),
created_at: Some(chrono::Utc::now()),
updated_at: Some(chrono::Utc::now()),
};
// 创建模板数据
let template_data = comfyui_sdk::types::WorkflowTemplateData {
metadata: template_metadata.clone(),
workflow: workflow.workflow_data.clone(),
parameters: parameter_schema.clone(),
};
// 创建模板模型
let template = TemplateModel::new(
template_metadata.name.clone(),
template_data,
parameter_schema,
);
Ok(template)
}
/// 预览模板实例(不执行)
pub async fn preview_instance(&self, template_id: &str, parameters: ParameterValues) -> Result<ComfyUIWorkflow> {
let instance = self.create_instance(template_id, parameters).await?;
Ok(instance.get_workflow().clone())
}
/// 获取模板的参数定义
pub async fn get_parameter_schema(&self, template_id: &str) -> Result<HashMap<String, comfyui_sdk::types::ParameterSchema>> {
let template = self.load_template(template_id).await?;
Ok(template.parameter_schema)
}
/// 验证模板结构
fn validate_template_structure(&self, template: &TemplateModel) -> Result<()> {
// 检查模板名称
if template.name.trim().is_empty() {
return Err(anyhow!("模板名称不能为空"));
}
// 检查工作流数据
if template.template_data.workflow.is_empty() {
return Err(anyhow!("工作流数据不能为空"));
}
// 检查参数定义
if template.parameter_schema.is_empty() {
warn!("模板 {} 没有定义参数", template.id);
}
// 验证工作流结构
for (node_id, node) in &template.template_data.workflow {
if node.class_type.trim().is_empty() {
return Err(anyhow!("节点 {} 的 class_type 不能为空", node_id));
}
}
Ok(())
}
/// 清除模板缓存
pub async fn clear_cache(&self) {
let mut cache = self.template_cache.write().await;
cache.clear();
info!("模板缓存已清除");
}
/// 预热模板缓存
pub async fn warm_cache(&self) -> Result<()> {
let templates = self.repository.list_templates(true).await?;
{
let mut cache = self.template_cache.write().await;
for template in templates {
cache.insert(template.id.clone(), template);
}
}
info!("模板缓存预热完成");
Ok(())
}
/// 获取缓存统计信息
pub async fn get_cache_stats(&self) -> CacheStats {
let cache = self.template_cache.read().await;
CacheStats {
cached_templates: cache.len(),
cache_keys: cache.keys().cloned().collect(),
}
}
/// 搜索模板
pub async fn search_templates(&self, query: &str) -> Result<Vec<TemplateModel>> {
let all_templates = self.repository.list_templates(true).await?;
let query_lower = query.to_lowercase();
let filtered_templates: Vec<TemplateModel> = all_templates
.into_iter()
.filter(|template| {
template.name.to_lowercase().contains(&query_lower) ||
template.description.as_ref().map_or(false, |desc| desc.to_lowercase().contains(&query_lower)) ||
template.tags.iter().any(|tag| tag.to_lowercase().contains(&query_lower))
})
.collect();
Ok(filtered_templates)
}
/// 导出模板
pub async fn export_template(&self, template_id: &str) -> Result<String> {
let template = self.load_template(template_id).await?;
serde_json::to_string_pretty(&template)
.map_err(|e| anyhow!("导出模板失败: {}", e))
}
/// 导入模板
pub async fn import_template(&self, template_json: &str) -> Result<String> {
let mut template: TemplateModel = serde_json::from_str(template_json)
.map_err(|e| anyhow!("解析模板失败: {}", e))?;
// 生成新的 ID 避免冲突
template.id = uuid::Uuid::new_v4().to_string();
template.created_at = chrono::Utc::now();
template.updated_at = chrono::Utc::now();
// 创建模板
self.create_template(template.clone()).await?;
Ok(template.id)
}
}
/// 缓存统计信息
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CacheStats {
pub cached_templates: usize,
pub cache_keys: Vec<String>,
}