380 lines
15 KiB
Rust
380 lines
15 KiB
Rust
use crate::data::models::comfyui::*;
|
|
use anyhow::{anyhow, Result};
|
|
use reqwest::{Client, Response};
|
|
use serde_json::Value;
|
|
use std::time::Duration;
|
|
use tracing::{debug, error, info, warn};
|
|
|
|
/// ComfyUI Infrastructure API 服务
|
|
///
|
|
/// 提供对 ComfyUI Workflow Service & Management API 的完整封装
|
|
/// 基于 OpenAPI 3.1.0 规范实现所有端点
|
|
/// 使用 /api/run/ 端点进行工作流执行
|
|
#[derive(Debug, Clone)]
|
|
pub struct ComfyuiInfrastructureService {
|
|
/// HTTP 客户端
|
|
client: Client,
|
|
/// 服务配置
|
|
config: ComfyuiConfig,
|
|
}
|
|
|
|
impl ComfyuiInfrastructureService {
|
|
/// 创建新的 ComfyUI Infrastructure 服务实例
|
|
pub fn new(config: ComfyuiConfig) -> Result<Self> {
|
|
let timeout = Duration::from_secs(config.timeout.unwrap_or(30));
|
|
|
|
let client = Client::builder()
|
|
.timeout(timeout)
|
|
.build()
|
|
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
|
|
|
info!("ComfyUI Infrastructure service initialized with base URL: {}", config.base_url);
|
|
|
|
Ok(Self { client, config })
|
|
}
|
|
|
|
/// 获取所有工作流
|
|
///
|
|
/// GET /api/workflow
|
|
/// 返回所有已发布的工作流列表
|
|
pub async fn get_all_workflows(&self) -> Result<Vec<Workflow>> {
|
|
let url = format!("{}/api/workflow", self.config.base_url.trim_end_matches('/'));
|
|
debug!("Getting all workflows from: {}", url);
|
|
|
|
let response = self.execute_get_request(&url).await?;
|
|
let workflows: Vec<Workflow> = response.json().await
|
|
.map_err(|e| anyhow!("Failed to parse workflows response: {}", e))?;
|
|
|
|
info!("Retrieved {} workflows", workflows.len());
|
|
Ok(workflows)
|
|
}
|
|
|
|
/// 发布工作流
|
|
///
|
|
/// POST /api/workflow
|
|
/// 发布新的工作流或更新现有工作流
|
|
pub async fn publish_workflow(&self, request: PublishWorkflowRequest) -> Result<PublishWorkflowResponse> {
|
|
let url = format!("{}/api/workflow", self.config.base_url.trim_end_matches('/'));
|
|
debug!("Publishing workflow '{}' to: {}", request.name, url);
|
|
|
|
let response = self.execute_post_request(&url, &request).await?;
|
|
|
|
// 根据 OpenAPI 规范,成功响应状态码是 201
|
|
if response.status().as_u16() == 201 {
|
|
// 响应体可能为空,创建成功响应
|
|
let workflow_response = PublishWorkflowResponse {
|
|
success: true,
|
|
message: Some("Workflow published successfully".to_string()),
|
|
workflow: None, // API 规范中响应体为空
|
|
};
|
|
info!("Workflow '{}' published successfully", request.name);
|
|
Ok(workflow_response)
|
|
} else {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
Err(anyhow!("Failed to publish workflow: HTTP {} - {}", status, error_text))
|
|
}
|
|
}
|
|
|
|
/// 删除工作流
|
|
///
|
|
/// DELETE /api/workflow/{workflow_name}
|
|
/// 删除指定的工作流
|
|
pub async fn delete_workflow(&self, workflow_name: &str) -> Result<DeleteWorkflowResponse> {
|
|
let encoded_name = urlencoding::encode(workflow_name);
|
|
let url = format!("{}/api/workflow/{}", self.config.base_url.trim_end_matches('/'), encoded_name);
|
|
debug!("Deleting workflow '{}' from: {}", workflow_name, url);
|
|
|
|
let response = self.execute_delete_request(&url).await?;
|
|
|
|
if response.status().is_success() {
|
|
let delete_response = DeleteWorkflowResponse {
|
|
success: true,
|
|
message: Some("Workflow deleted successfully".to_string()),
|
|
deleted_workflow: Some(workflow_name.to_string()),
|
|
};
|
|
info!("Workflow '{}' deleted successfully", workflow_name);
|
|
Ok(delete_response)
|
|
} else {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
Err(anyhow!("Failed to delete workflow: HTTP {} - {}", status, error_text))
|
|
}
|
|
}
|
|
|
|
/// 执行工作流
|
|
///
|
|
/// POST /api/run/
|
|
/// 执行指定的工作流
|
|
pub async fn execute_workflow(&self, request: ExecuteWorkflowRequest) -> Result<ExecuteWorkflowResponse> {
|
|
let mut url = format!("{}/api/run/", self.config.base_url.trim_end_matches('/'));
|
|
|
|
// 添加查询参数
|
|
url.push_str(&format!("?base_name={}", urlencoding::encode(&request.base_name)));
|
|
if let Some(version) = &request.version {
|
|
url.push_str(&format!("&version={}", urlencoding::encode(version)));
|
|
}
|
|
|
|
debug!("Executing workflow '{}' at: {}", request.base_name, url);
|
|
|
|
let response = self.execute_post_request(&url, &request.request_data).await?;
|
|
|
|
if response.status().is_success() {
|
|
let response_url = response.url().to_string();
|
|
info!("Response URL: {}", response_url);
|
|
|
|
info!("Workflow '{}' execution started", request.base_name);
|
|
|
|
// 总是尝试用 GET 请求获取最终结果
|
|
info!("Using GET request to fetch final result from: {}", response_url);
|
|
|
|
match self.client.get(&response_url).send().await {
|
|
Ok(final_response) => {
|
|
if final_response.status().is_success() {
|
|
match final_response.json::<Value>().await {
|
|
Ok(final_result) => {
|
|
info!("Successfully retrieved final result via GET");
|
|
let execute_response = ExecuteWorkflowResponse {
|
|
task_id: None, // GET 请求通常不返回 task_id
|
|
status: Some("completed".to_string()),
|
|
message: Some("Workflow execution completed successfully".to_string()),
|
|
result: Some(final_result),
|
|
error: None,
|
|
};
|
|
return Ok(execute_response);
|
|
}
|
|
Err(e) => {
|
|
warn!("Failed to parse final result JSON: {}", e);
|
|
}
|
|
}
|
|
} else {
|
|
warn!("Failed to fetch final result: HTTP {}", final_response.status());
|
|
}
|
|
}
|
|
Err(e) => {
|
|
warn!("Failed to request final result: {}", e);
|
|
}
|
|
}
|
|
|
|
// 如果无法获取最终结果,解析初始响应
|
|
let initial_result: Value = response.json().await
|
|
.unwrap_or_else(|_| serde_json::json!({}));
|
|
|
|
let execute_response = ExecuteWorkflowResponse {
|
|
task_id: initial_result.get("task_id").and_then(|v| v.as_str()).map(String::from),
|
|
status: initial_result.get("status").and_then(|v| v.as_str()).map(String::from),
|
|
message: Some("Workflow execution started".to_string()),
|
|
result: Some(initial_result),
|
|
error: None,
|
|
};
|
|
Ok(execute_response)
|
|
} else {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
|
|
// 解析错误详情,如果是服务器端代理错误,提供更清晰的错误信息
|
|
let error_message = if error_text.contains("comfyui-for-waas-ui") {
|
|
format!("服务器端 ComfyUI 服务错误: {} (状态码: {})", error_text, status)
|
|
} else {
|
|
format!("HTTP {} - {}", status, error_text)
|
|
};
|
|
|
|
let execute_response = ExecuteWorkflowResponse {
|
|
task_id: None,
|
|
status: Some("failed".to_string()),
|
|
message: None,
|
|
result: None,
|
|
error: Some(error_message.clone()),
|
|
};
|
|
warn!("Workflow '{}' execution failed: {}", request.base_name, error_message);
|
|
Ok(execute_response)
|
|
}
|
|
}
|
|
|
|
/// 获取工作流规范
|
|
///
|
|
/// GET /api/spec/
|
|
/// 获取指定工作流的规范信息
|
|
pub async fn get_workflow_spec(&self, request: GetWorkflowSpecRequest) -> Result<GetWorkflowSpecResponse> {
|
|
let mut url = format!("{}/api/spec/", self.config.base_url.trim_end_matches('/'));
|
|
|
|
// 添加查询参数
|
|
url.push_str(&format!("?base_name={}", urlencoding::encode(&request.base_name)));
|
|
if let Some(version) = &request.version {
|
|
url.push_str(&format!("&version={}", urlencoding::encode(version)));
|
|
}
|
|
|
|
debug!("Getting workflow spec for '{}' from: {}", request.base_name, url);
|
|
|
|
let response = self.execute_get_request(&url).await?;
|
|
let spec: Value = response.json().await
|
|
.map_err(|e| anyhow!("Failed to parse workflow spec response: {}", e))?;
|
|
|
|
let spec_response = GetWorkflowSpecResponse {
|
|
spec: spec.clone(),
|
|
name: Some(request.base_name.clone()),
|
|
version: request.version.clone(),
|
|
description: spec.get("description").and_then(|v| v.as_str()).map(String::from),
|
|
};
|
|
|
|
info!("Retrieved workflow spec for '{}'", request.base_name);
|
|
Ok(spec_response)
|
|
}
|
|
|
|
/// 获取服务器状态
|
|
///
|
|
/// GET /api/servers/status
|
|
/// 获取所有已配置的ComfyUI服务器的配置信息和实时状态
|
|
pub async fn get_servers_status(&self) -> Result<Vec<ServerStatus>> {
|
|
let url = format!("{}/api/servers/status", self.config.base_url.trim_end_matches('/'));
|
|
debug!("Getting servers status from: {}", url);
|
|
|
|
let response = self.execute_get_request(&url).await?;
|
|
let servers: Vec<ServerStatus> = response.json().await
|
|
.map_err(|e| anyhow!("Failed to parse servers status response: {}", e))?;
|
|
|
|
info!("Retrieved status for {} servers", servers.len());
|
|
Ok(servers)
|
|
}
|
|
|
|
/// 获取服务器文件列表
|
|
///
|
|
/// GET /api/servers/{server_index}/files
|
|
/// 获取指定ComfyUI服务器的输入和输出文件夹中的文件列表
|
|
pub async fn list_server_files(&self, server_index: i32) -> Result<ServerFiles> {
|
|
let url = format!("{}/api/servers/{}/files", self.config.base_url.trim_end_matches('/'), server_index);
|
|
debug!("Getting server files for server {} from: {}", server_index, url);
|
|
|
|
let response = self.execute_get_request(&url).await?;
|
|
let files: ServerFiles = response.json().await
|
|
.map_err(|e| anyhow!("Failed to parse server files response: {}", e))?;
|
|
|
|
info!("Retrieved file list for server {}: {} input files, {} output files",
|
|
server_index, files.input_files.len(), files.output_files.len());
|
|
Ok(files)
|
|
}
|
|
|
|
/// 获取 API 根信息
|
|
///
|
|
/// GET /
|
|
/// 提供一个API的快速使用指南
|
|
pub async fn get_api_root(&self) -> Result<ApiRootResponse> {
|
|
let url = self.config.base_url.trim_end_matches('/').to_string();
|
|
debug!("Getting API root info from: {}", url);
|
|
|
|
let response = self.execute_get_request(&url).await?;
|
|
|
|
// 尝试解析为结构化响应,如果失败则返回基本信息
|
|
let result: Value = response.json().await
|
|
.unwrap_or_else(|_| serde_json::json!({"guide": "ComfyUI API"}));
|
|
|
|
let root_response = ApiRootResponse {
|
|
guide: result.get("guide").and_then(|v| v.as_str())
|
|
.unwrap_or("ComfyUI Workflow Service & Management API").to_string(),
|
|
version: result.get("version").and_then(|v| v.as_str()).map(String::from),
|
|
endpoints: result.get("endpoints").and_then(|v| v.as_array())
|
|
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()),
|
|
};
|
|
|
|
info!("Retrieved API root information");
|
|
Ok(root_response)
|
|
}
|
|
|
|
// ========== 私有辅助方法 ==========
|
|
|
|
/// 执行 GET 请求
|
|
async fn execute_get_request(&self, url: &str) -> Result<Response> {
|
|
let response = self.client.get(url).send().await
|
|
.map_err(|e| anyhow!("GET request failed: {}", e))?;
|
|
Ok(response)
|
|
}
|
|
|
|
/// 执行 POST 请求
|
|
async fn execute_post_request<T: serde::Serialize>(&self, url: &str, body: &T) -> Result<Response> {
|
|
info!("Infrastructure service making POST request to: {}", url);
|
|
info!("Request body: {:?}", serde_json::to_value(body).unwrap_or_default());
|
|
|
|
let response = self.client.post(url).json(body).send().await
|
|
.map_err(|e| {
|
|
error!("POST request failed to {}: {}", url, e);
|
|
anyhow!("POST request failed: {}", e)
|
|
})?;
|
|
|
|
info!("Response status: {}", response.status());
|
|
info!("Response URL: {}", response.url());
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
/// 执行 DELETE 请求
|
|
async fn execute_delete_request(&self, url: &str) -> Result<Response> {
|
|
let response = self.client.delete(url).send().await
|
|
.map_err(|e| anyhow!("DELETE request failed: {}", e))?;
|
|
Ok(response)
|
|
}
|
|
|
|
|
|
|
|
/// 验证配置
|
|
pub fn validate_config(&self) -> Result<()> {
|
|
if self.config.base_url.is_empty() {
|
|
return Err(anyhow!("Base URL cannot be empty"));
|
|
}
|
|
|
|
if !self.config.base_url.starts_with("http://") && !self.config.base_url.starts_with("https://") {
|
|
return Err(anyhow!("Base URL must start with http:// or https://"));
|
|
}
|
|
|
|
if let Some(timeout) = self.config.timeout {
|
|
if timeout == 0 {
|
|
return Err(anyhow!("Timeout must be greater than 0"));
|
|
}
|
|
}
|
|
|
|
if let Some(retry_attempts) = self.config.retry_attempts {
|
|
if retry_attempts > 10 {
|
|
warn!("Retry attempts is very high ({}), consider reducing it", retry_attempts);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// 获取服务配置
|
|
pub fn get_config(&self) -> &ComfyuiConfig {
|
|
&self.config
|
|
}
|
|
|
|
/// 更新服务配置
|
|
pub fn update_config(&mut self, config: ComfyuiConfig) -> Result<()> {
|
|
// 验证新配置
|
|
let temp_service = Self::new(config.clone())?;
|
|
temp_service.validate_config()?;
|
|
|
|
// 更新配置和客户端
|
|
self.config = config;
|
|
let timeout = Duration::from_secs(self.config.timeout.unwrap_or(30));
|
|
self.client = Client::builder()
|
|
.timeout(timeout)
|
|
.build()
|
|
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
|
|
|
info!("ComfyUI service configuration updated");
|
|
Ok(())
|
|
}
|
|
|
|
/// 测试连接
|
|
pub async fn test_connection(&self) -> Result<bool> {
|
|
match self.get_api_root().await {
|
|
Ok(_) => {
|
|
info!("ComfyUI service connection test successful");
|
|
Ok(true)
|
|
}
|
|
Err(e) => {
|
|
error!("ComfyUI service connection test failed: {}", e);
|
|
Ok(false)
|
|
}
|
|
}
|
|
}
|
|
}
|