mixvideo-v2/cargos/comfyui-sdk/utils/event_emitter.rs

192 lines
5.3 KiB
Rust

//! Event emitter utilities
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;
use crate::types::{ExecutionProgress, ExecutionResult, ExecutionError, ExecutionCallbacks};
/// Event emitter for execution events
#[derive(Clone)]
pub struct EventEmitter {
callbacks: Arc<RwLock<HashMap<String, Arc<dyn ExecutionCallbacks>>>>,
last_error: Arc<RwLock<Option<ExecutionError>>>,
}
impl EventEmitter {
/// Creates a new event emitter
pub fn new() -> Self {
Self {
callbacks: Arc::new(RwLock::new(HashMap::new())),
last_error: Arc::new(RwLock::new(None)),
}
}
/// Registers callbacks with a unique ID
pub async fn register_callbacks(&self, id: String, callbacks: Arc<dyn ExecutionCallbacks>) {
let mut callbacks_map = self.callbacks.write().await;
callbacks_map.insert(id, callbacks);
}
/// Unregisters callbacks by ID
pub async fn unregister_callbacks(&self, id: &str) {
let mut callbacks_map = self.callbacks.write().await;
callbacks_map.remove(id);
}
/// Emits a progress event
pub async fn emit_progress(&self, progress: ExecutionProgress) {
let callbacks_map = self.callbacks.read().await;
for callback in callbacks_map.values() {
callback.on_progress(progress.clone());
}
}
/// Emits an executing event
pub async fn emit_executing(&self, node_id: String) {
let callbacks_map = self.callbacks.read().await;
for callback in callbacks_map.values() {
callback.on_executing(node_id.clone());
}
}
/// Emits an executed event
pub async fn emit_executed(&self, result: ExecutionResult) {
let callbacks_map = self.callbacks.read().await;
for callback in callbacks_map.values() {
callback.on_executed(result.clone());
}
}
/// Emits an error event
pub async fn emit_error(&self, error: ExecutionError) {
// Store the error for later retrieval
{
let mut last_error_guard = self.last_error.write().await;
*last_error_guard = Some(error.clone());
}
let callbacks_map = self.callbacks.read().await;
for callback in callbacks_map.values() {
callback.on_error(error.clone());
}
}
/// Gets the number of registered callbacks
pub async fn callback_count(&self) -> usize {
let callbacks_map = self.callbacks.read().await;
callbacks_map.len()
}
/// Clears all callbacks
pub async fn clear(&self) {
let mut callbacks_map = self.callbacks.write().await;
callbacks_map.clear();
}
/// Gets the last execution error and clears it
pub async fn get_last_error(&self) -> Option<ExecutionError> {
let mut error_guard = self.last_error.write().await;
error_guard.take()
}
/// Clears the last error
pub async fn clear_last_error(&self) {
let mut error_guard = self.last_error.write().await;
*error_guard = None;
}
}
impl Default for EventEmitter {
fn default() -> Self {
Self::new()
}
}
/// Simple callback implementation for testing
pub struct SimpleCallbacks {
pub on_progress_fn: Option<Box<dyn Fn(ExecutionProgress) + Send + Sync>>,
pub on_executing_fn: Option<Box<dyn Fn(String) + Send + Sync>>,
pub on_executed_fn: Option<Box<dyn Fn(ExecutionResult) + Send + Sync>>,
pub on_error_fn: Option<Box<dyn Fn(ExecutionError) + Send + Sync>>,
}
impl SimpleCallbacks {
/// Creates new simple callbacks
pub fn new() -> Self {
Self {
on_progress_fn: None,
on_executing_fn: None,
on_executed_fn: None,
on_error_fn: None,
}
}
/// Sets the progress callback
pub fn with_progress<F>(mut self, f: F) -> Self
where
F: Fn(ExecutionProgress) + Send + Sync + 'static,
{
self.on_progress_fn = Some(Box::new(f));
self
}
/// Sets the executing callback
pub fn with_executing<F>(mut self, f: F) -> Self
where
F: Fn(String) + Send + Sync + 'static,
{
self.on_executing_fn = Some(Box::new(f));
self
}
/// Sets the executed callback
pub fn with_executed<F>(mut self, f: F) -> Self
where
F: Fn(ExecutionResult) + Send + Sync + 'static,
{
self.on_executed_fn = Some(Box::new(f));
self
}
/// Sets the error callback
pub fn with_error<F>(mut self, f: F) -> Self
where
F: Fn(ExecutionError) + Send + Sync + 'static,
{
self.on_error_fn = Some(Box::new(f));
self
}
}
impl ExecutionCallbacks for SimpleCallbacks {
fn on_progress(&self, progress: ExecutionProgress) {
if let Some(ref f) = self.on_progress_fn {
f(progress);
}
}
fn on_executing(&self, node_id: String) {
if let Some(ref f) = self.on_executing_fn {
f(node_id);
}
}
fn on_executed(&self, result: ExecutionResult) {
if let Some(ref f) = self.on_executed_fn {
f(result);
}
}
fn on_error(&self, error: ExecutionError) {
if let Some(ref f) = self.on_error_fn {
f(error);
}
}
}
impl Default for SimpleCallbacks {
fn default() -> Self {
Self::new()
}
}