//! 实时监控服务 //! 负责 WebSocket 连接和实时事件处理 use anyhow::{Result, anyhow}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::{RwLock, mpsc, broadcast}; use tracing::{info, warn, error, debug}; use comfyui_sdk::types::events::{ExecutionProgress, ExecutionResult, ExecutionError, ExecutionCallbacks}; use comfyui_sdk::client::websocket_client::WebSocketClient; use crate::business::services::comfyui_manager::ComfyUIManager; use crate::data::models::comfyui::{ExecutionModel, ExecutionStatus}; use crate::data::repositories::comfyui_repository::ComfyUIRepository; /// 实时事件类型 #[derive(Debug, Clone)] pub enum RealtimeEvent { /// 执行开始 ExecutionStarted { prompt_id: String, execution_id: Option, }, /// 执行进度更新 ExecutionProgress { prompt_id: String, execution_id: Option, progress: f32, current_node: Option, total_nodes: Option, }, /// 执行完成 ExecutionCompleted { prompt_id: String, execution_id: Option, outputs: HashMap, output_urls: Vec, }, /// 执行失败 ExecutionFailed { prompt_id: String, execution_id: Option, error: String, }, /// 队列状态更新 QueueUpdated { running_count: u32, pending_count: u32, }, /// 连接状态变化 ConnectionChanged { connected: bool, message: String, }, } /// 实时监控服务 pub struct RealtimeMonitor { /// ComfyUI 管理器 manager: Arc, /// 数据仓库 repository: Arc, /// 事件广播器 event_sender: broadcast::Sender, /// WebSocket 连接状态 websocket_connected: Arc>, /// 提示 ID 到执行 ID 的映射 prompt_execution_map: Arc>>, /// 监控任务句柄 monitor_handle: Arc>>>, } impl RealtimeMonitor { /// 创建新的实时监控服务 pub fn new( manager: Arc, repository: Arc, ) -> Self { let (event_sender, _) = broadcast::channel(1000); Self { manager, repository, event_sender, websocket_connected: Arc::new(RwLock::new(false)), prompt_execution_map: Arc::new(RwLock::new(HashMap::new())), monitor_handle: Arc::new(RwLock::new(None)), } } /// 启动实时监控 pub async fn start(&self) -> Result<()> { info!("启动实时监控服务"); // 检查是否已经在运行 { let handle = self.monitor_handle.read().await; if handle.is_some() { return Err(anyhow!("实时监控服务已在运行")); } } // 启动监控任务 let monitor_task = self.spawn_monitor_task().await?; { let mut handle = self.monitor_handle.write().await; *handle = Some(monitor_task); } info!("实时监控服务已启动"); Ok(()) } /// 停止实时监控 pub async fn stop(&self) -> Result<()> { info!("停止实时监控服务"); let handle = { let mut handle_guard = self.monitor_handle.write().await; handle_guard.take() }; if let Some(handle) = handle { handle.abort(); info!("实时监控服务已停止"); } // 更新连接状态 *self.websocket_connected.write().await = false; Ok(()) } /// 订阅实时事件 pub fn subscribe(&self) -> broadcast::Receiver { self.event_sender.subscribe() } /// 注册执行映射 pub async fn register_execution(&self, prompt_id: String, execution_id: String) { let mut map = self.prompt_execution_map.write().await; map.insert(prompt_id, execution_id); } /// 获取执行 ID async fn get_execution_id(&self, prompt_id: &str) -> Option { let map = self.prompt_execution_map.read().await; map.get(prompt_id).cloned() } /// 生成监控任务 async fn spawn_monitor_task(&self) -> Result> { let manager = Arc::clone(&self.manager); let repository = Arc::clone(&self.repository); let event_sender = self.event_sender.clone(); let websocket_connected = Arc::clone(&self.websocket_connected); let prompt_execution_map = Arc::clone(&self.prompt_execution_map); let handle = tokio::spawn(async move { let mut reconnect_interval = Duration::from_secs(5); let max_reconnect_interval = Duration::from_secs(60); loop { // 检查 ComfyUI 连接状态 if !manager.is_connected().await { warn!("ComfyUI 未连接,等待连接..."); tokio::time::sleep(Duration::from_secs(5)).await; continue; } // 尝试建立 WebSocket 连接 match Self::connect_websocket(&manager, &event_sender, &websocket_connected, &prompt_execution_map, &repository).await { Ok(_) => { info!("WebSocket 连接已建立"); reconnect_interval = Duration::from_secs(5); // 重置重连间隔 } Err(e) => { error!("WebSocket 连接失败: {}", e); // 发送连接状态事件 let _ = event_sender.send(RealtimeEvent::ConnectionChanged { connected: false, message: format!("WebSocket 连接失败: {}", e), }); // 等待后重试 tokio::time::sleep(reconnect_interval).await; // 增加重连间隔(指数退避) reconnect_interval = std::cmp::min( reconnect_interval * 2, max_reconnect_interval, ); } } } }); Ok(handle) } /// 连接 WebSocket async fn connect_websocket( manager: &Arc, event_sender: &broadcast::Sender, websocket_connected: &Arc>, prompt_execution_map: &Arc>>, repository: &Arc, ) -> Result<()> { // 获取客户端 let client = manager.get_client().await?; // 建立 WebSocket 连接 let mut websocket = client.websocket().connect().await?; // 更新连接状态 *websocket_connected.write().await = true; // 发送连接状态事件 let _ = event_sender.send(RealtimeEvent::ConnectionChanged { connected: true, message: "WebSocket 连接已建立".to_string(), }); // 处理 WebSocket 消息 while let Some(event) = websocket.next_event().await { match event { Ok(comfyui_event) => { if let Err(e) = Self::handle_comfyui_event( comfyui_event, event_sender, prompt_execution_map, repository, ).await { error!("处理 ComfyUI 事件失败: {}", e); } } Err(e) => { error!("WebSocket 错误: {}", e); break; } } } // 连接断开 *websocket_connected.write().await = false; let _ = event_sender.send(RealtimeEvent::ConnectionChanged { connected: false, message: "WebSocket 连接已断开".to_string(), }); Err(anyhow!("WebSocket 连接断开")) } /// 处理 ComfyUI 事件 async fn handle_comfyui_event( event: serde_json::Value, // 临时使用 Value 类型 event_sender: &broadcast::Sender, prompt_execution_map: &Arc>>, repository: &Arc, ) -> Result<()> { // TODO: 实现事件处理逻辑 info!("Received ComfyUI event: {:?}", event); Ok(()) } // 处理执行事件 (临时禁用) /* async fn handle_execution_event( event: ExecutionEvent, event_sender: &broadcast::Sender, prompt_execution_map: &Arc>>, repository: &Arc, ) -> Result<()> { let execution_id = { let map = prompt_execution_map.read().await; map.get(&event.prompt_id).cloned() }; match event.event_type.as_str() { "execution_start" => { let _ = event_sender.send(RealtimeEvent::ExecutionStarted { prompt_id: event.prompt_id.clone(), execution_id: execution_id.clone(), }); // 更新数据库中的执行状态 if let Some(exec_id) = execution_id { if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await { execution.update_status(ExecutionStatus::Running); let _ = repository.update_execution(&execution).await; } } } "execution_complete" => { // 提取输出信息 let outputs = event.data.unwrap_or_default(); let output_urls = Self::extract_output_urls(&outputs); let _ = event_sender.send(RealtimeEvent::ExecutionCompleted { prompt_id: event.prompt_id.clone(), execution_id: execution_id.clone(), outputs: outputs.clone(), output_urls: output_urls.clone(), }); // 更新数据库中的执行状态 if let Some(exec_id) = execution_id { if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await { execution.set_results(outputs, output_urls); execution.node_outputs = Some(outputs); let _ = repository.update_execution(&execution).await; } } // 清理映射 { let mut map = prompt_execution_map.write().await; map.remove(&event.prompt_id); } } "execution_error" => { let error_msg = event.data .and_then(|data| data.get("error")) .and_then(|e| e.as_str()) .unwrap_or("未知错误") .to_string(); let _ = event_sender.send(RealtimeEvent::ExecutionFailed { prompt_id: event.prompt_id.clone(), execution_id: execution_id.clone(), error: error_msg.clone(), }); // 更新数据库中的执行状态 if let Some(exec_id) = execution_id { if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await { execution.set_error(error_msg); let _ = repository.update_execution(&execution).await; } } // 清理映射 { let mut map = prompt_execution_map.write().await; map.remove(&event.prompt_id); } } _ => { debug!("未处理的执行事件类型: {}", event.event_type); } } Ok(()) } /// 处理进度事件 async fn handle_progress_event( event: ProgressEvent, event_sender: &broadcast::Sender, prompt_execution_map: &Arc>>, ) -> Result<()> { let execution_id = { let map = prompt_execution_map.read().await; map.get(&event.prompt_id).cloned() }; let _ = event_sender.send(RealtimeEvent::ExecutionProgress { prompt_id: event.prompt_id, execution_id, progress: event.progress, current_node: event.current_node, total_nodes: event.total_nodes, }); Ok(()) } /// 处理队列事件 async fn handle_queue_event( event: QueueEvent, event_sender: &broadcast::Sender, ) -> Result<()> { let _ = event_sender.send(RealtimeEvent::QueueUpdated { running_count: event.running_count, pending_count: event.pending_count, }); Ok(()) } /// 提取输出 URLs fn extract_output_urls(outputs: &HashMap) -> Vec { let mut urls = Vec::new(); for (_, output) in outputs { if let Some(images) = output.get("images").and_then(|v| v.as_array()) { for image in images { if let Some(filename) = image.get("filename").and_then(|v| v.as_str()) { // 构建完整的 URL(这里需要根据实际的 ComfyUI 配置调整) let url = format!("/view?filename={}", filename); urls.push(url); } } } } urls } /// 检查 WebSocket 连接状态 pub async fn is_websocket_connected(&self) -> bool { *self.websocket_connected.read().await } /// 获取监控统计信息 pub async fn get_monitor_stats(&self) -> MonitorStats { let websocket_connected = *self.websocket_connected.read().await; let prompt_map_size = self.prompt_execution_map.read().await.len(); let is_running = self.monitor_handle.read().await.is_some(); MonitorStats { is_running, websocket_connected, tracked_executions: prompt_map_size, event_subscribers: self.event_sender.receiver_count(), } } */ } /// 监控统计信息 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct MonitorStats { pub is_running: bool, pub websocket_connected: bool, pub tracked_executions: usize, pub event_subscribers: usize, }