//! Main ComfyUI client that combines HTTP and WebSocket functionality use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::time::sleep; use crate::types::{ ComfyUIClientConfig, PromptRequest, TemplateExecutionResult, ExecutionOptions, ExecutionCallbacks, ExecutionError, ParameterValues }; use crate::templates::{TemplateManager, WorkflowTemplate, WorkflowInstance}; use crate::client::{HTTPClient, WebSocketClient}; use crate::error::{ComfyUIError, Result}; /// Main ComfyUI client pub struct ComfyUIClient { http_client: HTTPClient, ws_client: WebSocketClient, template_manager: TemplateManager, config: ComfyUIClientConfig, } impl ComfyUIClient { /// Creates a new ComfyUI client pub fn new(config: ComfyUIClientConfig) -> Result { let http_client = HTTPClient::new(config.clone())?; let ws_client = WebSocketClient::new(config.clone()); let template_manager = TemplateManager::new(); Ok(Self { http_client, ws_client, template_manager, config, }) } /// Gets the HTTP client pub fn http(&self) -> &HTTPClient { &self.http_client } /// Gets the WebSocket client pub fn ws(&self) -> &WebSocketClient { &self.ws_client } /// Gets the template manager pub fn templates(&mut self) -> &mut TemplateManager { &mut self.template_manager } /// Gets the template manager (read-only) pub fn templates_ref(&self) -> &TemplateManager { &self.template_manager } /// Gets the client configuration pub fn config(&self) -> &ComfyUIClientConfig { &self.config } /// Connects to ComfyUI server (both HTTP and WebSocket) pub async fn connect(&mut self) -> Result<()> { // Test HTTP connection first self.http_client.get_queue().await .map_err(|e| ComfyUIError::connection(format!("HTTP connection failed: {}", e)))?; // Connect WebSocket self.ws_client.connect().await .map_err(|e| ComfyUIError::connection(format!("WebSocket connection failed: {}", e)))?; log::info!("Successfully connected to ComfyUI server"); Ok(()) } /// Disconnects from ComfyUI server pub async fn disconnect(&mut self) -> Result<()> { self.ws_client.disconnect().await?; log::info!("Disconnected from ComfyUI server"); Ok(()) } /// Checks if connected to the server pub async fn is_connected(&self) -> bool { self.ws_client.is_connected().await } /// Executes a workflow template pub async fn execute_template( &self, template: &WorkflowTemplate, parameters: ParameterValues, options: ExecutionOptions, ) -> Result { let start_time = Instant::now(); // Create workflow instance let instance = template.create_instance(parameters)?; // Convert to prompt request let prompt_request = self.instance_to_prompt_request(&instance)?; // Submit prompt let prompt_response = self.http_client.queue_prompt(&prompt_request).await?; let prompt_id = prompt_response.prompt_id.clone(); log::info!("Submitted prompt with ID: {}", prompt_id); // Wait for completion if WebSocket is connected if self.ws_client.is_connected().await { match self.wait_for_completion(&prompt_id, options.timeout).await { Ok(outputs) => { let execution_time = start_time.elapsed().as_millis() as u64; Ok(TemplateExecutionResult::success(prompt_id, outputs, execution_time)) } Err(error) => { let execution_time = start_time.elapsed().as_millis() as u64; let exec_error = ExecutionError { node_id: None, message: error.to_string(), details: None, timestamp: chrono::Utc::now(), }; Ok(TemplateExecutionResult::failure(prompt_id, exec_error, execution_time)) } } } else { // If no WebSocket, just return the prompt ID let execution_time = start_time.elapsed().as_millis() as u64; Ok(TemplateExecutionResult::success(prompt_id, HashMap::new(), execution_time)) } } /// Executes a template with callbacks pub async fn execute_template_with_callbacks( &self, template: &WorkflowTemplate, parameters: ParameterValues, options: ExecutionOptions, callbacks: Arc, ) -> Result { // Register callbacks let callback_id = self.ws_client.register_callbacks(callbacks).await; // Execute template let result = self.execute_template(template, parameters, options).await; // Unregister callbacks self.ws_client.unregister_callbacks(&callback_id).await; result } /// Converts a workflow instance to a prompt request fn instance_to_prompt_request(&self, instance: &WorkflowInstance) -> Result { let workflow_json = instance.to_workflow_json()?; let mut prompt = HashMap::new(); if let serde_json::Value::Object(workflow_obj) = workflow_json { for (key, value) in workflow_obj { prompt.insert(key, value); } } Ok(PromptRequest { prompt, client_id: Some(self.ws_client.client_id().to_string()), extra_data: None, }) } /// Waits for prompt completion async fn wait_for_completion( &self, prompt_id: &str, timeout: Option, ) -> Result> { let timeout_duration = timeout.unwrap_or(Duration::from_secs(300)); // 5 minutes default let start_time = Instant::now(); let check_interval = Duration::from_millis(1000); // Check every second loop { // Check timeout if start_time.elapsed() > timeout_duration { return Err(ComfyUIError::timeout(format!( "Execution timed out after {:?}", timeout_duration ))); } // Check history for completion match self.http_client.get_history_by_prompt(prompt_id).await { Ok(history) => { if let Some(history_item) = history.get(prompt_id) { if history_item.status.completed { // Extract outputs let mut outputs = HashMap::new(); for (node_id, node_outputs) in &history_item.outputs { let mut node_output = HashMap::new(); for (output_name, output_list) in node_outputs { node_output.insert(output_name.clone(), serde_json::to_value(output_list)?); } outputs.insert(node_id.clone(), serde_json::to_value(node_output)?); } return Ok(outputs); } } } Err(e) => { log::warn!("Error checking history: {}", e); } } sleep(check_interval).await; } } /// Converts outputs to HTTP URLs pub fn outputs_to_urls(&self, outputs: &HashMap) -> Vec { self.http_client.outputs_to_urls(outputs) } /// Interrupts the current execution pub async fn interrupt(&self) -> Result<()> { self.http_client.interrupt().await } /// Clears the execution queue pub async fn clear_queue(&self) -> Result<()> { self.http_client.clear_queue().await } /// Gets the current queue status pub async fn get_queue_status(&self) -> Result { self.http_client.get_queue().await } /// Gets system statistics pub async fn get_system_stats(&self) -> Result { self.http_client.get_system_stats().await } /// Gets available nodes information pub async fn get_object_info(&self) -> Result { self.http_client.get_object_info().await } /// Uploads an image file pub async fn upload_image>( &self, image_path: P, overwrite: bool, ) -> Result { self.http_client.upload_image(image_path, overwrite).await } /// Gets an image by filename pub async fn get_image( &self, filename: &str, subfolder: Option<&str>, image_type: Option<&str>, ) -> Result { self.http_client.get_image(filename, subfolder, image_type).await } /// Frees memory pub async fn free_memory(&self) -> Result<()> { self.http_client.free_memory().await } }