//! WebSocket client for ComfyUI real-time events use std::sync::Arc; use tokio::sync::{RwLock, mpsc}; use tokio_tungstenite::{connect_async, tungstenite::Message}; use futures_util::StreamExt; use url::Url; use uuid::Uuid; use chrono::Utc; use crate::types::{ ComfyUIClientConfig, WSMessage, ExecutionProgress, ExecutionResult, ExecutionError, ExecutionCallbacks }; use crate::error::{ComfyUIError, Result}; use crate::utils::event_emitter::EventEmitter; /// WebSocket client for real-time ComfyUI events pub struct WebSocketClient { config: ComfyUIClientConfig, client_id: String, event_emitter: EventEmitter, is_connected: Arc>, shutdown_tx: Option>, } impl WebSocketClient { /// Creates a new WebSocket client pub fn new(config: ComfyUIClientConfig) -> Self { Self { config, client_id: Uuid::new_v4().to_string(), event_emitter: EventEmitter::new(), is_connected: Arc::new(RwLock::new(false)), shutdown_tx: None, } } /// Gets the client ID pub fn client_id(&self) -> &str { &self.client_id } /// Checks if connected pub async fn is_connected(&self) -> bool { *self.is_connected.read().await } /// Connects to the WebSocket server pub async fn connect(&mut self) -> Result<()> { if self.is_connected().await { return Ok(()); } let ws_url = self.build_websocket_url()?; log::info!("Connecting to WebSocket: {ws_url}"); let (ws_stream, _) = connect_async(&ws_url).await?; let (_ws_sender, mut ws_receiver) = ws_stream.split(); // Set connected state *self.is_connected.write().await = true; // Create shutdown channel let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel(); self.shutdown_tx = Some(shutdown_tx); // Clone necessary data for the message handling task let event_emitter = self.event_emitter.clone(); let is_connected = self.is_connected.clone(); // Spawn message handling task tokio::spawn(async move { loop { tokio::select! { // Handle incoming messages message = ws_receiver.next() => { match message { Some(Ok(Message::Text(text))) => { if let Err(e) = Self::handle_message(&text, &event_emitter).await { log::error!("Error handling WebSocket message: {e}"); } } Some(Ok(Message::Close(_))) => { log::info!("WebSocket connection closed by server"); break; } Some(Err(e)) => { log::error!("WebSocket error: {e}"); break; } None => { log::info!("WebSocket stream ended"); break; } _ => {} // Ignore other message types } } // Handle shutdown signal _ = shutdown_rx.recv() => { log::info!("WebSocket shutdown requested"); break; } } } // Set disconnected state *is_connected.write().await = false; log::info!("WebSocket connection closed"); }); Ok(()) } /// Disconnects from the WebSocket server pub async fn disconnect(&mut self) -> Result<()> { if !self.is_connected().await { return Ok(()); } if let Some(shutdown_tx) = &self.shutdown_tx { let _ = shutdown_tx.send(()); } // Wait for disconnection let mut attempts = 0; while self.is_connected().await && attempts < 50 { tokio::time::sleep(std::time::Duration::from_millis(100)).await; attempts += 1; } self.shutdown_tx = None; Ok(()) } /// Registers execution callbacks pub async fn register_callbacks(&self, callbacks: Arc) -> String { let callback_id = Uuid::new_v4().to_string(); self.event_emitter.register_callbacks(callback_id.clone(), callbacks).await; callback_id } /// Unregisters execution callbacks pub async fn unregister_callbacks(&self, callback_id: &str) { self.event_emitter.unregister_callbacks(callback_id).await; } /// Builds the WebSocket URL fn build_websocket_url(&self) -> Result { let base_url = Url::parse(&self.config.base_url)?; let scheme = match base_url.scheme() { "http" => "ws", "https" => "wss", _ => return Err(ComfyUIError::new("Invalid URL scheme")), }; let ws_url = format!( "{}://{}:{}/ws?clientId={}", scheme, base_url.host_str().unwrap_or("localhost"), base_url.port().unwrap_or(8188), self.client_id ); Ok(ws_url) } /// Handles incoming WebSocket messages async fn handle_message(text: &str, event_emitter: &EventEmitter) -> Result<()> { // 打印原始消息用于调试 log::debug!("Received WebSocket message: {}", text); let message: WSMessage = serde_json::from_str(text) .map_err(|e| { log::error!("Failed to parse WebSocket message: {}", e); log::error!("Raw message content: {}", text); ComfyUIError::new(format!("Failed to parse WebSocket message: {e}")) })?; match message { WSMessage::Progress { data } => { let progress = ExecutionProgress { node_id: data.node.unwrap_or_else(|| "unknown".to_string()), progress: data.value, max: data.max, timestamp: Utc::now(), }; event_emitter.emit_progress(progress).await; } WSMessage::Executing { data } => { if let Some(node_id) = data.node { event_emitter.emit_executing(node_id).await; } } WSMessage::Executed { data } => { let result = ExecutionResult { prompt_id: data.prompt_id, outputs: data.output, execution_time: 0, // WebSocket doesn't provide execution time timestamp: Utc::now(), }; event_emitter.emit_executed(result).await; } WSMessage::ExecutionError { data } => { // 构建错误消息,尝试多个可能的字段 let error_message = data.message .or(data.error) .or(data.exception_message) .or_else(|| { // 尝试从extra字段中提取错误信息 data.extra.get("exception_message") .and_then(|v| v.as_str()) .map(|s| s.to_string()) }) .or_else(|| { data.extra.get("error") .and_then(|v| v.as_str()) .map(|s| s.to_string()) }) .unwrap_or_else(|| "Unknown execution error".to_string()); // 获取节点ID let node_id = data.node_id .or(data.node) .or_else(|| { data.extra.get("node") .and_then(|v| v.as_str()) .map(|s| s.to_string()) }); let error = ExecutionError { node_id, message: error_message, details: data.details, timestamp: Utc::now(), }; log::error!("ComfyUI execution error: {}", error.message); event_emitter.emit_error(error).await; } WSMessage::Unknown => { log::debug!("Received unknown WebSocket message: {text}"); } } Ok(()) } /// Gets the event emitter for direct access pub fn event_emitter(&self) -> &EventEmitter { &self.event_emitter } /// Gets the last execution error and clears it pub async fn get_last_error(&self) -> Option { self.event_emitter.get_last_error().await } /// Clears the last error pub async fn clear_last_error(&self) { self.event_emitter.clear_last_error().await } } impl Drop for WebSocketClient { fn drop(&mut self) { if let Some(shutdown_tx) = &self.shutdown_tx { let _ = shutdown_tx.send(()); } } }