272 lines
9.2 KiB
Rust
272 lines
9.2 KiB
Rust
//! ComfyUI SDK Examples
|
|
//!
|
|
//! This example demonstrates how to use the ComfyUI SDK for Rust
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use comfyui_sdk::{
|
|
ComfyUIClient, ComfyUIClientConfig, ExecutionOptions,
|
|
AI_MODEL_FACE_HAIR_FIX_TEMPLATE
|
|
};
|
|
use comfyui_sdk::utils::SimpleCallbacks;
|
|
use serde_json::json;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
// Initialize logging (optional)
|
|
// env_logger::init();
|
|
|
|
println!("🚀 Starting ComfyUI SDK Rust Examples");
|
|
|
|
// Run the AI model face hair fix example
|
|
ai_model_face_hair_fix_example().await?;
|
|
|
|
// Run template validation example
|
|
template_validation_example().await?;
|
|
|
|
println!("✅ All examples completed successfully!");
|
|
Ok(())
|
|
}
|
|
|
|
/// AI Model Face & Hair Detail Fix Example
|
|
async fn ai_model_face_hair_fix_example() -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("\n🎨 AI Model Face & Hair Detail Fix Example");
|
|
println!("===========================================");
|
|
|
|
// Initialize client
|
|
let mut client = ComfyUIClient::new(ComfyUIClientConfig {
|
|
base_url: "http://192.168.0.193:8188".to_string(),
|
|
..Default::default()
|
|
})?;
|
|
|
|
// Connect to ComfyUI server
|
|
println!("📡 Connecting to ComfyUI server...");
|
|
client.connect().await?;
|
|
println!("✅ Connected successfully!");
|
|
|
|
// Register the built-in template
|
|
println!("📝 Registering AI Model Face & Hair Fix template...");
|
|
client.templates().register_from_data(AI_MODEL_FACE_HAIR_FIX_TEMPLATE.clone())?;
|
|
println!("✅ Template registered!");
|
|
|
|
// Get the registered template
|
|
let template = client.templates_ref()
|
|
.get_by_id(&AI_MODEL_FACE_HAIR_FIX_TEMPLATE.metadata.id)
|
|
.ok_or("Failed to get template")?;
|
|
|
|
println!("📋 Template ID: {}", template.id());
|
|
println!("📋 Template Name: {}", template.name());
|
|
if let Some(description) = template.description() {
|
|
println!("📋 Template Description: {}", description);
|
|
}
|
|
|
|
// Display template parameters
|
|
println!("\n🔍 Template Parameters:");
|
|
for (name, schema) in template.parameters() {
|
|
println!(" - {}: {:?} ({})",
|
|
name,
|
|
schema.param_type,
|
|
schema.description.as_deref().unwrap_or("No description")
|
|
);
|
|
}
|
|
|
|
// Create execution callbacks
|
|
let callbacks = Arc::new(
|
|
SimpleCallbacks::new()
|
|
.with_progress(|progress| {
|
|
let percentage = (progress.progress as f64 / progress.max as f64 * 100.0) as u32;
|
|
println!("⏳ Progress: {}% ({}/{}) - Node: {}",
|
|
percentage, progress.progress, progress.max, progress.node_id);
|
|
})
|
|
.with_executing(|node_id| {
|
|
println!("🔄 Executing node: {}", node_id);
|
|
})
|
|
.with_executed(|result| {
|
|
println!("✅ Node executed - Prompt ID: {}", result.prompt_id);
|
|
})
|
|
.with_error(|error| {
|
|
println!("❌ Execution error: {}", error.message);
|
|
})
|
|
);
|
|
|
|
// Execute the template with test parameters
|
|
println!("\n🎨 Executing AI Model Face & Hair Enhancement...");
|
|
let mut parameters = HashMap::new();
|
|
parameters.insert("input_image".to_string(), json!("20250806-190606.jpg"));
|
|
parameters.insert("face_prompt".to_string(), json!("beautiful woman, perfect skin, detailed facial features"));
|
|
parameters.insert("face_denoise".to_string(), json!("0.25"));
|
|
|
|
let execution_options = ExecutionOptions {
|
|
timeout: Some(std::time::Duration::from_secs(120)), // 2 minutes timeout
|
|
priority: None,
|
|
};
|
|
|
|
let result = client.execute_template_with_callbacks(
|
|
template,
|
|
parameters,
|
|
execution_options,
|
|
callbacks,
|
|
).await?;
|
|
|
|
if result.success {
|
|
println!("\n🎉 AI Model Enhancement completed successfully!");
|
|
println!("📊 Execution time: {}ms", result.execution_time);
|
|
println!("🆔 Prompt ID: {}", result.prompt_id);
|
|
|
|
if let Some(outputs) = &result.outputs {
|
|
println!("📁 Outputs: {}", serde_json::to_string_pretty(outputs)?);
|
|
|
|
// Convert outputs to HTTP URLs
|
|
let image_urls = client.outputs_to_urls(outputs);
|
|
if !image_urls.is_empty() {
|
|
println!("🖼️ Enhanced Image URLs:");
|
|
for url in image_urls {
|
|
println!(" - {}", url);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
println!("❌ AI Model Enhancement failed");
|
|
if let Some(error) = &result.error {
|
|
println!(" Error: {}", error.message);
|
|
}
|
|
}
|
|
|
|
// Disconnect
|
|
println!("\n🔌 Disconnecting...");
|
|
client.disconnect().await?;
|
|
println!("👋 Disconnected!");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Template validation example
|
|
async fn template_validation_example() -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("\n🔍 Template Validation Example");
|
|
println!("==============================");
|
|
|
|
let mut client = ComfyUIClient::new(ComfyUIClientConfig {
|
|
base_url: "http://192.168.0.193:8188".to_string(),
|
|
..Default::default()
|
|
})?;
|
|
|
|
// Register template
|
|
client.templates().register_from_data(AI_MODEL_FACE_HAIR_FIX_TEMPLATE.clone())?;
|
|
let template = client.templates_ref()
|
|
.get_by_id(&AI_MODEL_FACE_HAIR_FIX_TEMPLATE.metadata.id)
|
|
.ok_or("Failed to get template")?;
|
|
|
|
println!("📝 Testing parameter validation...");
|
|
|
|
// Test valid parameters
|
|
let mut valid_params = HashMap::new();
|
|
valid_params.insert("input_image".to_string(), json!("test.jpg"));
|
|
valid_params.insert("face_prompt".to_string(), json!("beautiful face"));
|
|
valid_params.insert("face_denoise".to_string(), json!("0.3"));
|
|
|
|
let validation_result = template.validate(&valid_params);
|
|
println!("✅ Valid parameters test: {}",
|
|
if validation_result.valid { "PASSED" } else { "FAILED" });
|
|
|
|
if !validation_result.valid {
|
|
println!("❌ Validation errors:");
|
|
for error in &validation_result.errors {
|
|
println!(" - {}: {}", error.path, error.message);
|
|
}
|
|
}
|
|
|
|
// Test invalid parameters
|
|
let mut invalid_params = HashMap::new();
|
|
invalid_params.insert("input_image".to_string(), json!("")); // Empty required field
|
|
invalid_params.insert("face_prompt".to_string(), json!("test"));
|
|
invalid_params.insert("unknown_param".to_string(), json!("value")); // Unknown parameter
|
|
|
|
let invalid_validation = template.validate(&invalid_params);
|
|
println!("🚫 Invalid parameters test: {}",
|
|
if !invalid_validation.valid { "PASSED" } else { "FAILED" });
|
|
|
|
if !invalid_validation.valid {
|
|
println!("📋 Expected validation errors:");
|
|
for error in &invalid_validation.errors {
|
|
println!(" - {}: {}", error.path, error.message);
|
|
}
|
|
}
|
|
|
|
// Test template instance creation
|
|
println!("\n🏗️ Testing template instance creation...");
|
|
match template.create_instance(valid_params) {
|
|
Ok(instance) => {
|
|
println!("✅ Template instance created successfully");
|
|
println!(" Template ID: {}", instance.template_id());
|
|
println!(" Template Name: {}", instance.template_name());
|
|
println!(" Node count: {}", instance.node_count());
|
|
}
|
|
Err(e) => {
|
|
println!("❌ Failed to create template instance: {}", e);
|
|
}
|
|
}
|
|
|
|
// Test template manager features
|
|
println!("\n📚 Testing template manager features...");
|
|
let templates = client.templates_ref();
|
|
println!(" Total templates: {}", templates.count());
|
|
println!(" Template IDs: {:?}", templates.list_ids());
|
|
|
|
let categories = templates.get_categories();
|
|
if !categories.is_empty() {
|
|
println!(" Categories: {:?}", categories);
|
|
}
|
|
|
|
let tags = templates.get_tags();
|
|
if !tags.is_empty() {
|
|
println!(" Tags: {:?}", tags);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Basic usage example (simpler version)
|
|
#[allow(dead_code)]
|
|
async fn basic_usage_example() -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("\n📖 Basic Usage Example");
|
|
println!("======================");
|
|
|
|
// Create client
|
|
let mut client = ComfyUIClient::new(ComfyUIClientConfig {
|
|
base_url: "http://localhost:8188".to_string(),
|
|
..Default::default()
|
|
})?;
|
|
|
|
// Connect
|
|
client.connect().await?;
|
|
println!("✅ Connected to ComfyUI!");
|
|
|
|
// Get system stats
|
|
match client.get_system_stats().await {
|
|
Ok(stats) => {
|
|
println!("💻 System Info:");
|
|
println!(" OS: {}", stats.system.os);
|
|
println!(" Python: {}", stats.system.python_version);
|
|
println!(" Devices: {}", stats.devices.len());
|
|
}
|
|
Err(e) => {
|
|
println!("❌ Failed to get system stats: {}", e);
|
|
}
|
|
}
|
|
|
|
// Get queue status
|
|
match client.get_queue_status().await {
|
|
Ok(queue) => {
|
|
println!("📋 Queue Status:");
|
|
println!(" Running: {}", queue.queue_running.len());
|
|
println!(" Pending: {}", queue.queue_pending.len());
|
|
}
|
|
Err(e) => {
|
|
println!("❌ Failed to get queue status: {}", e);
|
|
}
|
|
}
|
|
|
|
client.disconnect().await?;
|
|
Ok(())
|
|
}
|