//! Error handling utilities use crate::error::{ComfyUIError, Result}; use std::time::Duration; use tokio::time::sleep; /// Retry configuration #[derive(Debug, Clone)] pub struct RetryConfig { pub max_attempts: u32, pub initial_delay: Duration, pub max_delay: Duration, pub backoff_multiplier: f64, } impl Default for RetryConfig { fn default() -> Self { Self { max_attempts: 3, initial_delay: Duration::from_millis(1000), max_delay: Duration::from_secs(30), backoff_multiplier: 2.0, } } } /// Retries an async operation with exponential backoff pub async fn retry_with_backoff( operation: F, config: RetryConfig, ) -> Result where F: Fn() -> Fut, Fut: std::future::Future>, { let mut last_error = None; let mut delay = config.initial_delay; for attempt in 1..=config.max_attempts { match operation().await { Ok(result) => return Ok(result), Err(error) => { last_error = Some(error); if attempt < config.max_attempts { log::warn!("Attempt {attempt} failed, retrying in {delay:?}"); sleep(delay).await; // Exponential backoff delay = std::cmp::min( Duration::from_millis( (delay.as_millis() as f64 * config.backoff_multiplier) as u64 ), config.max_delay, ); } } } } Err(last_error.unwrap_or_else(|| ComfyUIError::new("Retry failed with no error"))) } /// Checks if an error is retryable pub fn is_retryable_error(error: &ComfyUIError) -> bool { match error { ComfyUIError::Http(reqwest_error) => { // Retry on network errors, timeouts, and 5xx status codes reqwest_error.is_timeout() || reqwest_error.is_connect() || reqwest_error.status().is_some_and(|status| status.is_server_error()) } ComfyUIError::WebSocket(_) => true, // Most WebSocket errors are retryable ComfyUIError::Connection(_) => true, ComfyUIError::Timeout(_) => true, ComfyUIError::Io(_) => true, _ => false, // Don't retry validation errors, etc. } } /// Retries an operation only if the error is retryable pub async fn retry_if_retryable( operation: F, config: RetryConfig, ) -> Result where F: Fn() -> Fut, Fut: std::future::Future>, { let mut last_error = None; let mut delay = config.initial_delay; for attempt in 1..=config.max_attempts { match operation().await { Ok(result) => return Ok(result), Err(error) => { if !is_retryable_error(&error) { return Err(error); } last_error = Some(error); if attempt < config.max_attempts { log::warn!("Attempt {attempt} failed with retryable error, retrying in {delay:?}"); sleep(delay).await; // Exponential backoff delay = std::cmp::min( Duration::from_millis( (delay.as_millis() as f64 * config.backoff_multiplier) as u64 ), config.max_delay, ); } } } } Err(last_error.unwrap_or_else(|| ComfyUIError::new("Retry failed with no error"))) } /// Wraps an operation with timeout pub async fn with_timeout( operation: F, timeout: Duration, ) -> Result where F: std::future::Future>, { match tokio::time::timeout(timeout, operation).await { Ok(result) => result, Err(_) => Err(ComfyUIError::timeout(format!( "Operation timed out after {timeout:?}" ))), } } /// Error context helper pub trait ErrorContext { fn with_context(self, context: &str) -> Result; } impl ErrorContext for Result { fn with_context(self, context: &str) -> Result { self.map_err(|error| { ComfyUIError::new(format!("{context}: {error}")) }) } } #[cfg(test)] mod tests { use super::*; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; #[tokio::test] async fn test_retry_success_on_second_attempt() { let counter = Arc::new(AtomicU32::new(0)); let counter_clone = counter.clone(); let config = RetryConfig { max_attempts: 3, initial_delay: Duration::from_millis(10), max_delay: Duration::from_millis(100), backoff_multiplier: 2.0, }; let result = retry_with_backoff( || { let counter = counter_clone.clone(); async move { let count = counter.fetch_add(1, Ordering::SeqCst); if count == 0 { Err(ComfyUIError::connection("First attempt fails")) } else { Ok("Success") } } }, config, ).await; assert!(result.is_ok()); assert_eq!(result.unwrap(), "Success"); assert_eq!(counter.load(Ordering::SeqCst), 2); } #[tokio::test] async fn test_retry_exhausted() { let counter = Arc::new(AtomicU32::new(0)); let counter_clone = counter.clone(); let config = RetryConfig { max_attempts: 2, initial_delay: Duration::from_millis(10), max_delay: Duration::from_millis(100), backoff_multiplier: 2.0, }; let result: Result<&str> = retry_with_backoff( || { let counter = counter_clone.clone(); async move { counter.fetch_add(1, Ordering::SeqCst); Err(ComfyUIError::connection("Always fails")) } }, config, ).await; assert!(result.is_err()); assert_eq!(counter.load(Ordering::SeqCst), 2); } }