mixvideo-v2/cargos/comfyui-sdk/client/websocket_client.rs

272 lines
9.1 KiB
Rust

//! WebSocket client for ComfyUI real-time events
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use futures_util::StreamExt;
use url::Url;
use uuid::Uuid;
use chrono::Utc;
use crate::types::{
ComfyUIClientConfig, WSMessage, ExecutionProgress, ExecutionResult,
ExecutionError, ExecutionCallbacks
};
use crate::error::{ComfyUIError, Result};
use crate::utils::event_emitter::EventEmitter;
/// WebSocket client for real-time ComfyUI events
pub struct WebSocketClient {
config: ComfyUIClientConfig,
client_id: String,
event_emitter: EventEmitter,
is_connected: Arc<RwLock<bool>>,
shutdown_tx: Option<mpsc::UnboundedSender<()>>,
}
impl WebSocketClient {
/// Creates a new WebSocket client
pub fn new(config: ComfyUIClientConfig) -> Self {
Self {
config,
client_id: Uuid::new_v4().to_string(),
event_emitter: EventEmitter::new(),
is_connected: Arc::new(RwLock::new(false)),
shutdown_tx: None,
}
}
/// Gets the client ID
pub fn client_id(&self) -> &str {
&self.client_id
}
/// Checks if connected
pub async fn is_connected(&self) -> bool {
*self.is_connected.read().await
}
/// Connects to the WebSocket server
pub async fn connect(&mut self) -> Result<()> {
if self.is_connected().await {
return Ok(());
}
let ws_url = self.build_websocket_url()?;
log::info!("Connecting to WebSocket: {ws_url}");
let (ws_stream, _) = connect_async(&ws_url).await?;
let (_ws_sender, mut ws_receiver) = ws_stream.split();
// Set connected state
*self.is_connected.write().await = true;
// Create shutdown channel
let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel();
self.shutdown_tx = Some(shutdown_tx);
// Clone necessary data for the message handling task
let event_emitter = self.event_emitter.clone();
let is_connected = self.is_connected.clone();
// Spawn message handling task
tokio::spawn(async move {
loop {
tokio::select! {
// Handle incoming messages
message = ws_receiver.next() => {
match message {
Some(Ok(Message::Text(text))) => {
if let Err(e) = Self::handle_message(&text, &event_emitter).await {
log::error!("Error handling WebSocket message: {e}");
}
}
Some(Ok(Message::Close(_))) => {
log::info!("WebSocket connection closed by server");
break;
}
Some(Err(e)) => {
log::error!("WebSocket error: {e}");
break;
}
None => {
log::info!("WebSocket stream ended");
break;
}
_ => {} // Ignore other message types
}
}
// Handle shutdown signal
_ = shutdown_rx.recv() => {
log::info!("WebSocket shutdown requested");
break;
}
}
}
// Set disconnected state
*is_connected.write().await = false;
log::info!("WebSocket connection closed");
});
Ok(())
}
/// Disconnects from the WebSocket server
pub async fn disconnect(&mut self) -> Result<()> {
if !self.is_connected().await {
return Ok(());
}
if let Some(shutdown_tx) = &self.shutdown_tx {
let _ = shutdown_tx.send(());
}
// Wait for disconnection
let mut attempts = 0;
while self.is_connected().await && attempts < 50 {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
attempts += 1;
}
self.shutdown_tx = None;
Ok(())
}
/// Registers execution callbacks
pub async fn register_callbacks(&self, callbacks: Arc<dyn ExecutionCallbacks>) -> String {
let callback_id = Uuid::new_v4().to_string();
self.event_emitter.register_callbacks(callback_id.clone(), callbacks).await;
callback_id
}
/// Unregisters execution callbacks
pub async fn unregister_callbacks(&self, callback_id: &str) {
self.event_emitter.unregister_callbacks(callback_id).await;
}
/// Builds the WebSocket URL
fn build_websocket_url(&self) -> Result<String> {
let base_url = Url::parse(&self.config.base_url)?;
let scheme = match base_url.scheme() {
"http" => "ws",
"https" => "wss",
_ => return Err(ComfyUIError::new("Invalid URL scheme")),
};
let ws_url = format!(
"{}://{}:{}/ws?clientId={}",
scheme,
base_url.host_str().unwrap_or("localhost"),
base_url.port().unwrap_or(8188),
self.client_id
);
Ok(ws_url)
}
/// Handles incoming WebSocket messages
async fn handle_message(text: &str, event_emitter: &EventEmitter) -> Result<()> {
// 打印原始消息用于调试
log::debug!("Received WebSocket message: {}", text);
let message: WSMessage = serde_json::from_str(text)
.map_err(|e| {
log::error!("Failed to parse WebSocket message: {}", e);
log::error!("Raw message content: {}", text);
ComfyUIError::new(format!("Failed to parse WebSocket message: {e}"))
})?;
match message {
WSMessage::Progress { data } => {
let progress = ExecutionProgress {
node_id: data.node.unwrap_or_else(|| "unknown".to_string()),
progress: data.value,
max: data.max,
timestamp: Utc::now(),
};
event_emitter.emit_progress(progress).await;
}
WSMessage::Executing { data } => {
if let Some(node_id) = data.node {
event_emitter.emit_executing(node_id).await;
}
}
WSMessage::Executed { data } => {
let result = ExecutionResult {
prompt_id: data.prompt_id,
outputs: data.output,
execution_time: 0, // WebSocket doesn't provide execution time
timestamp: Utc::now(),
};
event_emitter.emit_executed(result).await;
}
WSMessage::ExecutionError { data } => {
// 构建错误消息,尝试多个可能的字段
let error_message = data.message
.or(data.error)
.or(data.exception_message)
.or_else(|| {
// 尝试从extra字段中提取错误信息
data.extra.get("exception_message")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})
.or_else(|| {
data.extra.get("error")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "Unknown execution error".to_string());
// 获取节点ID
let node_id = data.node_id
.or(data.node)
.or_else(|| {
data.extra.get("node")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
});
let error = ExecutionError {
node_id,
message: error_message,
details: data.details,
timestamp: Utc::now(),
};
log::error!("ComfyUI execution error: {}", error.message);
event_emitter.emit_error(error).await;
}
WSMessage::Unknown => {
log::debug!("Received unknown WebSocket message: {text}");
}
}
Ok(())
}
/// Gets the event emitter for direct access
pub fn event_emitter(&self) -> &EventEmitter {
&self.event_emitter
}
/// Gets the last execution error and clears it
pub async fn get_last_error(&self) -> Option<ExecutionError> {
self.event_emitter.get_last_error().await
}
/// Clears the last error
pub async fn clear_last_error(&self) {
self.event_emitter.clear_last_error().await
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
if let Some(shutdown_tx) = &self.shutdown_tx {
let _ = shutdown_tx.send(());
}
}
}