mixvideo-v2/cargos/comfyui-sdk/client/http_client.rs

316 lines
10 KiB
Rust

//! HTTP client for ComfyUI API
use std::collections::HashMap;
use std::path::Path;
use reqwest::{Client, multipart};
use url::Url;
use crate::types::{
ComfyUIClientConfig, PromptRequest, PromptResponse, QueueStatus,
HistoryItem, SystemStats, ObjectInfo, UploadImageResponse, ViewMetadata
};
use crate::error::{ComfyUIError, Result};
use crate::utils::error_handling::{retry_if_retryable, RetryConfig, with_timeout};
/// HTTP client for ComfyUI API
#[derive(Debug, Clone)]
pub struct HTTPClient {
client: Client,
base_url: Url,
config: ComfyUIClientConfig,
}
impl HTTPClient {
/// Creates a new HTTP client
pub fn new(config: ComfyUIClientConfig) -> Result<Self> {
let base_url = Url::parse(&config.base_url)?;
let mut client_builder = Client::builder();
if let Some(timeout) = config.timeout {
client_builder = client_builder.timeout(timeout);
}
let client = client_builder.build()?;
Ok(Self {
client,
base_url,
config,
})
}
/// Gets the base URL
pub fn base_url(&self) -> &Url {
&self.base_url
}
/// Builds a URL for an endpoint
fn build_url(&self, endpoint: &str) -> Result<Url> {
self.base_url.join(endpoint)
.map_err(ComfyUIError::from)
}
/// Executes a GET request with retry logic
async fn get_with_retry<T>(&self, endpoint: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let url = self.build_url(endpoint)?;
let retry_config = RetryConfig {
max_attempts: self.config.retry_attempts.unwrap_or(3),
initial_delay: self.config.retry_delay.unwrap_or(std::time::Duration::from_millis(1000)),
..Default::default()
};
let operation = || async {
let mut request = self.client.get(url.clone());
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
let response = request.send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(
response.error_for_status().unwrap_err()
));
}
let data: T = response.json().await?;
Ok(data)
};
if let Some(timeout) = self.config.timeout {
with_timeout(retry_if_retryable(operation, retry_config), timeout).await
} else {
retry_if_retryable(operation, retry_config).await
}
}
/// Executes a POST request with retry logic
async fn post_with_retry<T, B>(&self, endpoint: &str, body: &B) -> Result<T>
where
T: serde::de::DeserializeOwned,
B: serde::Serialize,
{
let url = self.build_url(endpoint)?;
let retry_config = RetryConfig {
max_attempts: self.config.retry_attempts.unwrap_or(3),
initial_delay: self.config.retry_delay.unwrap_or(std::time::Duration::from_millis(1000)),
..Default::default()
};
let operation = || async {
let mut request = self.client.post(url.clone()).json(body);
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
let response = request.send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(
response.error_for_status().unwrap_err()
));
}
let data: T = response.json().await?;
Ok(data)
};
if let Some(timeout) = self.config.timeout {
with_timeout(retry_if_retryable(operation, retry_config), timeout).await
} else {
retry_if_retryable(operation, retry_config).await
}
}
/// Gets the current queue status
pub async fn get_queue(&self) -> Result<QueueStatus> {
self.get_with_retry("/queue").await
}
/// Submits a prompt for execution
pub async fn queue_prompt(&self, prompt_request: &PromptRequest) -> Result<PromptResponse> {
self.post_with_retry("/prompt", prompt_request).await
}
/// Gets execution history
pub async fn get_history(&self, max_items: Option<u32>) -> Result<HashMap<String, HistoryItem>> {
let endpoint = if let Some(max) = max_items {
format!("/history?max_items={max}")
} else {
"/history".to_string()
};
self.get_with_retry(&endpoint).await
}
/// Gets history for a specific prompt
pub async fn get_history_by_prompt(&self, prompt_id: &str) -> Result<HashMap<String, HistoryItem>> {
let endpoint = format!("/history/{prompt_id}");
self.get_with_retry(&endpoint).await
}
/// Clears the execution queue
pub async fn clear_queue(&self) -> Result<()> {
let url = self.build_url("/queue")?;
let response = self.client.delete(url).send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(
response.error_for_status().unwrap_err()
));
}
Ok(())
}
/// Cancels a specific prompt
pub async fn cancel_prompt(&self, prompt_id: &str) -> Result<()> {
let body = serde_json::json!({ "delete": [prompt_id] });
let url = self.build_url("/queue")?;
let response = self.client.post(url).json(&body).send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(
response.error_for_status().unwrap_err()
));
}
Ok(())
}
/// Gets system statistics
pub async fn get_system_stats(&self) -> Result<SystemStats> {
self.get_with_retry("/system_stats").await
}
/// Gets object information (available nodes)
pub async fn get_object_info(&self) -> Result<ObjectInfo> {
self.get_with_retry("/object_info").await
}
/// Uploads an image file
pub async fn upload_image<P: AsRef<Path>>(&self, image_path: P, overwrite: bool) -> Result<UploadImageResponse> {
let path = image_path.as_ref();
let filename = path.file_name()
.and_then(|name| name.to_str())
.ok_or_else(|| ComfyUIError::new("Invalid filename"))?;
let file_bytes = tokio::fs::read(path).await?;
let part = multipart::Part::bytes(file_bytes)
.file_name(filename.to_string())
.mime_str("image/*")
.map_err(|e| ComfyUIError::new(format!("Invalid MIME type: {e}")))?;
let form = multipart::Form::new()
.part("image", part)
.text("overwrite", overwrite.to_string());
let url = self.build_url("/upload/image")?;
let response = self.client.post(url).multipart(form).send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(response.error_for_status().unwrap_err()));
}
let upload_response: UploadImageResponse = response.json().await?;
Ok(upload_response)
}
/// Gets an image by filename
pub async fn get_image(&self, filename: &str, subfolder: Option<&str>, image_type: Option<&str>) -> Result<bytes::Bytes> {
let mut endpoint = format!("/view?filename={filename}");
if let Some(subfolder) = subfolder {
endpoint.push_str(&format!("&subfolder={subfolder}"));
}
if let Some(image_type) = image_type {
endpoint.push_str(&format!("&type={image_type}"));
}
let url = self.build_url(&endpoint)?;
let response = self.client.get(url).send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(response.error_for_status().unwrap_err()));
}
let bytes = response.bytes().await?;
Ok(bytes)
}
/// Gets image metadata
pub async fn get_view_metadata(&self, filename: &str, subfolder: Option<&str>) -> Result<ViewMetadata> {
let mut endpoint = format!("/view_metadata/{filename}");
if let Some(subfolder) = subfolder {
endpoint.push_str(&format!("?subfolder={subfolder}"));
}
self.get_with_retry(&endpoint).await
}
/// Converts outputs to HTTP URLs
pub fn outputs_to_urls(&self, outputs: &HashMap<String, serde_json::Value>) -> Vec<String> {
let mut urls = Vec::new();
for output in outputs.values() {
if let Some(output_obj) = output.as_object() {
if let Some(images) = output_obj.get("images").and_then(|v| v.as_array()) {
for image in images {
if let Some(image_obj) = image.as_object() {
if let (Some(filename), Some(subfolder), Some(image_type)) = (
image_obj.get("filename").and_then(|v| v.as_str()),
image_obj.get("subfolder").and_then(|v| v.as_str()),
image_obj.get("type").and_then(|v| v.as_str()),
) {
let base_url_str = self.base_url.as_str().trim_end_matches('/');
let url = format!(
"{base_url_str}/view?filename={filename}&subfolder={subfolder}&type={image_type}"
);
urls.push(url);
}
}
}
}
}
}
urls
}
/// Interrupts the current execution
pub async fn interrupt(&self) -> Result<()> {
let url = self.build_url("/interrupt")?;
let response = self.client.post(url).send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(response.error_for_status().unwrap_err()));
}
Ok(())
}
/// Frees memory
pub async fn free_memory(&self) -> Result<()> {
let url = self.build_url("/free")?;
let response = self.client.post(url).send().await?;
if !response.status().is_success() {
return Err(ComfyUIError::Http(response.error_for_status().unwrap_err()));
}
Ok(())
}
}