272 lines
9.1 KiB
Rust
272 lines
9.1 KiB
Rust
//! WebSocket client for ComfyUI real-time events
|
|
|
|
use std::sync::Arc;
|
|
use tokio::sync::{RwLock, mpsc};
|
|
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
|
use futures_util::StreamExt;
|
|
use url::Url;
|
|
use uuid::Uuid;
|
|
use chrono::Utc;
|
|
|
|
use crate::types::{
|
|
ComfyUIClientConfig, WSMessage, ExecutionProgress, ExecutionResult,
|
|
ExecutionError, ExecutionCallbacks
|
|
};
|
|
use crate::error::{ComfyUIError, Result};
|
|
use crate::utils::event_emitter::EventEmitter;
|
|
|
|
/// WebSocket client for real-time ComfyUI events
|
|
pub struct WebSocketClient {
|
|
config: ComfyUIClientConfig,
|
|
client_id: String,
|
|
event_emitter: EventEmitter,
|
|
is_connected: Arc<RwLock<bool>>,
|
|
shutdown_tx: Option<mpsc::UnboundedSender<()>>,
|
|
}
|
|
|
|
impl WebSocketClient {
|
|
/// Creates a new WebSocket client
|
|
pub fn new(config: ComfyUIClientConfig) -> Self {
|
|
Self {
|
|
config,
|
|
client_id: Uuid::new_v4().to_string(),
|
|
event_emitter: EventEmitter::new(),
|
|
is_connected: Arc::new(RwLock::new(false)),
|
|
shutdown_tx: None,
|
|
}
|
|
}
|
|
|
|
/// Gets the client ID
|
|
pub fn client_id(&self) -> &str {
|
|
&self.client_id
|
|
}
|
|
|
|
/// Checks if connected
|
|
pub async fn is_connected(&self) -> bool {
|
|
*self.is_connected.read().await
|
|
}
|
|
|
|
/// Connects to the WebSocket server
|
|
pub async fn connect(&mut self) -> Result<()> {
|
|
if self.is_connected().await {
|
|
return Ok(());
|
|
}
|
|
|
|
let ws_url = self.build_websocket_url()?;
|
|
log::info!("Connecting to WebSocket: {ws_url}");
|
|
|
|
let (ws_stream, _) = connect_async(&ws_url).await?;
|
|
let (_ws_sender, mut ws_receiver) = ws_stream.split();
|
|
|
|
// Set connected state
|
|
*self.is_connected.write().await = true;
|
|
|
|
// Create shutdown channel
|
|
let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel();
|
|
self.shutdown_tx = Some(shutdown_tx);
|
|
|
|
// Clone necessary data for the message handling task
|
|
let event_emitter = self.event_emitter.clone();
|
|
let is_connected = self.is_connected.clone();
|
|
|
|
// Spawn message handling task
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
// Handle incoming messages
|
|
message = ws_receiver.next() => {
|
|
match message {
|
|
Some(Ok(Message::Text(text))) => {
|
|
if let Err(e) = Self::handle_message(&text, &event_emitter).await {
|
|
log::error!("Error handling WebSocket message: {e}");
|
|
}
|
|
}
|
|
Some(Ok(Message::Close(_))) => {
|
|
log::info!("WebSocket connection closed by server");
|
|
break;
|
|
}
|
|
Some(Err(e)) => {
|
|
log::error!("WebSocket error: {e}");
|
|
break;
|
|
}
|
|
None => {
|
|
log::info!("WebSocket stream ended");
|
|
break;
|
|
}
|
|
_ => {} // Ignore other message types
|
|
}
|
|
}
|
|
// Handle shutdown signal
|
|
_ = shutdown_rx.recv() => {
|
|
log::info!("WebSocket shutdown requested");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Set disconnected state
|
|
*is_connected.write().await = false;
|
|
log::info!("WebSocket connection closed");
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Disconnects from the WebSocket server
|
|
pub async fn disconnect(&mut self) -> Result<()> {
|
|
if !self.is_connected().await {
|
|
return Ok(());
|
|
}
|
|
|
|
if let Some(shutdown_tx) = &self.shutdown_tx {
|
|
let _ = shutdown_tx.send(());
|
|
}
|
|
|
|
// Wait for disconnection
|
|
let mut attempts = 0;
|
|
while self.is_connected().await && attempts < 50 {
|
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
|
attempts += 1;
|
|
}
|
|
|
|
self.shutdown_tx = None;
|
|
Ok(())
|
|
}
|
|
|
|
/// Registers execution callbacks
|
|
pub async fn register_callbacks(&self, callbacks: Arc<dyn ExecutionCallbacks>) -> String {
|
|
let callback_id = Uuid::new_v4().to_string();
|
|
self.event_emitter.register_callbacks(callback_id.clone(), callbacks).await;
|
|
callback_id
|
|
}
|
|
|
|
/// Unregisters execution callbacks
|
|
pub async fn unregister_callbacks(&self, callback_id: &str) {
|
|
self.event_emitter.unregister_callbacks(callback_id).await;
|
|
}
|
|
|
|
/// Builds the WebSocket URL
|
|
fn build_websocket_url(&self) -> Result<String> {
|
|
let base_url = Url::parse(&self.config.base_url)?;
|
|
let scheme = match base_url.scheme() {
|
|
"http" => "ws",
|
|
"https" => "wss",
|
|
_ => return Err(ComfyUIError::new("Invalid URL scheme")),
|
|
};
|
|
|
|
let ws_url = format!(
|
|
"{}://{}:{}/ws?clientId={}",
|
|
scheme,
|
|
base_url.host_str().unwrap_or("localhost"),
|
|
base_url.port().unwrap_or(8188),
|
|
self.client_id
|
|
);
|
|
|
|
Ok(ws_url)
|
|
}
|
|
|
|
/// Handles incoming WebSocket messages
|
|
async fn handle_message(text: &str, event_emitter: &EventEmitter) -> Result<()> {
|
|
// 打印原始消息用于调试
|
|
log::debug!("Received WebSocket message: {}", text);
|
|
|
|
let message: WSMessage = serde_json::from_str(text)
|
|
.map_err(|e| {
|
|
log::error!("Failed to parse WebSocket message: {}", e);
|
|
log::error!("Raw message content: {}", text);
|
|
ComfyUIError::new(format!("Failed to parse WebSocket message: {e}"))
|
|
})?;
|
|
|
|
match message {
|
|
WSMessage::Progress { data } => {
|
|
let progress = ExecutionProgress {
|
|
node_id: data.node.unwrap_or_else(|| "unknown".to_string()),
|
|
progress: data.value,
|
|
max: data.max,
|
|
timestamp: Utc::now(),
|
|
};
|
|
event_emitter.emit_progress(progress).await;
|
|
}
|
|
WSMessage::Executing { data } => {
|
|
if let Some(node_id) = data.node {
|
|
event_emitter.emit_executing(node_id).await;
|
|
}
|
|
}
|
|
WSMessage::Executed { data } => {
|
|
let result = ExecutionResult {
|
|
prompt_id: data.prompt_id,
|
|
outputs: data.output,
|
|
execution_time: 0, // WebSocket doesn't provide execution time
|
|
timestamp: Utc::now(),
|
|
};
|
|
event_emitter.emit_executed(result).await;
|
|
}
|
|
WSMessage::ExecutionError { data } => {
|
|
// 构建错误消息,尝试多个可能的字段
|
|
let error_message = data.message
|
|
.or(data.error)
|
|
.or(data.exception_message)
|
|
.or_else(|| {
|
|
// 尝试从extra字段中提取错误信息
|
|
data.extra.get("exception_message")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_string())
|
|
})
|
|
.or_else(|| {
|
|
data.extra.get("error")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_string())
|
|
})
|
|
.unwrap_or_else(|| "Unknown execution error".to_string());
|
|
|
|
// 获取节点ID
|
|
let node_id = data.node_id
|
|
.or(data.node)
|
|
.or_else(|| {
|
|
data.extra.get("node")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_string())
|
|
});
|
|
|
|
let error = ExecutionError {
|
|
node_id,
|
|
message: error_message,
|
|
details: data.details,
|
|
timestamp: Utc::now(),
|
|
};
|
|
|
|
log::error!("ComfyUI execution error: {}", error.message);
|
|
event_emitter.emit_error(error).await;
|
|
}
|
|
WSMessage::Unknown => {
|
|
log::debug!("Received unknown WebSocket message: {text}");
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Gets the event emitter for direct access
|
|
pub fn event_emitter(&self) -> &EventEmitter {
|
|
&self.event_emitter
|
|
}
|
|
|
|
/// Gets the last execution error and clears it
|
|
pub async fn get_last_error(&self) -> Option<ExecutionError> {
|
|
self.event_emitter.get_last_error().await
|
|
}
|
|
|
|
/// Clears the last error
|
|
pub async fn clear_last_error(&self) {
|
|
self.event_emitter.clear_last_error().await
|
|
}
|
|
}
|
|
|
|
impl Drop for WebSocketClient {
|
|
fn drop(&mut self) {
|
|
if let Some(shutdown_tx) = &self.shutdown_tx {
|
|
let _ = shutdown_tx.send(());
|
|
}
|
|
}
|
|
}
|