mixvideo-v2/cargos/comfyui-sdk/client/comfyui_client.rs

284 lines
9.5 KiB
Rust

//! Main ComfyUI client that combines HTTP and WebSocket functionality
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use crate::types::{
ComfyUIClientConfig, PromptRequest, TemplateExecutionResult,
ExecutionOptions, ExecutionCallbacks, ExecutionError, ParameterValues
};
use crate::templates::{TemplateManager, WorkflowTemplate, WorkflowInstance};
use crate::client::{HTTPClient, WebSocketClient};
use crate::error::{ComfyUIError, Result};
/// Main ComfyUI client
pub struct ComfyUIClient {
http_client: HTTPClient,
ws_client: WebSocketClient,
template_manager: TemplateManager,
config: ComfyUIClientConfig,
}
impl ComfyUIClient {
/// Creates a new ComfyUI client
pub fn new(config: ComfyUIClientConfig) -> Result<Self> {
let http_client = HTTPClient::new(config.clone())?;
let ws_client = WebSocketClient::new(config.clone());
let template_manager = TemplateManager::new();
Ok(Self {
http_client,
ws_client,
template_manager,
config,
})
}
/// Gets the HTTP client
pub fn http(&self) -> &HTTPClient {
&self.http_client
}
/// Gets the WebSocket client
pub fn ws(&self) -> &WebSocketClient {
&self.ws_client
}
/// Gets the template manager
pub fn templates(&mut self) -> &mut TemplateManager {
&mut self.template_manager
}
/// Gets the template manager (read-only)
pub fn templates_ref(&self) -> &TemplateManager {
&self.template_manager
}
/// Gets the client configuration
pub fn config(&self) -> &ComfyUIClientConfig {
&self.config
}
/// Connects to ComfyUI server (both HTTP and WebSocket)
pub async fn connect(&mut self) -> Result<()> {
// Test HTTP connection first
self.http_client.get_queue().await
.map_err(|e| ComfyUIError::connection(format!("HTTP connection failed: {e}")))?;
// Connect WebSocket
self.ws_client.connect().await
.map_err(|e| ComfyUIError::connection(format!("WebSocket connection failed: {e}")))?;
log::info!("Successfully connected to ComfyUI server");
Ok(())
}
/// Disconnects from ComfyUI server
pub async fn disconnect(&mut self) -> Result<()> {
self.ws_client.disconnect().await?;
log::info!("Disconnected from ComfyUI server");
Ok(())
}
/// Checks if connected to the server
pub async fn is_connected(&self) -> bool {
self.ws_client.is_connected().await
}
/// Executes a workflow template
pub async fn execute_template(
&self,
template: &WorkflowTemplate,
parameters: ParameterValues,
options: ExecutionOptions,
) -> Result<TemplateExecutionResult> {
let start_time = Instant::now();
// Create workflow instance
let instance = template.create_instance(parameters)?;
// Convert to prompt request
let prompt_request = self.instance_to_prompt_request(&instance)?;
// Submit prompt
let prompt_response = self.http_client.queue_prompt(&prompt_request).await?;
let prompt_id = prompt_response.prompt_id.clone();
log::info!("Submitted prompt with ID: {prompt_id}");
// Wait for completion if WebSocket is connected
if self.ws_client.is_connected().await {
match self.wait_for_completion(&prompt_id, options.timeout).await {
Ok(outputs) => {
let execution_time = start_time.elapsed().as_millis() as u64;
Ok(TemplateExecutionResult::success(prompt_id, outputs, execution_time))
}
Err(error) => {
let execution_time = start_time.elapsed().as_millis() as u64;
let exec_error = ExecutionError {
node_id: None,
message: error.to_string(),
details: None,
timestamp: chrono::Utc::now(),
};
Ok(TemplateExecutionResult::failure(prompt_id, exec_error, execution_time))
}
}
} else {
// If no WebSocket, just return the prompt ID
let execution_time = start_time.elapsed().as_millis() as u64;
Ok(TemplateExecutionResult::success(prompt_id, HashMap::new(), execution_time))
}
}
/// Executes a template with callbacks
pub async fn execute_template_with_callbacks(
&self,
template: &WorkflowTemplate,
parameters: ParameterValues,
options: ExecutionOptions,
callbacks: Arc<dyn ExecutionCallbacks>,
) -> Result<TemplateExecutionResult> {
// Register callbacks
let callback_id = self.ws_client.register_callbacks(callbacks).await;
// Execute template
let result = self.execute_template(template, parameters, options).await;
// Unregister callbacks
self.ws_client.unregister_callbacks(&callback_id).await;
result
}
/// Converts a workflow instance to a prompt request
fn instance_to_prompt_request(&self, instance: &WorkflowInstance) -> Result<PromptRequest> {
let workflow_json = instance.to_workflow_json()?;
let mut prompt = HashMap::new();
if let serde_json::Value::Object(workflow_obj) = workflow_json {
for (key, value) in workflow_obj {
prompt.insert(key, value);
}
}
Ok(PromptRequest {
prompt,
client_id: Some(self.ws_client.client_id().to_string()),
extra_data: None,
})
}
/// Waits for prompt completion
async fn wait_for_completion(
&self,
prompt_id: &str,
timeout: Option<Duration>,
) -> Result<HashMap<String, serde_json::Value>> {
// Clear any previous errors before starting
self.ws_client.clear_last_error().await;
let timeout_duration = timeout.unwrap_or(Duration::from_secs(300)); // 5 minutes default
let start_time = Instant::now();
let check_interval = Duration::from_millis(1000); // Check every second
loop {
// Check timeout
if start_time.elapsed() > timeout_duration {
return Err(ComfyUIError::timeout(format!(
"Execution timed out after {timeout_duration:?}"
)));
}
// Check for WebSocket execution errors first
if let Some(error) = self.ws_client.get_last_error().await {
return Err(ComfyUIError::execution(format!(
"Execution failed: {}", error.message
)));
}
// Check history for completion
match self.http_client.get_history_by_prompt(prompt_id).await {
Ok(history) => {
if let Some(history_item) = history.get(prompt_id) {
if history_item.status.completed {
// Extract outputs
let mut outputs = HashMap::new();
for (node_id, node_outputs) in &history_item.outputs {
let mut node_output = HashMap::new();
for (output_name, output_list) in node_outputs {
node_output.insert(output_name.clone(), serde_json::to_value(output_list)?);
}
outputs.insert(node_id.clone(), serde_json::to_value(node_output)?);
}
return Ok(outputs);
}
}
}
Err(e) => {
log::warn!("Error checking history: {e}");
}
}
sleep(check_interval).await;
}
}
/// Converts outputs to HTTP URLs
pub fn outputs_to_urls(&self, outputs: &HashMap<String, serde_json::Value>) -> Vec<String> {
self.http_client.outputs_to_urls(outputs)
}
/// Interrupts the current execution
pub async fn interrupt(&self) -> Result<()> {
self.http_client.interrupt().await
}
/// Clears the execution queue
pub async fn clear_queue(&self) -> Result<()> {
self.http_client.clear_queue().await
}
/// Gets the current queue status
pub async fn get_queue_status(&self) -> Result<crate::types::QueueStatus> {
self.http_client.get_queue().await
}
/// Gets system statistics
pub async fn get_system_stats(&self) -> Result<crate::types::SystemStats> {
self.http_client.get_system_stats().await
}
/// Gets available nodes information
pub async fn get_object_info(&self) -> Result<crate::types::ObjectInfo> {
self.http_client.get_object_info().await
}
/// Uploads an image file
pub async fn upload_image<P: AsRef<std::path::Path>>(
&self,
image_path: P,
overwrite: bool,
) -> Result<crate::types::UploadImageResponse> {
self.http_client.upload_image(image_path, overwrite).await
}
/// Gets an image by filename
pub async fn get_image(
&self,
filename: &str,
subfolder: Option<&str>,
image_type: Option<&str>,
) -> Result<bytes::Bytes> {
self.http_client.get_image(filename, subfolder, image_type).await
}
/// Frees memory
pub async fn free_memory(&self) -> Result<()> {
self.http_client.free_memory().await
}
}