257 lines
8.0 KiB
Rust
257 lines
8.0 KiB
Rust
//! Template parsing utilities for parameter substitution
|
|
|
|
use std::collections::HashMap;
|
|
use regex::Regex;
|
|
use crate::types::{ComfyUIWorkflow, ParameterValues, ComfyUINode, VariableSyntax};
|
|
use crate::error::{ComfyUIError, Result};
|
|
|
|
/// Applies parameters to a workflow template
|
|
pub fn apply_parameters(
|
|
workflow: &ComfyUIWorkflow,
|
|
parameters: &ParameterValues,
|
|
) -> Result<ComfyUIWorkflow> {
|
|
let mut resolved_workflow = HashMap::new();
|
|
|
|
for (node_id, node) in workflow {
|
|
let resolved_node = apply_parameters_to_node(node, parameters)?;
|
|
resolved_workflow.insert(node_id.clone(), resolved_node);
|
|
}
|
|
|
|
Ok(resolved_workflow)
|
|
}
|
|
|
|
/// Applies parameters to a single node
|
|
fn apply_parameters_to_node(
|
|
node: &ComfyUINode,
|
|
parameters: &ParameterValues,
|
|
) -> Result<ComfyUINode> {
|
|
let mut resolved_inputs = HashMap::new();
|
|
|
|
for (input_name, input_value) in &node.inputs {
|
|
let resolved_value = apply_parameters_to_value(input_value, parameters)?;
|
|
resolved_inputs.insert(input_name.clone(), resolved_value);
|
|
}
|
|
|
|
Ok(ComfyUINode {
|
|
class_type: node.class_type.clone(),
|
|
inputs: resolved_inputs,
|
|
_meta: node._meta.clone(),
|
|
})
|
|
}
|
|
|
|
/// Applies parameters to a JSON value
|
|
fn apply_parameters_to_value(
|
|
value: &serde_json::Value,
|
|
parameters: &ParameterValues,
|
|
) -> Result<serde_json::Value> {
|
|
match value {
|
|
serde_json::Value::String(s) => {
|
|
let resolved_string = substitute_variables(s, parameters, VariableSyntax::DoubleBrace)?;
|
|
Ok(serde_json::Value::String(resolved_string))
|
|
}
|
|
serde_json::Value::Array(arr) => {
|
|
let mut resolved_array = Vec::new();
|
|
for item in arr {
|
|
resolved_array.push(apply_parameters_to_value(item, parameters)?);
|
|
}
|
|
Ok(serde_json::Value::Array(resolved_array))
|
|
}
|
|
serde_json::Value::Object(obj) => {
|
|
let mut resolved_object = serde_json::Map::new();
|
|
for (key, val) in obj {
|
|
let resolved_key = substitute_variables(key, parameters, VariableSyntax::DoubleBrace)?;
|
|
let resolved_val = apply_parameters_to_value(val, parameters)?;
|
|
resolved_object.insert(resolved_key, resolved_val);
|
|
}
|
|
Ok(serde_json::Value::Object(resolved_object))
|
|
}
|
|
_ => Ok(value.clone()),
|
|
}
|
|
}
|
|
|
|
/// Substitutes variables in a string using the specified syntax
|
|
pub fn substitute_variables(
|
|
template: &str,
|
|
parameters: &ParameterValues,
|
|
syntax: VariableSyntax,
|
|
) -> Result<String> {
|
|
let pattern = match syntax {
|
|
VariableSyntax::DoubleBrace => r"\{\{([^}]+)\}\}",
|
|
VariableSyntax::DollarBrace => r"\$\{([^}]+)\}",
|
|
VariableSyntax::AtBrace => r"@\{([^}]+)\}",
|
|
};
|
|
|
|
let regex = Regex::new(pattern)
|
|
.map_err(|e| ComfyUIError::new(format!("Invalid regex pattern: {e}")))?;
|
|
|
|
let mut result = template.to_string();
|
|
let mut offset = 0i32;
|
|
|
|
for captures in regex.captures_iter(template) {
|
|
let full_match = captures.get(0).unwrap();
|
|
let var_name = captures.get(1).unwrap().as_str().trim();
|
|
|
|
// Get parameter value
|
|
let replacement = match parameters.get(var_name) {
|
|
Some(value) => value_to_string(value)?,
|
|
None => {
|
|
return Err(ComfyUIError::template_validation(
|
|
format!("Parameter '{var_name}' not found")
|
|
));
|
|
}
|
|
};
|
|
|
|
// Calculate positions with offset
|
|
let start = (full_match.start() as i32 + offset) as usize;
|
|
let end = (full_match.end() as i32 + offset) as usize;
|
|
|
|
// Replace the variable
|
|
result.replace_range(start..end, &replacement);
|
|
|
|
// Update offset
|
|
offset += replacement.len() as i32 - full_match.len() as i32;
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Converts a JSON value to string for substitution
|
|
fn value_to_string(value: &serde_json::Value) -> Result<String> {
|
|
match value {
|
|
serde_json::Value::String(s) => Ok(s.clone()),
|
|
serde_json::Value::Number(n) => Ok(n.to_string()),
|
|
serde_json::Value::Bool(b) => Ok(b.to_string()),
|
|
serde_json::Value::Null => Ok("null".to_string()),
|
|
_ => {
|
|
// For complex types, serialize to JSON
|
|
serde_json::to_string(value)
|
|
.map_err(ComfyUIError::from)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Extracts variable names from a template string
|
|
pub fn extract_variables(template: &str, syntax: VariableSyntax) -> Result<Vec<String>> {
|
|
let pattern = match syntax {
|
|
VariableSyntax::DoubleBrace => r"\{\{([^}]+)\}\}",
|
|
VariableSyntax::DollarBrace => r"\$\{([^}]+)\}",
|
|
VariableSyntax::AtBrace => r"@\{([^}]+)\}",
|
|
};
|
|
|
|
let regex = Regex::new(pattern)
|
|
.map_err(|e| ComfyUIError::new(format!("Invalid regex pattern: {e}")))?;
|
|
|
|
let mut variables = Vec::new();
|
|
for captures in regex.captures_iter(template) {
|
|
if let Some(var_match) = captures.get(1) {
|
|
let var_name = var_match.as_str().trim().to_string();
|
|
if !variables.contains(&var_name) {
|
|
variables.push(var_name);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(variables)
|
|
}
|
|
|
|
/// Extracts all variables from a workflow
|
|
pub fn extract_workflow_variables(workflow: &ComfyUIWorkflow) -> Result<Vec<String>> {
|
|
let mut all_variables = Vec::new();
|
|
|
|
for node in workflow.values() {
|
|
let node_variables = extract_node_variables(node)?;
|
|
for var in node_variables {
|
|
if !all_variables.contains(&var) {
|
|
all_variables.push(var);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(all_variables)
|
|
}
|
|
|
|
/// Extracts variables from a single node
|
|
fn extract_node_variables(node: &ComfyUINode) -> Result<Vec<String>> {
|
|
let mut variables = Vec::new();
|
|
|
|
for input_value in node.inputs.values() {
|
|
let value_variables = extract_value_variables(input_value)?;
|
|
for var in value_variables {
|
|
if !variables.contains(&var) {
|
|
variables.push(var);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(variables)
|
|
}
|
|
|
|
/// Extracts variables from a JSON value
|
|
fn extract_value_variables(value: &serde_json::Value) -> Result<Vec<String>> {
|
|
let mut variables = Vec::new();
|
|
|
|
match value {
|
|
serde_json::Value::String(s) => {
|
|
let string_vars = extract_variables(s, VariableSyntax::DoubleBrace)?;
|
|
variables.extend(string_vars);
|
|
}
|
|
serde_json::Value::Array(arr) => {
|
|
for item in arr {
|
|
let item_vars = extract_value_variables(item)?;
|
|
variables.extend(item_vars);
|
|
}
|
|
}
|
|
serde_json::Value::Object(obj) => {
|
|
for (key, val) in obj {
|
|
let key_vars = extract_variables(key, VariableSyntax::DoubleBrace)?;
|
|
variables.extend(key_vars);
|
|
|
|
let val_vars = extract_value_variables(val)?;
|
|
variables.extend(val_vars);
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
// Remove duplicates
|
|
variables.sort();
|
|
variables.dedup();
|
|
|
|
Ok(variables)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use serde_json::json;
|
|
|
|
#[test]
|
|
fn test_substitute_variables() {
|
|
let mut parameters = HashMap::new();
|
|
parameters.insert("name".to_string(), json!("test"));
|
|
parameters.insert("value".to_string(), json!(42));
|
|
|
|
let template = "Hello {{name}}, value is {{value}}";
|
|
let result = substitute_variables(template, ¶meters, VariableSyntax::DoubleBrace).unwrap();
|
|
|
|
assert_eq!(result, "Hello test, value is 42");
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_variables() {
|
|
let template = "{{var1}} and {{var2}} and {{var1}} again";
|
|
let variables = extract_variables(template, VariableSyntax::DoubleBrace).unwrap();
|
|
|
|
assert_eq!(variables, vec!["var1", "var2"]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_missing_parameter() {
|
|
let parameters = HashMap::new();
|
|
let template = "Hello {{missing}}";
|
|
|
|
let result = substitute_variables(template, ¶meters, VariableSyntax::DoubleBrace);
|
|
assert!(result.is_err());
|
|
}
|
|
}
|