mixvideo-v2/cargos/comfyui-sdk/utils/template_parser.rs

257 lines
8.1 KiB
Rust

//! Template parsing utilities for parameter substitution
use std::collections::HashMap;
use regex::Regex;
use crate::types::{ComfyUIWorkflow, ParameterValues, ComfyUINode, VariableSyntax};
use crate::error::{ComfyUIError, Result};
/// Applies parameters to a workflow template
pub fn apply_parameters(
workflow: &ComfyUIWorkflow,
parameters: &ParameterValues,
) -> Result<ComfyUIWorkflow> {
let mut resolved_workflow = HashMap::new();
for (node_id, node) in workflow {
let resolved_node = apply_parameters_to_node(node, parameters)?;
resolved_workflow.insert(node_id.clone(), resolved_node);
}
Ok(resolved_workflow)
}
/// Applies parameters to a single node
fn apply_parameters_to_node(
node: &ComfyUINode,
parameters: &ParameterValues,
) -> Result<ComfyUINode> {
let mut resolved_inputs = HashMap::new();
for (input_name, input_value) in &node.inputs {
let resolved_value = apply_parameters_to_value(input_value, parameters)?;
resolved_inputs.insert(input_name.clone(), resolved_value);
}
Ok(ComfyUINode {
class_type: node.class_type.clone(),
inputs: resolved_inputs,
_meta: node._meta.clone(),
})
}
/// Applies parameters to a JSON value
fn apply_parameters_to_value(
value: &serde_json::Value,
parameters: &ParameterValues,
) -> Result<serde_json::Value> {
match value {
serde_json::Value::String(s) => {
let resolved_string = substitute_variables(s, parameters, VariableSyntax::DoubleBrace)?;
Ok(serde_json::Value::String(resolved_string))
}
serde_json::Value::Array(arr) => {
let mut resolved_array = Vec::new();
for item in arr {
resolved_array.push(apply_parameters_to_value(item, parameters)?);
}
Ok(serde_json::Value::Array(resolved_array))
}
serde_json::Value::Object(obj) => {
let mut resolved_object = serde_json::Map::new();
for (key, val) in obj {
let resolved_key = substitute_variables(key, parameters, VariableSyntax::DoubleBrace)?;
let resolved_val = apply_parameters_to_value(val, parameters)?;
resolved_object.insert(resolved_key, resolved_val);
}
Ok(serde_json::Value::Object(resolved_object))
}
_ => Ok(value.clone()),
}
}
/// Substitutes variables in a string using the specified syntax
pub fn substitute_variables(
template: &str,
parameters: &ParameterValues,
syntax: VariableSyntax,
) -> Result<String> {
let pattern = match syntax {
VariableSyntax::DoubleBrace => r"\{\{([^}]+)\}\}",
VariableSyntax::DollarBrace => r"\$\{([^}]+)\}",
VariableSyntax::AtBrace => r"@\{([^}]+)\}",
};
let regex = Regex::new(pattern)
.map_err(|e| ComfyUIError::new(format!("Invalid regex pattern: {}", e)))?;
let mut result = template.to_string();
let mut offset = 0i32;
for captures in regex.captures_iter(template) {
let full_match = captures.get(0).unwrap();
let var_name = captures.get(1).unwrap().as_str().trim();
// Get parameter value
let replacement = match parameters.get(var_name) {
Some(value) => value_to_string(value)?,
None => {
return Err(ComfyUIError::template_validation(
format!("Parameter '{}' not found", var_name)
));
}
};
// Calculate positions with offset
let start = (full_match.start() as i32 + offset) as usize;
let end = (full_match.end() as i32 + offset) as usize;
// Replace the variable
result.replace_range(start..end, &replacement);
// Update offset
offset += replacement.len() as i32 - full_match.len() as i32;
}
Ok(result)
}
/// Converts a JSON value to string for substitution
fn value_to_string(value: &serde_json::Value) -> Result<String> {
match value {
serde_json::Value::String(s) => Ok(s.clone()),
serde_json::Value::Number(n) => Ok(n.to_string()),
serde_json::Value::Bool(b) => Ok(b.to_string()),
serde_json::Value::Null => Ok("null".to_string()),
_ => {
// For complex types, serialize to JSON
serde_json::to_string(value)
.map_err(ComfyUIError::from)
}
}
}
/// Extracts variable names from a template string
pub fn extract_variables(template: &str, syntax: VariableSyntax) -> Result<Vec<String>> {
let pattern = match syntax {
VariableSyntax::DoubleBrace => r"\{\{([^}]+)\}\}",
VariableSyntax::DollarBrace => r"\$\{([^}]+)\}",
VariableSyntax::AtBrace => r"@\{([^}]+)\}",
};
let regex = Regex::new(pattern)
.map_err(|e| ComfyUIError::new(format!("Invalid regex pattern: {}", e)))?;
let mut variables = Vec::new();
for captures in regex.captures_iter(template) {
if let Some(var_match) = captures.get(1) {
let var_name = var_match.as_str().trim().to_string();
if !variables.contains(&var_name) {
variables.push(var_name);
}
}
}
Ok(variables)
}
/// Extracts all variables from a workflow
pub fn extract_workflow_variables(workflow: &ComfyUIWorkflow) -> Result<Vec<String>> {
let mut all_variables = Vec::new();
for node in workflow.values() {
let node_variables = extract_node_variables(node)?;
for var in node_variables {
if !all_variables.contains(&var) {
all_variables.push(var);
}
}
}
Ok(all_variables)
}
/// Extracts variables from a single node
fn extract_node_variables(node: &ComfyUINode) -> Result<Vec<String>> {
let mut variables = Vec::new();
for input_value in node.inputs.values() {
let value_variables = extract_value_variables(input_value)?;
for var in value_variables {
if !variables.contains(&var) {
variables.push(var);
}
}
}
Ok(variables)
}
/// Extracts variables from a JSON value
fn extract_value_variables(value: &serde_json::Value) -> Result<Vec<String>> {
let mut variables = Vec::new();
match value {
serde_json::Value::String(s) => {
let string_vars = extract_variables(s, VariableSyntax::DoubleBrace)?;
variables.extend(string_vars);
}
serde_json::Value::Array(arr) => {
for item in arr {
let item_vars = extract_value_variables(item)?;
variables.extend(item_vars);
}
}
serde_json::Value::Object(obj) => {
for (key, val) in obj {
let key_vars = extract_variables(key, VariableSyntax::DoubleBrace)?;
variables.extend(key_vars);
let val_vars = extract_value_variables(val)?;
variables.extend(val_vars);
}
}
_ => {}
}
// Remove duplicates
variables.sort();
variables.dedup();
Ok(variables)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_substitute_variables() {
let mut parameters = HashMap::new();
parameters.insert("name".to_string(), json!("test"));
parameters.insert("value".to_string(), json!(42));
let template = "Hello {{name}}, value is {{value}}";
let result = substitute_variables(template, &parameters, VariableSyntax::DoubleBrace).unwrap();
assert_eq!(result, "Hello test, value is 42");
}
#[test]
fn test_extract_variables() {
let template = "{{var1}} and {{var2}} and {{var1}} again";
let variables = extract_variables(template, VariableSyntax::DoubleBrace).unwrap();
assert_eq!(variables, vec!["var1", "var2"]);
}
#[test]
fn test_missing_parameter() {
let parameters = HashMap::new();
let template = "Hello {{missing}}";
let result = substitute_variables(template, &parameters, VariableSyntax::DoubleBrace);
assert!(result.is_err());
}
}