mixvideo-v2/apps/desktop/src-tauri/src/business/services/realtime_monitor.rs

444 lines
15 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 实时监控服务
//! 负责 WebSocket 连接和实时事件处理
use anyhow::{Result, anyhow};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, mpsc, broadcast};
use tracing::{info, warn, error, debug};
use comfyui_sdk::types::events::{ExecutionProgress, ExecutionResult, ExecutionError, ExecutionCallbacks};
use comfyui_sdk::client::websocket_client::WebSocketClient;
use crate::business::services::comfyui_manager::ComfyUIManager;
use crate::data::models::comfyui::{ExecutionModel, ExecutionStatus};
use crate::data::repositories::comfyui_repository::ComfyUIRepository;
/// 实时事件类型
#[derive(Debug, Clone)]
pub enum RealtimeEvent {
/// 执行开始
ExecutionStarted {
prompt_id: String,
execution_id: Option<String>,
},
/// 执行进度更新
ExecutionProgress {
prompt_id: String,
execution_id: Option<String>,
progress: f32,
current_node: Option<String>,
total_nodes: Option<u32>,
},
/// 执行完成
ExecutionCompleted {
prompt_id: String,
execution_id: Option<String>,
outputs: HashMap<String, serde_json::Value>,
output_urls: Vec<String>,
},
/// 执行失败
ExecutionFailed {
prompt_id: String,
execution_id: Option<String>,
error: String,
},
/// 队列状态更新
QueueUpdated {
running_count: u32,
pending_count: u32,
},
/// 连接状态变化
ConnectionChanged {
connected: bool,
message: String,
},
}
/// 实时监控服务
pub struct RealtimeMonitor {
/// ComfyUI 管理器
manager: Arc<ComfyUIManager>,
/// 数据仓库
repository: Arc<ComfyUIRepository>,
/// 事件广播器
event_sender: broadcast::Sender<RealtimeEvent>,
/// WebSocket 连接状态
websocket_connected: Arc<RwLock<bool>>,
/// 提示 ID 到执行 ID 的映射
prompt_execution_map: Arc<RwLock<HashMap<String, String>>>,
/// 监控任务句柄
monitor_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
}
impl RealtimeMonitor {
/// 创建新的实时监控服务
pub fn new(
manager: Arc<ComfyUIManager>,
repository: Arc<ComfyUIRepository>,
) -> Self {
let (event_sender, _) = broadcast::channel(1000);
Self {
manager,
repository,
event_sender,
websocket_connected: Arc::new(RwLock::new(false)),
prompt_execution_map: Arc::new(RwLock::new(HashMap::new())),
monitor_handle: Arc::new(RwLock::new(None)),
}
}
/// 启动实时监控
pub async fn start(&self) -> Result<()> {
info!("启动实时监控服务");
// 检查是否已经在运行
{
let handle = self.monitor_handle.read().await;
if handle.is_some() {
return Err(anyhow!("实时监控服务已在运行"));
}
}
// 启动监控任务
let monitor_task = self.spawn_monitor_task().await?;
{
let mut handle = self.monitor_handle.write().await;
*handle = Some(monitor_task);
}
info!("实时监控服务已启动");
Ok(())
}
/// 停止实时监控
pub async fn stop(&self) -> Result<()> {
info!("停止实时监控服务");
let handle = {
let mut handle_guard = self.monitor_handle.write().await;
handle_guard.take()
};
if let Some(handle) = handle {
handle.abort();
info!("实时监控服务已停止");
}
// 更新连接状态
*self.websocket_connected.write().await = false;
Ok(())
}
/// 订阅实时事件
pub fn subscribe(&self) -> broadcast::Receiver<RealtimeEvent> {
self.event_sender.subscribe()
}
/// 注册执行映射
pub async fn register_execution(&self, prompt_id: String, execution_id: String) {
let mut map = self.prompt_execution_map.write().await;
map.insert(prompt_id, execution_id);
}
/// 获取执行 ID
async fn get_execution_id(&self, prompt_id: &str) -> Option<String> {
let map = self.prompt_execution_map.read().await;
map.get(prompt_id).cloned()
}
/// 生成监控任务
async fn spawn_monitor_task(&self) -> Result<tokio::task::JoinHandle<()>> {
let manager = Arc::clone(&self.manager);
let repository = Arc::clone(&self.repository);
let event_sender = self.event_sender.clone();
let websocket_connected = Arc::clone(&self.websocket_connected);
let prompt_execution_map = Arc::clone(&self.prompt_execution_map);
let handle = tokio::spawn(async move {
let mut reconnect_interval = Duration::from_secs(5);
let max_reconnect_interval = Duration::from_secs(60);
loop {
// 检查 ComfyUI 连接状态
if !manager.is_connected().await {
warn!("ComfyUI 未连接,等待连接...");
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
// 尝试建立 WebSocket 连接
match Self::connect_websocket(&manager, &event_sender, &websocket_connected, &prompt_execution_map, &repository).await {
Ok(_) => {
info!("WebSocket 连接已建立");
reconnect_interval = Duration::from_secs(5); // 重置重连间隔
}
Err(e) => {
error!("WebSocket 连接失败: {}", e);
// 发送连接状态事件
let _ = event_sender.send(RealtimeEvent::ConnectionChanged {
connected: false,
message: format!("WebSocket 连接失败: {}", e),
});
// 等待后重试
tokio::time::sleep(reconnect_interval).await;
// 增加重连间隔(指数退避)
reconnect_interval = std::cmp::min(
reconnect_interval * 2,
max_reconnect_interval,
);
}
}
}
});
Ok(handle)
}
/// 连接 WebSocket
async fn connect_websocket(
manager: &Arc<ComfyUIManager>,
event_sender: &broadcast::Sender<RealtimeEvent>,
websocket_connected: &Arc<RwLock<bool>>,
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
repository: &Arc<ComfyUIRepository>,
) -> Result<()> {
// 获取客户端
let client = manager.get_client().await?;
// 建立 WebSocket 连接
let mut websocket = client.websocket().connect().await?;
// 更新连接状态
*websocket_connected.write().await = true;
// 发送连接状态事件
let _ = event_sender.send(RealtimeEvent::ConnectionChanged {
connected: true,
message: "WebSocket 连接已建立".to_string(),
});
// 处理 WebSocket 消息
while let Some(event) = websocket.next_event().await {
match event {
Ok(comfyui_event) => {
if let Err(e) = Self::handle_comfyui_event(
comfyui_event,
event_sender,
prompt_execution_map,
repository,
).await {
error!("处理 ComfyUI 事件失败: {}", e);
}
}
Err(e) => {
error!("WebSocket 错误: {}", e);
break;
}
}
}
// 连接断开
*websocket_connected.write().await = false;
let _ = event_sender.send(RealtimeEvent::ConnectionChanged {
connected: false,
message: "WebSocket 连接已断开".to_string(),
});
Err(anyhow!("WebSocket 连接断开"))
}
/// 处理 ComfyUI 事件
async fn handle_comfyui_event(
event: serde_json::Value, // 临时使用 Value 类型
event_sender: &broadcast::Sender<RealtimeEvent>,
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
repository: &Arc<ComfyUIRepository>,
) -> Result<()> {
// TODO: 实现事件处理逻辑
info!("Received ComfyUI event: {:?}", event);
Ok(())
}
// 处理执行事件 (临时禁用)
/*
async fn handle_execution_event(
event: ExecutionEvent,
event_sender: &broadcast::Sender<RealtimeEvent>,
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
repository: &Arc<ComfyUIRepository>,
) -> Result<()> {
let execution_id = {
let map = prompt_execution_map.read().await;
map.get(&event.prompt_id).cloned()
};
match event.event_type.as_str() {
"execution_start" => {
let _ = event_sender.send(RealtimeEvent::ExecutionStarted {
prompt_id: event.prompt_id.clone(),
execution_id: execution_id.clone(),
});
// 更新数据库中的执行状态
if let Some(exec_id) = execution_id {
if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await {
execution.update_status(ExecutionStatus::Running);
let _ = repository.update_execution(&execution).await;
}
}
}
"execution_complete" => {
// 提取输出信息
let outputs = event.data.unwrap_or_default();
let output_urls = Self::extract_output_urls(&outputs);
let _ = event_sender.send(RealtimeEvent::ExecutionCompleted {
prompt_id: event.prompt_id.clone(),
execution_id: execution_id.clone(),
outputs: outputs.clone(),
output_urls: output_urls.clone(),
});
// 更新数据库中的执行状态
if let Some(exec_id) = execution_id {
if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await {
execution.set_results(outputs, output_urls);
execution.node_outputs = Some(outputs);
let _ = repository.update_execution(&execution).await;
}
}
// 清理映射
{
let mut map = prompt_execution_map.write().await;
map.remove(&event.prompt_id);
}
}
"execution_error" => {
let error_msg = event.data
.and_then(|data| data.get("error"))
.and_then(|e| e.as_str())
.unwrap_or("未知错误")
.to_string();
let _ = event_sender.send(RealtimeEvent::ExecutionFailed {
prompt_id: event.prompt_id.clone(),
execution_id: execution_id.clone(),
error: error_msg.clone(),
});
// 更新数据库中的执行状态
if let Some(exec_id) = execution_id {
if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await {
execution.set_error(error_msg);
let _ = repository.update_execution(&execution).await;
}
}
// 清理映射
{
let mut map = prompt_execution_map.write().await;
map.remove(&event.prompt_id);
}
}
_ => {
debug!("未处理的执行事件类型: {}", event.event_type);
}
}
Ok(())
}
/// 处理进度事件
async fn handle_progress_event(
event: ProgressEvent,
event_sender: &broadcast::Sender<RealtimeEvent>,
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
) -> Result<()> {
let execution_id = {
let map = prompt_execution_map.read().await;
map.get(&event.prompt_id).cloned()
};
let _ = event_sender.send(RealtimeEvent::ExecutionProgress {
prompt_id: event.prompt_id,
execution_id,
progress: event.progress,
current_node: event.current_node,
total_nodes: event.total_nodes,
});
Ok(())
}
/// 处理队列事件
async fn handle_queue_event(
event: QueueEvent,
event_sender: &broadcast::Sender<RealtimeEvent>,
) -> Result<()> {
let _ = event_sender.send(RealtimeEvent::QueueUpdated {
running_count: event.running_count,
pending_count: event.pending_count,
});
Ok(())
}
/// 提取输出 URLs
fn extract_output_urls(outputs: &HashMap<String, serde_json::Value>) -> Vec<String> {
let mut urls = Vec::new();
for (_, output) in outputs {
if let Some(images) = output.get("images").and_then(|v| v.as_array()) {
for image in images {
if let Some(filename) = image.get("filename").and_then(|v| v.as_str()) {
// 构建完整的 URL这里需要根据实际的 ComfyUI 配置调整)
let url = format!("/view?filename={}", filename);
urls.push(url);
}
}
}
}
urls
}
/// 检查 WebSocket 连接状态
pub async fn is_websocket_connected(&self) -> bool {
*self.websocket_connected.read().await
}
/// 获取监控统计信息
pub async fn get_monitor_stats(&self) -> MonitorStats {
let websocket_connected = *self.websocket_connected.read().await;
let prompt_map_size = self.prompt_execution_map.read().await.len();
let is_running = self.monitor_handle.read().await.is_some();
MonitorStats {
is_running,
websocket_connected,
tracked_executions: prompt_map_size,
event_subscribers: self.event_sender.receiver_count(),
}
}
*/
}
/// 监控统计信息
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MonitorStats {
pub is_running: bool,
pub websocket_connected: bool,
pub tracked_executions: usize,
pub event_subscribers: usize,
}