275 lines
9.2 KiB
Rust
275 lines
9.2 KiB
Rust
//! Main ComfyUI client that combines HTTP and WebSocket functionality
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
use tokio::time::sleep;
|
|
|
|
use crate::types::{
|
|
ComfyUIClientConfig, PromptRequest, TemplateExecutionResult,
|
|
ExecutionOptions, ExecutionCallbacks, ExecutionError, ParameterValues
|
|
};
|
|
use crate::templates::{TemplateManager, WorkflowTemplate, WorkflowInstance};
|
|
use crate::client::{HTTPClient, WebSocketClient};
|
|
use crate::error::{ComfyUIError, Result};
|
|
|
|
|
|
/// Main ComfyUI client
|
|
pub struct ComfyUIClient {
|
|
http_client: HTTPClient,
|
|
ws_client: WebSocketClient,
|
|
template_manager: TemplateManager,
|
|
config: ComfyUIClientConfig,
|
|
}
|
|
|
|
impl ComfyUIClient {
|
|
/// Creates a new ComfyUI client
|
|
pub fn new(config: ComfyUIClientConfig) -> Result<Self> {
|
|
let http_client = HTTPClient::new(config.clone())?;
|
|
let ws_client = WebSocketClient::new(config.clone());
|
|
let template_manager = TemplateManager::new();
|
|
|
|
Ok(Self {
|
|
http_client,
|
|
ws_client,
|
|
template_manager,
|
|
config,
|
|
})
|
|
}
|
|
|
|
/// Gets the HTTP client
|
|
pub fn http(&self) -> &HTTPClient {
|
|
&self.http_client
|
|
}
|
|
|
|
/// Gets the WebSocket client
|
|
pub fn ws(&self) -> &WebSocketClient {
|
|
&self.ws_client
|
|
}
|
|
|
|
/// Gets the template manager
|
|
pub fn templates(&mut self) -> &mut TemplateManager {
|
|
&mut self.template_manager
|
|
}
|
|
|
|
/// Gets the template manager (read-only)
|
|
pub fn templates_ref(&self) -> &TemplateManager {
|
|
&self.template_manager
|
|
}
|
|
|
|
/// Gets the client configuration
|
|
pub fn config(&self) -> &ComfyUIClientConfig {
|
|
&self.config
|
|
}
|
|
|
|
/// Connects to ComfyUI server (both HTTP and WebSocket)
|
|
pub async fn connect(&mut self) -> Result<()> {
|
|
// Test HTTP connection first
|
|
self.http_client.get_queue().await
|
|
.map_err(|e| ComfyUIError::connection(format!("HTTP connection failed: {}", e)))?;
|
|
|
|
// Connect WebSocket
|
|
self.ws_client.connect().await
|
|
.map_err(|e| ComfyUIError::connection(format!("WebSocket connection failed: {}", e)))?;
|
|
|
|
log::info!("Successfully connected to ComfyUI server");
|
|
Ok(())
|
|
}
|
|
|
|
/// Disconnects from ComfyUI server
|
|
pub async fn disconnect(&mut self) -> Result<()> {
|
|
self.ws_client.disconnect().await?;
|
|
log::info!("Disconnected from ComfyUI server");
|
|
Ok(())
|
|
}
|
|
|
|
/// Checks if connected to the server
|
|
pub async fn is_connected(&self) -> bool {
|
|
self.ws_client.is_connected().await
|
|
}
|
|
|
|
/// Executes a workflow template
|
|
pub async fn execute_template(
|
|
&self,
|
|
template: &WorkflowTemplate,
|
|
parameters: ParameterValues,
|
|
options: ExecutionOptions,
|
|
) -> Result<TemplateExecutionResult> {
|
|
let start_time = Instant::now();
|
|
|
|
// Create workflow instance
|
|
let instance = template.create_instance(parameters)?;
|
|
|
|
// Convert to prompt request
|
|
let prompt_request = self.instance_to_prompt_request(&instance)?;
|
|
|
|
// Submit prompt
|
|
let prompt_response = self.http_client.queue_prompt(&prompt_request).await?;
|
|
let prompt_id = prompt_response.prompt_id.clone();
|
|
|
|
log::info!("Submitted prompt with ID: {}", prompt_id);
|
|
|
|
// Wait for completion if WebSocket is connected
|
|
if self.ws_client.is_connected().await {
|
|
match self.wait_for_completion(&prompt_id, options.timeout).await {
|
|
Ok(outputs) => {
|
|
let execution_time = start_time.elapsed().as_millis() as u64;
|
|
Ok(TemplateExecutionResult::success(prompt_id, outputs, execution_time))
|
|
}
|
|
Err(error) => {
|
|
let execution_time = start_time.elapsed().as_millis() as u64;
|
|
let exec_error = ExecutionError {
|
|
node_id: None,
|
|
message: error.to_string(),
|
|
details: None,
|
|
timestamp: chrono::Utc::now(),
|
|
};
|
|
Ok(TemplateExecutionResult::failure(prompt_id, exec_error, execution_time))
|
|
}
|
|
}
|
|
} else {
|
|
// If no WebSocket, just return the prompt ID
|
|
let execution_time = start_time.elapsed().as_millis() as u64;
|
|
Ok(TemplateExecutionResult::success(prompt_id, HashMap::new(), execution_time))
|
|
}
|
|
}
|
|
|
|
/// Executes a template with callbacks
|
|
pub async fn execute_template_with_callbacks(
|
|
&self,
|
|
template: &WorkflowTemplate,
|
|
parameters: ParameterValues,
|
|
options: ExecutionOptions,
|
|
callbacks: Arc<dyn ExecutionCallbacks>,
|
|
) -> Result<TemplateExecutionResult> {
|
|
// Register callbacks
|
|
let callback_id = self.ws_client.register_callbacks(callbacks).await;
|
|
|
|
// Execute template
|
|
let result = self.execute_template(template, parameters, options).await;
|
|
|
|
// Unregister callbacks
|
|
self.ws_client.unregister_callbacks(&callback_id).await;
|
|
|
|
result
|
|
}
|
|
|
|
/// Converts a workflow instance to a prompt request
|
|
fn instance_to_prompt_request(&self, instance: &WorkflowInstance) -> Result<PromptRequest> {
|
|
let workflow_json = instance.to_workflow_json()?;
|
|
|
|
let mut prompt = HashMap::new();
|
|
if let serde_json::Value::Object(workflow_obj) = workflow_json {
|
|
for (key, value) in workflow_obj {
|
|
prompt.insert(key, value);
|
|
}
|
|
}
|
|
|
|
Ok(PromptRequest {
|
|
prompt,
|
|
client_id: Some(self.ws_client.client_id().to_string()),
|
|
extra_data: None,
|
|
})
|
|
}
|
|
|
|
/// Waits for prompt completion
|
|
async fn wait_for_completion(
|
|
&self,
|
|
prompt_id: &str,
|
|
timeout: Option<Duration>,
|
|
) -> Result<HashMap<String, serde_json::Value>> {
|
|
let timeout_duration = timeout.unwrap_or(Duration::from_secs(300)); // 5 minutes default
|
|
let start_time = Instant::now();
|
|
let check_interval = Duration::from_millis(1000); // Check every second
|
|
|
|
loop {
|
|
// Check timeout
|
|
if start_time.elapsed() > timeout_duration {
|
|
return Err(ComfyUIError::timeout(format!(
|
|
"Execution timed out after {:?}",
|
|
timeout_duration
|
|
)));
|
|
}
|
|
|
|
// Check history for completion
|
|
match self.http_client.get_history_by_prompt(prompt_id).await {
|
|
Ok(history) => {
|
|
if let Some(history_item) = history.get(prompt_id) {
|
|
if history_item.status.completed {
|
|
// Extract outputs
|
|
let mut outputs = HashMap::new();
|
|
for (node_id, node_outputs) in &history_item.outputs {
|
|
let mut node_output = HashMap::new();
|
|
for (output_name, output_list) in node_outputs {
|
|
node_output.insert(output_name.clone(), serde_json::to_value(output_list)?);
|
|
}
|
|
outputs.insert(node_id.clone(), serde_json::to_value(node_output)?);
|
|
}
|
|
return Ok(outputs);
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
log::warn!("Error checking history: {}", e);
|
|
}
|
|
}
|
|
|
|
sleep(check_interval).await;
|
|
}
|
|
}
|
|
|
|
/// Converts outputs to HTTP URLs
|
|
pub fn outputs_to_urls(&self, outputs: &HashMap<String, serde_json::Value>) -> Vec<String> {
|
|
self.http_client.outputs_to_urls(outputs)
|
|
}
|
|
|
|
/// Interrupts the current execution
|
|
pub async fn interrupt(&self) -> Result<()> {
|
|
self.http_client.interrupt().await
|
|
}
|
|
|
|
/// Clears the execution queue
|
|
pub async fn clear_queue(&self) -> Result<()> {
|
|
self.http_client.clear_queue().await
|
|
}
|
|
|
|
/// Gets the current queue status
|
|
pub async fn get_queue_status(&self) -> Result<crate::types::QueueStatus> {
|
|
self.http_client.get_queue().await
|
|
}
|
|
|
|
/// Gets system statistics
|
|
pub async fn get_system_stats(&self) -> Result<crate::types::SystemStats> {
|
|
self.http_client.get_system_stats().await
|
|
}
|
|
|
|
/// Gets available nodes information
|
|
pub async fn get_object_info(&self) -> Result<crate::types::ObjectInfo> {
|
|
self.http_client.get_object_info().await
|
|
}
|
|
|
|
/// Uploads an image file
|
|
pub async fn upload_image<P: AsRef<std::path::Path>>(
|
|
&self,
|
|
image_path: P,
|
|
overwrite: bool,
|
|
) -> Result<crate::types::UploadImageResponse> {
|
|
self.http_client.upload_image(image_path, overwrite).await
|
|
}
|
|
|
|
/// Gets an image by filename
|
|
pub async fn get_image(
|
|
&self,
|
|
filename: &str,
|
|
subfolder: Option<&str>,
|
|
image_type: Option<&str>,
|
|
) -> Result<bytes::Bytes> {
|
|
self.http_client.get_image(filename, subfolder, image_type).await
|
|
}
|
|
|
|
/// Frees memory
|
|
pub async fn free_memory(&self) -> Result<()> {
|
|
self.http_client.free_memory().await
|
|
}
|
|
}
|