mixvideo-v2/cargos/comfyui-sdk/utils/error_handling.rs

216 lines
6.3 KiB
Rust

//! Error handling utilities
use crate::error::{ComfyUIError, Result};
use std::time::Duration;
use tokio::time::sleep;
/// Retry configuration
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(1000),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
}
/// Retries an async operation with exponential backoff
pub async fn retry_with_backoff<F, Fut, T>(
operation: F,
config: RetryConfig,
) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
let mut delay = config.initial_delay;
for attempt in 1..=config.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
last_error = Some(error);
if attempt < config.max_attempts {
log::warn!("Attempt {attempt} failed, retrying in {delay:?}");
sleep(delay).await;
// Exponential backoff
delay = std::cmp::min(
Duration::from_millis(
(delay.as_millis() as f64 * config.backoff_multiplier) as u64
),
config.max_delay,
);
}
}
}
}
Err(last_error.unwrap_or_else(|| ComfyUIError::new("Retry failed with no error")))
}
/// Checks if an error is retryable
pub fn is_retryable_error(error: &ComfyUIError) -> bool {
match error {
ComfyUIError::Http(reqwest_error) => {
// Retry on network errors, timeouts, and 5xx status codes
reqwest_error.is_timeout() ||
reqwest_error.is_connect() ||
reqwest_error.status().is_some_and(|status| status.is_server_error())
}
ComfyUIError::WebSocket(_) => true, // Most WebSocket errors are retryable
ComfyUIError::Connection(_) => true,
ComfyUIError::Timeout(_) => true,
ComfyUIError::Io(_) => true,
_ => false, // Don't retry validation errors, etc.
}
}
/// Retries an operation only if the error is retryable
pub async fn retry_if_retryable<F, Fut, T>(
operation: F,
config: RetryConfig,
) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
let mut delay = config.initial_delay;
for attempt in 1..=config.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
if !is_retryable_error(&error) {
return Err(error);
}
last_error = Some(error);
if attempt < config.max_attempts {
log::warn!("Attempt {attempt} failed with retryable error, retrying in {delay:?}");
sleep(delay).await;
// Exponential backoff
delay = std::cmp::min(
Duration::from_millis(
(delay.as_millis() as f64 * config.backoff_multiplier) as u64
),
config.max_delay,
);
}
}
}
}
Err(last_error.unwrap_or_else(|| ComfyUIError::new("Retry failed with no error")))
}
/// Wraps an operation with timeout
pub async fn with_timeout<F, T>(
operation: F,
timeout: Duration,
) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
match tokio::time::timeout(timeout, operation).await {
Ok(result) => result,
Err(_) => Err(ComfyUIError::timeout(format!(
"Operation timed out after {timeout:?}"
))),
}
}
/// Error context helper
pub trait ErrorContext<T> {
fn with_context(self, context: &str) -> Result<T>;
}
impl<T> ErrorContext<T> for Result<T> {
fn with_context(self, context: &str) -> Result<T> {
self.map_err(|error| {
ComfyUIError::new(format!("{context}: {error}"))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_success_on_second_attempt() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
};
let result = retry_with_backoff(
|| {
let counter = counter_clone.clone();
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count == 0 {
Err(ComfyUIError::connection("First attempt fails"))
} else {
Ok("Success")
}
}
},
config,
).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Success");
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_exhausted() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let config = RetryConfig {
max_attempts: 2,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
};
let result: Result<&str> = retry_with_backoff(
|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(ComfyUIError::connection("Always fails"))
}
},
config,
).await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
}