444 lines
15 KiB
Rust
444 lines
15 KiB
Rust
//! 实时监控服务
|
||
//! 负责 WebSocket 连接和实时事件处理
|
||
|
||
use anyhow::{Result, anyhow};
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
use tokio::sync::{RwLock, mpsc, broadcast};
|
||
use tracing::{info, warn, error, debug};
|
||
|
||
use comfyui_sdk::types::events::{ExecutionProgress, ExecutionResult, ExecutionError, ExecutionCallbacks};
|
||
use comfyui_sdk::client::websocket_client::WebSocketClient;
|
||
|
||
use crate::business::services::comfyui_manager::ComfyUIManager;
|
||
use crate::data::models::comfyui::{ExecutionModel, ExecutionStatus};
|
||
use crate::data::repositories::comfyui_repository::ComfyUIRepository;
|
||
|
||
/// 实时事件类型
|
||
#[derive(Debug, Clone)]
|
||
pub enum RealtimeEvent {
|
||
/// 执行开始
|
||
ExecutionStarted {
|
||
prompt_id: String,
|
||
execution_id: Option<String>,
|
||
},
|
||
/// 执行进度更新
|
||
ExecutionProgress {
|
||
prompt_id: String,
|
||
execution_id: Option<String>,
|
||
progress: f32,
|
||
current_node: Option<String>,
|
||
total_nodes: Option<u32>,
|
||
},
|
||
/// 执行完成
|
||
ExecutionCompleted {
|
||
prompt_id: String,
|
||
execution_id: Option<String>,
|
||
outputs: HashMap<String, serde_json::Value>,
|
||
output_urls: Vec<String>,
|
||
},
|
||
/// 执行失败
|
||
ExecutionFailed {
|
||
prompt_id: String,
|
||
execution_id: Option<String>,
|
||
error: String,
|
||
},
|
||
/// 队列状态更新
|
||
QueueUpdated {
|
||
running_count: u32,
|
||
pending_count: u32,
|
||
},
|
||
/// 连接状态变化
|
||
ConnectionChanged {
|
||
connected: bool,
|
||
message: String,
|
||
},
|
||
}
|
||
|
||
/// 实时监控服务
|
||
pub struct RealtimeMonitor {
|
||
/// ComfyUI 管理器
|
||
manager: Arc<ComfyUIManager>,
|
||
/// 数据仓库
|
||
repository: Arc<ComfyUIRepository>,
|
||
/// 事件广播器
|
||
event_sender: broadcast::Sender<RealtimeEvent>,
|
||
/// WebSocket 连接状态
|
||
websocket_connected: Arc<RwLock<bool>>,
|
||
/// 提示 ID 到执行 ID 的映射
|
||
prompt_execution_map: Arc<RwLock<HashMap<String, String>>>,
|
||
/// 监控任务句柄
|
||
monitor_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
|
||
}
|
||
|
||
impl RealtimeMonitor {
|
||
/// 创建新的实时监控服务
|
||
pub fn new(
|
||
manager: Arc<ComfyUIManager>,
|
||
repository: Arc<ComfyUIRepository>,
|
||
) -> Self {
|
||
let (event_sender, _) = broadcast::channel(1000);
|
||
|
||
Self {
|
||
manager,
|
||
repository,
|
||
event_sender,
|
||
websocket_connected: Arc::new(RwLock::new(false)),
|
||
prompt_execution_map: Arc::new(RwLock::new(HashMap::new())),
|
||
monitor_handle: Arc::new(RwLock::new(None)),
|
||
}
|
||
}
|
||
|
||
/// 启动实时监控
|
||
pub async fn start(&self) -> Result<()> {
|
||
info!("启动实时监控服务");
|
||
|
||
// 检查是否已经在运行
|
||
{
|
||
let handle = self.monitor_handle.read().await;
|
||
if handle.is_some() {
|
||
return Err(anyhow!("实时监控服务已在运行"));
|
||
}
|
||
}
|
||
|
||
// 启动监控任务
|
||
let monitor_task = self.spawn_monitor_task().await?;
|
||
|
||
{
|
||
let mut handle = self.monitor_handle.write().await;
|
||
*handle = Some(monitor_task);
|
||
}
|
||
|
||
info!("实时监控服务已启动");
|
||
Ok(())
|
||
}
|
||
|
||
/// 停止实时监控
|
||
pub async fn stop(&self) -> Result<()> {
|
||
info!("停止实时监控服务");
|
||
|
||
let handle = {
|
||
let mut handle_guard = self.monitor_handle.write().await;
|
||
handle_guard.take()
|
||
};
|
||
|
||
if let Some(handle) = handle {
|
||
handle.abort();
|
||
info!("实时监控服务已停止");
|
||
}
|
||
|
||
// 更新连接状态
|
||
*self.websocket_connected.write().await = false;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 订阅实时事件
|
||
pub fn subscribe(&self) -> broadcast::Receiver<RealtimeEvent> {
|
||
self.event_sender.subscribe()
|
||
}
|
||
|
||
/// 注册执行映射
|
||
pub async fn register_execution(&self, prompt_id: String, execution_id: String) {
|
||
let mut map = self.prompt_execution_map.write().await;
|
||
map.insert(prompt_id, execution_id);
|
||
}
|
||
|
||
/// 获取执行 ID
|
||
async fn get_execution_id(&self, prompt_id: &str) -> Option<String> {
|
||
let map = self.prompt_execution_map.read().await;
|
||
map.get(prompt_id).cloned()
|
||
}
|
||
|
||
/// 生成监控任务
|
||
async fn spawn_monitor_task(&self) -> Result<tokio::task::JoinHandle<()>> {
|
||
let manager = Arc::clone(&self.manager);
|
||
let repository = Arc::clone(&self.repository);
|
||
let event_sender = self.event_sender.clone();
|
||
let websocket_connected = Arc::clone(&self.websocket_connected);
|
||
let prompt_execution_map = Arc::clone(&self.prompt_execution_map);
|
||
|
||
let handle = tokio::spawn(async move {
|
||
let mut reconnect_interval = Duration::from_secs(5);
|
||
let max_reconnect_interval = Duration::from_secs(60);
|
||
|
||
loop {
|
||
// 检查 ComfyUI 连接状态
|
||
if !manager.is_connected().await {
|
||
warn!("ComfyUI 未连接,等待连接...");
|
||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||
continue;
|
||
}
|
||
|
||
// 尝试建立 WebSocket 连接
|
||
match Self::connect_websocket(&manager, &event_sender, &websocket_connected, &prompt_execution_map, &repository).await {
|
||
Ok(_) => {
|
||
info!("WebSocket 连接已建立");
|
||
reconnect_interval = Duration::from_secs(5); // 重置重连间隔
|
||
}
|
||
Err(e) => {
|
||
error!("WebSocket 连接失败: {}", e);
|
||
|
||
// 发送连接状态事件
|
||
let _ = event_sender.send(RealtimeEvent::ConnectionChanged {
|
||
connected: false,
|
||
message: format!("WebSocket 连接失败: {}", e),
|
||
});
|
||
|
||
// 等待后重试
|
||
tokio::time::sleep(reconnect_interval).await;
|
||
|
||
// 增加重连间隔(指数退避)
|
||
reconnect_interval = std::cmp::min(
|
||
reconnect_interval * 2,
|
||
max_reconnect_interval,
|
||
);
|
||
}
|
||
}
|
||
}
|
||
});
|
||
|
||
Ok(handle)
|
||
}
|
||
|
||
/// 连接 WebSocket
|
||
async fn connect_websocket(
|
||
manager: &Arc<ComfyUIManager>,
|
||
event_sender: &broadcast::Sender<RealtimeEvent>,
|
||
websocket_connected: &Arc<RwLock<bool>>,
|
||
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
|
||
repository: &Arc<ComfyUIRepository>,
|
||
) -> Result<()> {
|
||
// 获取客户端
|
||
let client = manager.get_client().await?;
|
||
|
||
// 建立 WebSocket 连接
|
||
let mut websocket = client.websocket().connect().await?;
|
||
|
||
// 更新连接状态
|
||
*websocket_connected.write().await = true;
|
||
|
||
// 发送连接状态事件
|
||
let _ = event_sender.send(RealtimeEvent::ConnectionChanged {
|
||
connected: true,
|
||
message: "WebSocket 连接已建立".to_string(),
|
||
});
|
||
|
||
// 处理 WebSocket 消息
|
||
while let Some(event) = websocket.next_event().await {
|
||
match event {
|
||
Ok(comfyui_event) => {
|
||
if let Err(e) = Self::handle_comfyui_event(
|
||
comfyui_event,
|
||
event_sender,
|
||
prompt_execution_map,
|
||
repository,
|
||
).await {
|
||
error!("处理 ComfyUI 事件失败: {}", e);
|
||
}
|
||
}
|
||
Err(e) => {
|
||
error!("WebSocket 错误: {}", e);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 连接断开
|
||
*websocket_connected.write().await = false;
|
||
|
||
let _ = event_sender.send(RealtimeEvent::ConnectionChanged {
|
||
connected: false,
|
||
message: "WebSocket 连接已断开".to_string(),
|
||
});
|
||
|
||
Err(anyhow!("WebSocket 连接断开"))
|
||
}
|
||
|
||
/// 处理 ComfyUI 事件
|
||
async fn handle_comfyui_event(
|
||
event: serde_json::Value, // 临时使用 Value 类型
|
||
event_sender: &broadcast::Sender<RealtimeEvent>,
|
||
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
|
||
repository: &Arc<ComfyUIRepository>,
|
||
) -> Result<()> {
|
||
// TODO: 实现事件处理逻辑
|
||
info!("Received ComfyUI event: {:?}", event);
|
||
Ok(())
|
||
}
|
||
|
||
// 处理执行事件 (临时禁用)
|
||
/*
|
||
async fn handle_execution_event(
|
||
event: ExecutionEvent,
|
||
event_sender: &broadcast::Sender<RealtimeEvent>,
|
||
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
|
||
repository: &Arc<ComfyUIRepository>,
|
||
) -> Result<()> {
|
||
let execution_id = {
|
||
let map = prompt_execution_map.read().await;
|
||
map.get(&event.prompt_id).cloned()
|
||
};
|
||
|
||
match event.event_type.as_str() {
|
||
"execution_start" => {
|
||
let _ = event_sender.send(RealtimeEvent::ExecutionStarted {
|
||
prompt_id: event.prompt_id.clone(),
|
||
execution_id: execution_id.clone(),
|
||
});
|
||
|
||
// 更新数据库中的执行状态
|
||
if let Some(exec_id) = execution_id {
|
||
if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await {
|
||
execution.update_status(ExecutionStatus::Running);
|
||
let _ = repository.update_execution(&execution).await;
|
||
}
|
||
}
|
||
}
|
||
"execution_complete" => {
|
||
// 提取输出信息
|
||
let outputs = event.data.unwrap_or_default();
|
||
let output_urls = Self::extract_output_urls(&outputs);
|
||
|
||
let _ = event_sender.send(RealtimeEvent::ExecutionCompleted {
|
||
prompt_id: event.prompt_id.clone(),
|
||
execution_id: execution_id.clone(),
|
||
outputs: outputs.clone(),
|
||
output_urls: output_urls.clone(),
|
||
});
|
||
|
||
// 更新数据库中的执行状态
|
||
if let Some(exec_id) = execution_id {
|
||
if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await {
|
||
execution.set_results(outputs, output_urls);
|
||
execution.node_outputs = Some(outputs);
|
||
let _ = repository.update_execution(&execution).await;
|
||
}
|
||
}
|
||
|
||
// 清理映射
|
||
{
|
||
let mut map = prompt_execution_map.write().await;
|
||
map.remove(&event.prompt_id);
|
||
}
|
||
}
|
||
"execution_error" => {
|
||
let error_msg = event.data
|
||
.and_then(|data| data.get("error"))
|
||
.and_then(|e| e.as_str())
|
||
.unwrap_or("未知错误")
|
||
.to_string();
|
||
|
||
let _ = event_sender.send(RealtimeEvent::ExecutionFailed {
|
||
prompt_id: event.prompt_id.clone(),
|
||
execution_id: execution_id.clone(),
|
||
error: error_msg.clone(),
|
||
});
|
||
|
||
// 更新数据库中的执行状态
|
||
if let Some(exec_id) = execution_id {
|
||
if let Ok(Some(mut execution)) = repository.get_execution(&exec_id).await {
|
||
execution.set_error(error_msg);
|
||
let _ = repository.update_execution(&execution).await;
|
||
}
|
||
}
|
||
|
||
// 清理映射
|
||
{
|
||
let mut map = prompt_execution_map.write().await;
|
||
map.remove(&event.prompt_id);
|
||
}
|
||
}
|
||
_ => {
|
||
debug!("未处理的执行事件类型: {}", event.event_type);
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 处理进度事件
|
||
async fn handle_progress_event(
|
||
event: ProgressEvent,
|
||
event_sender: &broadcast::Sender<RealtimeEvent>,
|
||
prompt_execution_map: &Arc<RwLock<HashMap<String, String>>>,
|
||
) -> Result<()> {
|
||
let execution_id = {
|
||
let map = prompt_execution_map.read().await;
|
||
map.get(&event.prompt_id).cloned()
|
||
};
|
||
|
||
let _ = event_sender.send(RealtimeEvent::ExecutionProgress {
|
||
prompt_id: event.prompt_id,
|
||
execution_id,
|
||
progress: event.progress,
|
||
current_node: event.current_node,
|
||
total_nodes: event.total_nodes,
|
||
});
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 处理队列事件
|
||
async fn handle_queue_event(
|
||
event: QueueEvent,
|
||
event_sender: &broadcast::Sender<RealtimeEvent>,
|
||
) -> Result<()> {
|
||
let _ = event_sender.send(RealtimeEvent::QueueUpdated {
|
||
running_count: event.running_count,
|
||
pending_count: event.pending_count,
|
||
});
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 提取输出 URLs
|
||
fn extract_output_urls(outputs: &HashMap<String, serde_json::Value>) -> Vec<String> {
|
||
let mut urls = Vec::new();
|
||
|
||
for (_, output) in outputs {
|
||
if let Some(images) = output.get("images").and_then(|v| v.as_array()) {
|
||
for image in images {
|
||
if let Some(filename) = image.get("filename").and_then(|v| v.as_str()) {
|
||
// 构建完整的 URL(这里需要根据实际的 ComfyUI 配置调整)
|
||
let url = format!("/view?filename={}", filename);
|
||
urls.push(url);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
urls
|
||
}
|
||
|
||
/// 检查 WebSocket 连接状态
|
||
pub async fn is_websocket_connected(&self) -> bool {
|
||
*self.websocket_connected.read().await
|
||
}
|
||
|
||
/// 获取监控统计信息
|
||
pub async fn get_monitor_stats(&self) -> MonitorStats {
|
||
let websocket_connected = *self.websocket_connected.read().await;
|
||
let prompt_map_size = self.prompt_execution_map.read().await.len();
|
||
let is_running = self.monitor_handle.read().await.is_some();
|
||
|
||
MonitorStats {
|
||
is_running,
|
||
websocket_connected,
|
||
tracked_executions: prompt_map_size,
|
||
event_subscribers: self.event_sender.receiver_count(),
|
||
}
|
||
}
|
||
*/
|
||
}
|
||
|
||
/// 监控统计信息
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||
pub struct MonitorStats {
|
||
pub is_running: bool,
|
||
pub websocket_connected: bool,
|
||
pub tracked_executions: usize,
|
||
pub event_subscribers: usize,
|
||
}
|