273 lines
8.8 KiB
Rust
273 lines
8.8 KiB
Rust
//! ComfyUI 管理器
|
||
//! 统一的 SDK 管理器,负责连接管理、健康检查和配置管理
|
||
|
||
use anyhow::{Result, anyhow};
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
use tokio::sync::RwLock;
|
||
use tracing::{info, warn, error, debug};
|
||
|
||
use comfyui_sdk::client::ComfyUIClient;
|
||
use comfyui_sdk::types::{ComfyUIClientConfig, SystemStats, QueueStatus, ObjectInfo};
|
||
|
||
use crate::data::models::comfyui::{ComfyUIConfig, ValidationResult};
|
||
|
||
/// ComfyUI 连接状态
|
||
#[derive(Debug, Clone, PartialEq)]
|
||
pub enum ConnectionStatus {
|
||
/// 未连接
|
||
Disconnected,
|
||
/// 连接中
|
||
Connecting,
|
||
/// 已连接
|
||
Connected,
|
||
/// 连接失败
|
||
Failed(String),
|
||
}
|
||
|
||
/// ComfyUI 管理器
|
||
/// 提供统一的 ComfyUI 服务管理接口
|
||
pub struct ComfyUIManager {
|
||
/// SDK 客户端
|
||
client: Arc<RwLock<Option<ComfyUIClient>>>,
|
||
/// 配置信息
|
||
config: Arc<RwLock<ComfyUIConfig>>,
|
||
/// 连接状态
|
||
status: Arc<RwLock<ConnectionStatus>>,
|
||
/// 最后健康检查时间
|
||
last_health_check: Arc<RwLock<Option<std::time::Instant>>>,
|
||
}
|
||
|
||
impl ComfyUIManager {
|
||
/// 创建新的 ComfyUI 管理器
|
||
pub fn new(config: ComfyUIConfig) -> Result<Self> {
|
||
// 验证配置
|
||
let validation = config.validate();
|
||
if !validation.valid {
|
||
return Err(anyhow!("配置验证失败: {:?}", validation.errors));
|
||
}
|
||
|
||
Ok(Self {
|
||
client: Arc::new(RwLock::new(None)),
|
||
config: Arc::new(RwLock::new(config)),
|
||
status: Arc::new(RwLock::new(ConnectionStatus::Disconnected)),
|
||
last_health_check: Arc::new(RwLock::new(None)),
|
||
})
|
||
}
|
||
|
||
/// 连接到 ComfyUI 服务
|
||
pub async fn connect(&self) -> Result<()> {
|
||
info!("开始连接 ComfyUI 服务");
|
||
|
||
// 更新状态为连接中
|
||
*self.status.write().await = ConnectionStatus::Connecting;
|
||
|
||
let config = self.config.read().await;
|
||
let sdk_config = config.to_sdk_config();
|
||
|
||
match ComfyUIClient::new(sdk_config) {
|
||
Ok(client) => {
|
||
// 测试连接
|
||
match client.get_system_stats().await {
|
||
Ok(stats) => {
|
||
*self.client.write().await = Some(client);
|
||
*self.status.write().await = ConnectionStatus::Connected;
|
||
*self.last_health_check.write().await = Some(std::time::Instant::now());
|
||
|
||
info!("ComfyUI 连接成功,系统信息: OS={}, Python={}",
|
||
stats.system.os, stats.system.python_version);
|
||
Ok(())
|
||
}
|
||
Err(e) => {
|
||
let error_msg = format!("连接测试失败: {}", e);
|
||
*self.status.write().await = ConnectionStatus::Failed(error_msg.clone());
|
||
Err(anyhow!(error_msg))
|
||
}
|
||
}
|
||
}
|
||
Err(e) => {
|
||
let error_msg = format!("创建客户端失败: {}", e);
|
||
*self.status.write().await = ConnectionStatus::Failed(error_msg.clone());
|
||
Err(anyhow!(error_msg))
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 断开连接
|
||
pub async fn disconnect(&self) -> Result<()> {
|
||
info!("断开 ComfyUI 连接");
|
||
|
||
*self.client.write().await = None;
|
||
*self.status.write().await = ConnectionStatus::Disconnected;
|
||
*self.last_health_check.write().await = None;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 检查是否已连接
|
||
pub async fn is_connected(&self) -> bool {
|
||
matches!(*self.status.read().await, ConnectionStatus::Connected)
|
||
}
|
||
|
||
/// 获取连接状态
|
||
pub async fn get_connection_status(&self) -> ConnectionStatus {
|
||
self.status.read().await.clone()
|
||
}
|
||
|
||
/// 健康检查
|
||
pub async fn health_check(&self) -> Result<bool> {
|
||
let client_guard = self.client.read().await;
|
||
let client = match client_guard.as_ref() {
|
||
Some(client) => client,
|
||
None => {
|
||
warn!("客户端未连接,健康检查失败");
|
||
return Ok(false);
|
||
}
|
||
};
|
||
|
||
match client.get_system_stats().await {
|
||
Ok(_) => {
|
||
*self.last_health_check.write().await = Some(std::time::Instant::now());
|
||
debug!("健康检查通过");
|
||
Ok(true)
|
||
}
|
||
Err(e) => {
|
||
warn!("健康检查失败: {}", e);
|
||
// 如果健康检查失败,更新连接状态
|
||
*self.status.write().await = ConnectionStatus::Failed(format!("健康检查失败: {}", e));
|
||
Ok(false)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 获取系统信息
|
||
pub async fn get_system_info(&self) -> Result<SystemStats> {
|
||
let client_guard = self.client.read().await;
|
||
let client = match client_guard.as_ref() {
|
||
Some(client) => client,
|
||
None => return Err(anyhow!("客户端未连接")),
|
||
};
|
||
|
||
client.get_system_stats().await
|
||
.map_err(|e| anyhow!("获取系统信息失败: {}", e))
|
||
}
|
||
|
||
/// 获取队列状态
|
||
pub async fn get_queue_status(&self) -> Result<QueueStatus> {
|
||
let client_guard = self.client.read().await;
|
||
let client = match client_guard.as_ref() {
|
||
Some(client) => client,
|
||
None => return Err(anyhow!("客户端未连接")),
|
||
};
|
||
|
||
client.get_queue_status().await
|
||
.map_err(|e| anyhow!("获取队列状态失败: {}", e))
|
||
}
|
||
|
||
/// 获取对象信息
|
||
pub async fn get_object_info(&self) -> Result<ObjectInfo> {
|
||
let client_guard = self.client.read().await;
|
||
let client = match client_guard.as_ref() {
|
||
Some(client) => client,
|
||
None => return Err(anyhow!("客户端未连接")),
|
||
};
|
||
|
||
client.get_object_info().await
|
||
.map_err(|e| anyhow!("获取对象信息失败: {}", e))
|
||
}
|
||
|
||
/// 更新配置
|
||
pub async fn update_config(&self, new_config: ComfyUIConfig) -> Result<()> {
|
||
// 验证新配置
|
||
let validation = new_config.validate();
|
||
if !validation.valid {
|
||
return Err(anyhow!("配置验证失败: {:?}", validation.errors));
|
||
}
|
||
|
||
let mut config = self.config.write().await;
|
||
let old_base_url = config.base_url.clone();
|
||
*config = new_config;
|
||
|
||
// 如果 URL 发生变化,需要重新连接
|
||
if old_base_url != config.base_url {
|
||
drop(config); // 释放锁
|
||
info!("配置已更新,URL 发生变化,需要重新连接");
|
||
self.disconnect().await?;
|
||
} else {
|
||
info!("配置已更新");
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取当前配置
|
||
pub async fn get_config(&self) -> ComfyUIConfig {
|
||
self.config.read().await.clone()
|
||
}
|
||
|
||
/// 检查客户端是否连接
|
||
pub async fn is_connected(&self) -> bool {
|
||
let client_guard = self.client.read().await;
|
||
client_guard.is_some()
|
||
}
|
||
|
||
/// 自动重连
|
||
pub async fn auto_reconnect(&self) -> Result<()> {
|
||
info!("尝试自动重连");
|
||
|
||
// 先断开现有连接
|
||
self.disconnect().await?;
|
||
|
||
// 等待一段时间后重连
|
||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||
|
||
// 重新连接
|
||
self.connect().await
|
||
}
|
||
|
||
/// 检查是否需要健康检查
|
||
pub async fn should_health_check(&self, interval: Duration) -> bool {
|
||
let last_check = self.last_health_check.read().await;
|
||
match *last_check {
|
||
Some(last) => last.elapsed() >= interval,
|
||
None => true,
|
||
}
|
||
}
|
||
|
||
/// 获取连接统计信息
|
||
pub async fn get_connection_stats(&self) -> ConnectionStats {
|
||
let status = self.status.read().await.clone();
|
||
let last_check = *self.last_health_check.read().await;
|
||
let config = self.config.read().await;
|
||
|
||
ConnectionStats {
|
||
status,
|
||
base_url: config.base_url.clone(),
|
||
last_health_check: last_check,
|
||
timeout_seconds: config.timeout_seconds,
|
||
retry_attempts: config.retry_attempts,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 连接统计信息
|
||
#[derive(Debug, Clone)]
|
||
pub struct ConnectionStats {
|
||
pub status: ConnectionStatus,
|
||
pub base_url: String,
|
||
pub last_health_check: Option<std::time::Instant>,
|
||
pub timeout_seconds: u64,
|
||
pub retry_attempts: u32,
|
||
}
|
||
|
||
impl std::fmt::Display for ConnectionStatus {
|
||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
match self {
|
||
ConnectionStatus::Disconnected => write!(f, "未连接"),
|
||
ConnectionStatus::Connecting => write!(f, "连接中"),
|
||
ConnectionStatus::Connected => write!(f, "已连接"),
|
||
ConnectionStatus::Failed(msg) => write!(f, "连接失败: {}", msg),
|
||
}
|
||
}
|
||
}
|