mixvideo-v2/apps/desktop/src-tauri/src/infrastructure/comfyui_service.rs

380 lines
15 KiB
Rust

use crate::data::models::comfyui::*;
use anyhow::{anyhow, Result};
use reqwest::{Client, Response};
use serde_json::Value;
use std::time::Duration;
use tracing::{debug, error, info, warn};
/// ComfyUI Infrastructure API 服务
///
/// 提供对 ComfyUI Workflow Service & Management API 的完整封装
/// 基于 OpenAPI 3.1.0 规范实现所有端点
/// 使用 /api/run/ 端点进行工作流执行
#[derive(Debug, Clone)]
pub struct ComfyuiInfrastructureService {
/// HTTP 客户端
client: Client,
/// 服务配置
config: ComfyuiConfig,
}
impl ComfyuiInfrastructureService {
/// 创建新的 ComfyUI Infrastructure 服务实例
pub fn new(config: ComfyuiConfig) -> Result<Self> {
let timeout = Duration::from_secs(config.timeout.unwrap_or(30));
let client = Client::builder()
.timeout(timeout)
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!("ComfyUI Infrastructure service initialized with base URL: {}", config.base_url);
Ok(Self { client, config })
}
/// 获取所有工作流
///
/// GET /api/workflow
/// 返回所有已发布的工作流列表
pub async fn get_all_workflows(&self) -> Result<Vec<Workflow>> {
let url = format!("{}/api/workflow", self.config.base_url.trim_end_matches('/'));
debug!("Getting all workflows from: {}", url);
let response = self.execute_get_request(&url).await?;
let workflows: Vec<Workflow> = response.json().await
.map_err(|e| anyhow!("Failed to parse workflows response: {}", e))?;
info!("Retrieved {} workflows", workflows.len());
Ok(workflows)
}
/// 发布工作流
///
/// POST /api/workflow
/// 发布新的工作流或更新现有工作流
pub async fn publish_workflow(&self, request: PublishWorkflowRequest) -> Result<PublishWorkflowResponse> {
let url = format!("{}/api/workflow", self.config.base_url.trim_end_matches('/'));
debug!("Publishing workflow '{}' to: {}", request.name, url);
let response = self.execute_post_request(&url, &request).await?;
// 根据 OpenAPI 规范,成功响应状态码是 201
if response.status().as_u16() == 201 {
// 响应体可能为空,创建成功响应
let workflow_response = PublishWorkflowResponse {
success: true,
message: Some("Workflow published successfully".to_string()),
workflow: None, // API 规范中响应体为空
};
info!("Workflow '{}' published successfully", request.name);
Ok(workflow_response)
} else {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
Err(anyhow!("Failed to publish workflow: HTTP {} - {}", status, error_text))
}
}
/// 删除工作流
///
/// DELETE /api/workflow/{workflow_name}
/// 删除指定的工作流
pub async fn delete_workflow(&self, workflow_name: &str) -> Result<DeleteWorkflowResponse> {
let encoded_name = urlencoding::encode(workflow_name);
let url = format!("{}/api/workflow/{}", self.config.base_url.trim_end_matches('/'), encoded_name);
debug!("Deleting workflow '{}' from: {}", workflow_name, url);
let response = self.execute_delete_request(&url).await?;
if response.status().is_success() {
let delete_response = DeleteWorkflowResponse {
success: true,
message: Some("Workflow deleted successfully".to_string()),
deleted_workflow: Some(workflow_name.to_string()),
};
info!("Workflow '{}' deleted successfully", workflow_name);
Ok(delete_response)
} else {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
Err(anyhow!("Failed to delete workflow: HTTP {} - {}", status, error_text))
}
}
/// 执行工作流
///
/// POST /api/run/
/// 执行指定的工作流
pub async fn execute_workflow(&self, request: ExecuteWorkflowRequest) -> Result<ExecuteWorkflowResponse> {
let mut url = format!("{}/api/run/", self.config.base_url.trim_end_matches('/'));
// 添加查询参数
url.push_str(&format!("?base_name={}", urlencoding::encode(&request.base_name)));
if let Some(version) = &request.version {
url.push_str(&format!("&version={}", urlencoding::encode(version)));
}
debug!("Executing workflow '{}' at: {}", request.base_name, url);
let response = self.execute_post_request(&url, &request.request_data).await?;
if response.status().is_success() {
let response_url = response.url().to_string();
info!("Response URL: {}", response_url);
info!("Workflow '{}' execution started", request.base_name);
// 总是尝试用 GET 请求获取最终结果
info!("Using GET request to fetch final result from: {}", response_url);
match self.client.get(&response_url).send().await {
Ok(final_response) => {
if final_response.status().is_success() {
match final_response.json::<Value>().await {
Ok(final_result) => {
info!("Successfully retrieved final result via GET");
let execute_response = ExecuteWorkflowResponse {
task_id: None, // GET 请求通常不返回 task_id
status: Some("completed".to_string()),
message: Some("Workflow execution completed successfully".to_string()),
result: Some(final_result),
error: None,
};
return Ok(execute_response);
}
Err(e) => {
warn!("Failed to parse final result JSON: {}", e);
}
}
} else {
warn!("Failed to fetch final result: HTTP {}", final_response.status());
}
}
Err(e) => {
warn!("Failed to request final result: {}", e);
}
}
// 如果无法获取最终结果,解析初始响应
let initial_result: Value = response.json().await
.unwrap_or_else(|_| serde_json::json!({}));
let execute_response = ExecuteWorkflowResponse {
task_id: initial_result.get("task_id").and_then(|v| v.as_str()).map(String::from),
status: initial_result.get("status").and_then(|v| v.as_str()).map(String::from),
message: Some("Workflow execution started".to_string()),
result: Some(initial_result),
error: None,
};
Ok(execute_response)
} else {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
// 解析错误详情,如果是服务器端代理错误,提供更清晰的错误信息
let error_message = if error_text.contains("comfyui-for-waas-ui") {
format!("服务器端 ComfyUI 服务错误: {} (状态码: {})", error_text, status)
} else {
format!("HTTP {} - {}", status, error_text)
};
let execute_response = ExecuteWorkflowResponse {
task_id: None,
status: Some("failed".to_string()),
message: None,
result: None,
error: Some(error_message.clone()),
};
warn!("Workflow '{}' execution failed: {}", request.base_name, error_message);
Ok(execute_response)
}
}
/// 获取工作流规范
///
/// GET /api/spec/
/// 获取指定工作流的规范信息
pub async fn get_workflow_spec(&self, request: GetWorkflowSpecRequest) -> Result<GetWorkflowSpecResponse> {
let mut url = format!("{}/api/spec/", self.config.base_url.trim_end_matches('/'));
// 添加查询参数
url.push_str(&format!("?base_name={}", urlencoding::encode(&request.base_name)));
if let Some(version) = &request.version {
url.push_str(&format!("&version={}", urlencoding::encode(version)));
}
debug!("Getting workflow spec for '{}' from: {}", request.base_name, url);
let response = self.execute_get_request(&url).await?;
let spec: Value = response.json().await
.map_err(|e| anyhow!("Failed to parse workflow spec response: {}", e))?;
let spec_response = GetWorkflowSpecResponse {
spec: spec.clone(),
name: Some(request.base_name.clone()),
version: request.version.clone(),
description: spec.get("description").and_then(|v| v.as_str()).map(String::from),
};
info!("Retrieved workflow spec for '{}'", request.base_name);
Ok(spec_response)
}
/// 获取服务器状态
///
/// GET /api/servers/status
/// 获取所有已配置的ComfyUI服务器的配置信息和实时状态
pub async fn get_servers_status(&self) -> Result<Vec<ServerStatus>> {
let url = format!("{}/api/servers/status", self.config.base_url.trim_end_matches('/'));
debug!("Getting servers status from: {}", url);
let response = self.execute_get_request(&url).await?;
let servers: Vec<ServerStatus> = response.json().await
.map_err(|e| anyhow!("Failed to parse servers status response: {}", e))?;
info!("Retrieved status for {} servers", servers.len());
Ok(servers)
}
/// 获取服务器文件列表
///
/// GET /api/servers/{server_index}/files
/// 获取指定ComfyUI服务器的输入和输出文件夹中的文件列表
pub async fn list_server_files(&self, server_index: i32) -> Result<ServerFiles> {
let url = format!("{}/api/servers/{}/files", self.config.base_url.trim_end_matches('/'), server_index);
debug!("Getting server files for server {} from: {}", server_index, url);
let response = self.execute_get_request(&url).await?;
let files: ServerFiles = response.json().await
.map_err(|e| anyhow!("Failed to parse server files response: {}", e))?;
info!("Retrieved file list for server {}: {} input files, {} output files",
server_index, files.input_files.len(), files.output_files.len());
Ok(files)
}
/// 获取 API 根信息
///
/// GET /
/// 提供一个API的快速使用指南
pub async fn get_api_root(&self) -> Result<ApiRootResponse> {
let url = self.config.base_url.trim_end_matches('/').to_string();
debug!("Getting API root info from: {}", url);
let response = self.execute_get_request(&url).await?;
// 尝试解析为结构化响应,如果失败则返回基本信息
let result: Value = response.json().await
.unwrap_or_else(|_| serde_json::json!({"guide": "ComfyUI API"}));
let root_response = ApiRootResponse {
guide: result.get("guide").and_then(|v| v.as_str())
.unwrap_or("ComfyUI Workflow Service & Management API").to_string(),
version: result.get("version").and_then(|v| v.as_str()).map(String::from),
endpoints: result.get("endpoints").and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()),
};
info!("Retrieved API root information");
Ok(root_response)
}
// ========== 私有辅助方法 ==========
/// 执行 GET 请求
async fn execute_get_request(&self, url: &str) -> Result<Response> {
let response = self.client.get(url).send().await
.map_err(|e| anyhow!("GET request failed: {}", e))?;
Ok(response)
}
/// 执行 POST 请求
async fn execute_post_request<T: serde::Serialize>(&self, url: &str, body: &T) -> Result<Response> {
info!("Infrastructure service making POST request to: {}", url);
info!("Request body: {:?}", serde_json::to_value(body).unwrap_or_default());
let response = self.client.post(url).json(body).send().await
.map_err(|e| {
error!("POST request failed to {}: {}", url, e);
anyhow!("POST request failed: {}", e)
})?;
info!("Response status: {}", response.status());
info!("Response URL: {}", response.url());
Ok(response)
}
/// 执行 DELETE 请求
async fn execute_delete_request(&self, url: &str) -> Result<Response> {
let response = self.client.delete(url).send().await
.map_err(|e| anyhow!("DELETE request failed: {}", e))?;
Ok(response)
}
/// 验证配置
pub fn validate_config(&self) -> Result<()> {
if self.config.base_url.is_empty() {
return Err(anyhow!("Base URL cannot be empty"));
}
if !self.config.base_url.starts_with("http://") && !self.config.base_url.starts_with("https://") {
return Err(anyhow!("Base URL must start with http:// or https://"));
}
if let Some(timeout) = self.config.timeout {
if timeout == 0 {
return Err(anyhow!("Timeout must be greater than 0"));
}
}
if let Some(retry_attempts) = self.config.retry_attempts {
if retry_attempts > 10 {
warn!("Retry attempts is very high ({}), consider reducing it", retry_attempts);
}
}
Ok(())
}
/// 获取服务配置
pub fn get_config(&self) -> &ComfyuiConfig {
&self.config
}
/// 更新服务配置
pub fn update_config(&mut self, config: ComfyuiConfig) -> Result<()> {
// 验证新配置
let temp_service = Self::new(config.clone())?;
temp_service.validate_config()?;
// 更新配置和客户端
self.config = config;
let timeout = Duration::from_secs(self.config.timeout.unwrap_or(30));
self.client = Client::builder()
.timeout(timeout)
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
info!("ComfyUI service configuration updated");
Ok(())
}
/// 测试连接
pub async fn test_connection(&self) -> Result<bool> {
match self.get_api_root().await {
Ok(_) => {
info!("ComfyUI service connection test successful");
Ok(true)
}
Err(e) => {
error!("ComfyUI service connection test failed: {}", e);
Ok(false)
}
}
}
}