mixvideo-v2/apps/desktop/src-tauri/src/business/services/comfyui_manager.rs

273 lines
8.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! ComfyUI 管理器
//! 统一的 SDK 管理器,负责连接管理、健康检查和配置管理
use anyhow::{Result, anyhow};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{info, warn, error, debug};
use comfyui_sdk::client::ComfyUIClient;
use comfyui_sdk::types::{ComfyUIClientConfig, SystemStats, QueueStatus, ObjectInfo};
use crate::data::models::comfyui::{ComfyUIConfig, ValidationResult};
/// ComfyUI 连接状态
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionStatus {
/// 未连接
Disconnected,
/// 连接中
Connecting,
/// 已连接
Connected,
/// 连接失败
Failed(String),
}
/// ComfyUI 管理器
/// 提供统一的 ComfyUI 服务管理接口
pub struct ComfyUIManager {
/// SDK 客户端
client: Arc<RwLock<Option<ComfyUIClient>>>,
/// 配置信息
config: Arc<RwLock<ComfyUIConfig>>,
/// 连接状态
status: Arc<RwLock<ConnectionStatus>>,
/// 最后健康检查时间
last_health_check: Arc<RwLock<Option<std::time::Instant>>>,
}
impl ComfyUIManager {
/// 创建新的 ComfyUI 管理器
pub fn new(config: ComfyUIConfig) -> Result<Self> {
// 验证配置
let validation = config.validate();
if !validation.valid {
return Err(anyhow!("配置验证失败: {:?}", validation.errors));
}
Ok(Self {
client: Arc::new(RwLock::new(None)),
config: Arc::new(RwLock::new(config)),
status: Arc::new(RwLock::new(ConnectionStatus::Disconnected)),
last_health_check: Arc::new(RwLock::new(None)),
})
}
/// 连接到 ComfyUI 服务
pub async fn connect(&self) -> Result<()> {
info!("开始连接 ComfyUI 服务");
// 更新状态为连接中
*self.status.write().await = ConnectionStatus::Connecting;
let config = self.config.read().await;
let sdk_config = config.to_sdk_config();
match ComfyUIClient::new(sdk_config) {
Ok(client) => {
// 测试连接
match client.get_system_stats().await {
Ok(stats) => {
*self.client.write().await = Some(client);
*self.status.write().await = ConnectionStatus::Connected;
*self.last_health_check.write().await = Some(std::time::Instant::now());
info!("ComfyUI 连接成功,系统信息: OS={}, Python={}",
stats.system.os, stats.system.python_version);
Ok(())
}
Err(e) => {
let error_msg = format!("连接测试失败: {}", e);
*self.status.write().await = ConnectionStatus::Failed(error_msg.clone());
Err(anyhow!(error_msg))
}
}
}
Err(e) => {
let error_msg = format!("创建客户端失败: {}", e);
*self.status.write().await = ConnectionStatus::Failed(error_msg.clone());
Err(anyhow!(error_msg))
}
}
}
/// 断开连接
pub async fn disconnect(&self) -> Result<()> {
info!("断开 ComfyUI 连接");
*self.client.write().await = None;
*self.status.write().await = ConnectionStatus::Disconnected;
*self.last_health_check.write().await = None;
Ok(())
}
/// 检查是否已连接
pub async fn is_connected(&self) -> bool {
matches!(*self.status.read().await, ConnectionStatus::Connected)
}
/// 获取连接状态
pub async fn get_connection_status(&self) -> ConnectionStatus {
self.status.read().await.clone()
}
/// 健康检查
pub async fn health_check(&self) -> Result<bool> {
let client_guard = self.client.read().await;
let client = match client_guard.as_ref() {
Some(client) => client,
None => {
warn!("客户端未连接,健康检查失败");
return Ok(false);
}
};
match client.get_system_stats().await {
Ok(_) => {
*self.last_health_check.write().await = Some(std::time::Instant::now());
debug!("健康检查通过");
Ok(true)
}
Err(e) => {
warn!("健康检查失败: {}", e);
// 如果健康检查失败,更新连接状态
*self.status.write().await = ConnectionStatus::Failed(format!("健康检查失败: {}", e));
Ok(false)
}
}
}
/// 获取系统信息
pub async fn get_system_info(&self) -> Result<SystemStats> {
let client_guard = self.client.read().await;
let client = match client_guard.as_ref() {
Some(client) => client,
None => return Err(anyhow!("客户端未连接")),
};
client.get_system_stats().await
.map_err(|e| anyhow!("获取系统信息失败: {}", e))
}
/// 获取队列状态
pub async fn get_queue_status(&self) -> Result<QueueStatus> {
let client_guard = self.client.read().await;
let client = match client_guard.as_ref() {
Some(client) => client,
None => return Err(anyhow!("客户端未连接")),
};
client.get_queue_status().await
.map_err(|e| anyhow!("获取队列状态失败: {}", e))
}
/// 获取对象信息
pub async fn get_object_info(&self) -> Result<ObjectInfo> {
let client_guard = self.client.read().await;
let client = match client_guard.as_ref() {
Some(client) => client,
None => return Err(anyhow!("客户端未连接")),
};
client.get_object_info().await
.map_err(|e| anyhow!("获取对象信息失败: {}", e))
}
/// 更新配置
pub async fn update_config(&self, new_config: ComfyUIConfig) -> Result<()> {
// 验证新配置
let validation = new_config.validate();
if !validation.valid {
return Err(anyhow!("配置验证失败: {:?}", validation.errors));
}
let mut config = self.config.write().await;
let old_base_url = config.base_url.clone();
*config = new_config;
// 如果 URL 发生变化,需要重新连接
if old_base_url != config.base_url {
drop(config); // 释放锁
info!("配置已更新URL 发生变化,需要重新连接");
self.disconnect().await?;
} else {
info!("配置已更新");
}
Ok(())
}
/// 获取当前配置
pub async fn get_config(&self) -> ComfyUIConfig {
self.config.read().await.clone()
}
/// 检查客户端是否连接
pub async fn is_connected(&self) -> bool {
let client_guard = self.client.read().await;
client_guard.is_some()
}
/// 自动重连
pub async fn auto_reconnect(&self) -> Result<()> {
info!("尝试自动重连");
// 先断开现有连接
self.disconnect().await?;
// 等待一段时间后重连
tokio::time::sleep(Duration::from_secs(2)).await;
// 重新连接
self.connect().await
}
/// 检查是否需要健康检查
pub async fn should_health_check(&self, interval: Duration) -> bool {
let last_check = self.last_health_check.read().await;
match *last_check {
Some(last) => last.elapsed() >= interval,
None => true,
}
}
/// 获取连接统计信息
pub async fn get_connection_stats(&self) -> ConnectionStats {
let status = self.status.read().await.clone();
let last_check = *self.last_health_check.read().await;
let config = self.config.read().await;
ConnectionStats {
status,
base_url: config.base_url.clone(),
last_health_check: last_check,
timeout_seconds: config.timeout_seconds,
retry_attempts: config.retry_attempts,
}
}
}
/// 连接统计信息
#[derive(Debug, Clone)]
pub struct ConnectionStats {
pub status: ConnectionStatus,
pub base_url: String,
pub last_health_check: Option<std::time::Instant>,
pub timeout_seconds: u64,
pub retry_attempts: u32,
}
impl std::fmt::Display for ConnectionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionStatus::Disconnected => write!(f, "未连接"),
ConnectionStatus::Connecting => write!(f, "连接中"),
ConnectionStatus::Connected => write!(f, "已连接"),
ConnectionStatus::Failed(msg) => write!(f, "连接失败: {}", msg),
}
}
}