//! WebSocket client for ComfyUI real-time events use std::sync::Arc; use tokio::sync::{RwLock, mpsc}; use tokio_tungstenite::{connect_async, tungstenite::Message}; use futures_util::StreamExt; use url::Url; use uuid::Uuid; use chrono::Utc; use crate::types::{ ComfyUIClientConfig, WSMessage, ExecutionProgress, ExecutionResult, ExecutionError, ExecutionCallbacks }; use crate::error::{ComfyUIError, Result}; use crate::utils::event_emitter::EventEmitter; /// WebSocket client for real-time ComfyUI events pub struct WebSocketClient { config: ComfyUIClientConfig, client_id: String, event_emitter: EventEmitter, is_connected: Arc>, shutdown_tx: Option>, } impl WebSocketClient { /// Creates a new WebSocket client pub fn new(config: ComfyUIClientConfig) -> Self { Self { config, client_id: Uuid::new_v4().to_string(), event_emitter: EventEmitter::new(), is_connected: Arc::new(RwLock::new(false)), shutdown_tx: None, } } /// Gets the client ID pub fn client_id(&self) -> &str { &self.client_id } /// Checks if connected pub async fn is_connected(&self) -> bool { *self.is_connected.read().await } /// Connects to the WebSocket server pub async fn connect(&mut self) -> Result<()> { if self.is_connected().await { return Ok(()); } let ws_url = self.build_websocket_url()?; log::info!("Connecting to WebSocket: {}", ws_url); let (ws_stream, _) = connect_async(&ws_url).await?; let (_ws_sender, mut ws_receiver) = ws_stream.split(); // Set connected state *self.is_connected.write().await = true; // Create shutdown channel let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel(); self.shutdown_tx = Some(shutdown_tx); // Clone necessary data for the message handling task let event_emitter = self.event_emitter.clone(); let is_connected = self.is_connected.clone(); // Spawn message handling task tokio::spawn(async move { loop { tokio::select! { // Handle incoming messages message = ws_receiver.next() => { match message { Some(Ok(Message::Text(text))) => { if let Err(e) = Self::handle_message(&text, &event_emitter).await { log::error!("Error handling WebSocket message: {}", e); } } Some(Ok(Message::Close(_))) => { log::info!("WebSocket connection closed by server"); break; } Some(Err(e)) => { log::error!("WebSocket error: {}", e); break; } None => { log::info!("WebSocket stream ended"); break; } _ => {} // Ignore other message types } } // Handle shutdown signal _ = shutdown_rx.recv() => { log::info!("WebSocket shutdown requested"); break; } } } // Set disconnected state *is_connected.write().await = false; log::info!("WebSocket connection closed"); }); Ok(()) } /// Disconnects from the WebSocket server pub async fn disconnect(&mut self) -> Result<()> { if !self.is_connected().await { return Ok(()); } if let Some(shutdown_tx) = &self.shutdown_tx { let _ = shutdown_tx.send(()); } // Wait for disconnection let mut attempts = 0; while self.is_connected().await && attempts < 50 { tokio::time::sleep(std::time::Duration::from_millis(100)).await; attempts += 1; } self.shutdown_tx = None; Ok(()) } /// Registers execution callbacks pub async fn register_callbacks(&self, callbacks: Arc) -> String { let callback_id = Uuid::new_v4().to_string(); self.event_emitter.register_callbacks(callback_id.clone(), callbacks).await; callback_id } /// Unregisters execution callbacks pub async fn unregister_callbacks(&self, callback_id: &str) { self.event_emitter.unregister_callbacks(callback_id).await; } /// Builds the WebSocket URL fn build_websocket_url(&self) -> Result { let base_url = Url::parse(&self.config.base_url)?; let scheme = match base_url.scheme() { "http" => "ws", "https" => "wss", _ => return Err(ComfyUIError::new("Invalid URL scheme")), }; let ws_url = format!( "{}://{}:{}/ws?clientId={}", scheme, base_url.host_str().unwrap_or("localhost"), base_url.port().unwrap_or(8188), self.client_id ); Ok(ws_url) } /// Handles incoming WebSocket messages async fn handle_message(text: &str, event_emitter: &EventEmitter) -> Result<()> { let message: WSMessage = serde_json::from_str(text) .map_err(|e| ComfyUIError::new(format!("Failed to parse WebSocket message: {}", e)))?; match message { WSMessage::Progress { data } => { let progress = ExecutionProgress { node_id: data.node.unwrap_or_else(|| "unknown".to_string()), progress: data.value, max: data.max, timestamp: Utc::now(), }; event_emitter.emit_progress(progress).await; } WSMessage::Executing { data } => { if let Some(node_id) = data.node { event_emitter.emit_executing(node_id).await; } } WSMessage::Executed { data } => { let result = ExecutionResult { prompt_id: data.prompt_id, outputs: data.output, execution_time: 0, // WebSocket doesn't provide execution time timestamp: Utc::now(), }; event_emitter.emit_executed(result).await; } WSMessage::ExecutionError { data } => { let error = ExecutionError { node_id: data.node_id, message: data.message, details: data.details, timestamp: Utc::now(), }; event_emitter.emit_error(error).await; } WSMessage::Unknown => { log::debug!("Received unknown WebSocket message: {}", text); } } Ok(()) } /// Gets the event emitter for direct access pub fn event_emitter(&self) -> &EventEmitter { &self.event_emitter } } impl Drop for WebSocketClient { fn drop(&mut self) { if let Some(shutdown_tx) = &self.shutdown_tx { let _ = shutdown_tx.send(()); } } }