mixvideo-v2/apps/desktop/src-tauri/src/data/models/comfyui.rs

754 lines
23 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! ComfyUI 数据模型
//! 基于 comfyui-sdk 重新设计的统一数据模型
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use uuid::Uuid;
// 重新导出 SDK 类型
pub use comfyui_sdk::types::{
ComfyUIWorkflow, ParameterSchema, ParameterType, ParameterValues,
ValidationResult, ValidationError as SDKValidationError, TemplateMetadata, WorkflowTemplateData,
QueueStatus, SystemStats, ObjectInfo
};
/// 工作流数据模型
/// 存储完整的工作流信息,包括元数据和工作流定义
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowModel {
/// 工作流唯一标识符
pub id: String,
/// 工作流名称
pub name: String,
/// 工作流描述
pub description: Option<String>,
/// 工作流 JSON 数据
pub workflow_data: ComfyUIWorkflow,
/// 工作流版本
pub version: String,
/// 创建时间
pub created_at: DateTime<Utc>,
/// 更新时间
pub updated_at: DateTime<Utc>,
/// 是否启用
pub enabled: bool,
/// 标签
pub tags: Vec<String>,
/// 分类
pub category: Option<String>,
}
impl WorkflowModel {
/// 创建新的工作流模型
pub fn new(name: String, workflow_data: ComfyUIWorkflow) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
name,
description: None,
workflow_data,
version: "1.0".to_string(),
created_at: now,
updated_at: now,
enabled: true,
tags: Vec::new(),
category: None,
}
}
/// 更新工作流数据
pub fn update_workflow(&mut self, workflow_data: ComfyUIWorkflow) {
self.workflow_data = workflow_data;
self.updated_at = Utc::now();
}
/// 更新元数据
pub fn update_metadata(&mut self, name: Option<String>, description: Option<String>, tags: Option<Vec<String>>) {
if let Some(name) = name {
self.name = name;
}
if let Some(description) = description {
self.description = Some(description);
}
if let Some(tags) = tags {
self.tags = tags;
}
self.updated_at = Utc::now();
}
}
/// 模板数据模型
/// 存储工作流模板信息,包括参数定义和模板数据
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemplateModel {
/// 模板唯一标识符
pub id: String,
/// 模板名称
pub name: String,
/// 模板分类
pub category: Option<String>,
/// 模板描述
pub description: Option<String>,
/// 模板数据
pub template_data: WorkflowTemplateData,
/// 参数定义
pub parameter_schema: HashMap<String, ParameterSchema>,
/// 创建时间
pub created_at: DateTime<Utc>,
/// 更新时间
pub updated_at: DateTime<Utc>,
/// 是否启用
pub enabled: bool,
/// 标签
pub tags: Vec<String>,
/// 作者
pub author: Option<String>,
/// 版本
pub version: String,
}
impl TemplateModel {
/// 创建新的模板模型
pub fn new(name: String, template_data: WorkflowTemplateData, parameter_schema: HashMap<String, ParameterSchema>) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
name,
category: None,
description: None,
template_data,
parameter_schema,
created_at: now,
updated_at: now,
enabled: true,
tags: Vec::new(),
author: None,
version: "1.0".to_string(),
}
}
/// 验证参数
pub fn validate_parameters(&self, params: &ParameterValues) -> ValidationResult {
let mut result = ValidationResult::success();
// 检查必需参数
for (param_name, schema) in &self.parameter_schema {
if schema.required.unwrap_or(false) && !params.contains_key(param_name) {
result.add_error(ValidationError::new(
param_name,
format!("Required parameter '{}' is missing", param_name)
));
}
}
// 检查参数类型和值
for (param_name, value) in params {
if let Some(schema) = self.parameter_schema.get(param_name) {
if let Err(error) = self.validate_parameter_value(param_name, value, schema) {
result.add_error(error);
}
}
}
result
}
/// 验证单个参数值
fn validate_parameter_value(&self, param_name: &str, value: &serde_json::Value, schema: &ParameterSchema) -> Result<(), ValidationError> {
match (&schema.param_type, value) {
(ParameterType::String, serde_json::Value::String(_)) => Ok(()),
(ParameterType::Number, serde_json::Value::Number(_)) => Ok(()),
(ParameterType::Boolean, serde_json::Value::Bool(_)) => Ok(()),
(ParameterType::Array, serde_json::Value::Array(_)) => Ok(()),
(ParameterType::Object, serde_json::Value::Object(_)) => Ok(()),
_ => Err(ValidationError::with_value(
param_name,
format!("Parameter '{}' has invalid type, expected {:?}", param_name, schema.param_type),
value.clone()
)),
}
}
/// 更新模板数据
pub fn update_template(&mut self, template_data: WorkflowTemplateData, parameter_schema: HashMap<String, ParameterSchema>) {
self.template_data = template_data;
self.parameter_schema = parameter_schema;
self.updated_at = Utc::now();
}
}
/// 执行状态枚举
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ExecutionStatus {
/// 等待中
Pending,
/// 运行中
Running,
/// 已完成
Completed,
/// 失败
Failed,
/// 已取消
Cancelled,
}
impl std::fmt::Display for ExecutionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionStatus::Pending => write!(f, "pending"),
ExecutionStatus::Running => write!(f, "running"),
ExecutionStatus::Completed => write!(f, "completed"),
ExecutionStatus::Failed => write!(f, "failed"),
ExecutionStatus::Cancelled => write!(f, "cancelled"),
}
}
}
/// 执行记录模型
/// 存储工作流执行的详细信息和结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionModel {
/// 执行唯一标识符
pub id: String,
/// 关联的工作流 ID可选
pub workflow_id: Option<String>,
/// 关联的模板 ID可选
pub template_id: Option<String>,
/// ComfyUI 提示 ID
pub prompt_id: String,
/// 执行状态
pub status: ExecutionStatus,
/// 执行参数
pub parameters: Option<ParameterValues>,
/// 执行结果
pub results: Option<HashMap<String, serde_json::Value>>,
/// 输出文件 URLs
pub output_urls: Vec<String>,
/// 错误信息
pub error_message: Option<String>,
/// 执行时间(毫秒)
pub execution_time: Option<u64>,
/// 创建时间
pub created_at: DateTime<Utc>,
/// 完成时间
pub completed_at: Option<DateTime<Utc>>,
/// 客户端 ID
pub client_id: Option<String>,
/// 节点输出详情
pub node_outputs: Option<HashMap<String, serde_json::Value>>,
}
impl ExecutionModel {
/// 创建新的执行记录
pub fn new(prompt_id: String) -> Self {
Self {
id: Uuid::new_v4().to_string(),
workflow_id: None,
template_id: None,
prompt_id,
status: ExecutionStatus::Pending,
parameters: None,
results: None,
output_urls: Vec::new(),
error_message: None,
execution_time: None,
created_at: Utc::now(),
completed_at: None,
client_id: None,
node_outputs: None,
}
}
/// 创建工作流执行记录
pub fn for_workflow(workflow_id: String, prompt_id: String, parameters: Option<ParameterValues>) -> Self {
let mut execution = Self::new(prompt_id);
execution.workflow_id = Some(workflow_id);
execution.parameters = parameters;
execution
}
/// 创建模板执行记录
pub fn for_template(template_id: String, prompt_id: String, parameters: ParameterValues) -> Self {
let mut execution = Self::new(prompt_id);
execution.template_id = Some(template_id);
execution.parameters = Some(parameters);
execution
}
/// 更新执行状态
pub fn update_status(&mut self, status: ExecutionStatus) {
self.status = status;
if matches!(status, ExecutionStatus::Completed | ExecutionStatus::Failed | ExecutionStatus::Cancelled) {
self.completed_at = Some(Utc::now());
if let Some(created_at) = self.created_at.timestamp_millis().try_into().ok() {
if let Some(completed_at) = self.completed_at.and_then(|t| t.timestamp_millis().try_into().ok()) {
self.execution_time = Some(completed_at - created_at);
}
}
}
}
/// 设置执行结果
pub fn set_results(&mut self, results: HashMap<String, serde_json::Value>, output_urls: Vec<String>) {
self.results = Some(results);
self.output_urls = output_urls;
self.update_status(ExecutionStatus::Completed);
}
/// 设置执行错误
pub fn set_error(&mut self, error_message: String) {
self.error_message = Some(error_message);
self.update_status(ExecutionStatus::Failed);
}
/// 检查是否已完成
pub fn is_completed(&self) -> bool {
matches!(self.status, ExecutionStatus::Completed | ExecutionStatus::Failed | ExecutionStatus::Cancelled)
}
/// 检查是否成功
pub fn is_successful(&self) -> bool {
self.status == ExecutionStatus::Completed
}
}
/// 队列状态模型
/// 存储 ComfyUI 队列的状态信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueModel {
/// 队列快照 ID
pub id: String,
/// 正在运行的任务
pub running_tasks: Vec<QueueTaskInfo>,
/// 等待中的任务
pub pending_tasks: Vec<QueueTaskInfo>,
/// 快照时间
pub snapshot_time: DateTime<Utc>,
/// 总任务数
pub total_tasks: u32,
/// 已完成任务数
pub completed_tasks: u32,
}
/// 队列任务信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueTaskInfo {
/// 提示 ID
pub prompt_id: String,
/// 任务编号
pub number: u32,
/// 任务状态
pub status: String,
/// 提交时间
pub submitted_at: Option<DateTime<Utc>>,
/// 开始时间
pub started_at: Option<DateTime<Utc>>,
/// 预估剩余时间(秒)
pub estimated_remaining: Option<u64>,
}
impl QueueModel {
/// 从 SDK 队列状态创建模型
pub fn from_queue_status(queue_status: &QueueStatus) -> Self {
let running_tasks: Vec<QueueTaskInfo> = queue_status.queue_running
.iter()
.map(|item| QueueTaskInfo {
prompt_id: item.prompt_id.clone(),
number: item.number,
status: "running".to_string(),
submitted_at: None,
started_at: Some(Utc::now()), // 假设正在运行的任务刚开始
estimated_remaining: None,
})
.collect();
let pending_tasks: Vec<QueueTaskInfo> = queue_status.queue_pending
.iter()
.map(|item| QueueTaskInfo {
prompt_id: item.prompt_id.clone(),
number: item.number,
status: "pending".to_string(),
submitted_at: Some(Utc::now()), // 假设等待中的任务刚提交
started_at: None,
estimated_remaining: None,
})
.collect();
let total_tasks = (running_tasks.len() + pending_tasks.len()) as u32;
Self {
id: Uuid::new_v4().to_string(),
running_tasks,
pending_tasks,
snapshot_time: Utc::now(),
total_tasks,
completed_tasks: 0, // 这个需要从历史记录中计算
}
}
/// 获取队列统计信息
pub fn get_statistics(&self) -> QueueStatistics {
QueueStatistics {
total_running: self.running_tasks.len() as u32,
total_pending: self.pending_tasks.len() as u32,
total_tasks: self.total_tasks,
completed_tasks: self.completed_tasks,
average_wait_time: self.calculate_average_wait_time(),
}
}
/// 计算平均等待时间
fn calculate_average_wait_time(&self) -> Option<u64> {
if self.pending_tasks.is_empty() {
return None;
}
let now = Utc::now();
let total_wait_time: i64 = self.pending_tasks
.iter()
.filter_map(|task| task.submitted_at)
.map(|submitted| (now - submitted).num_seconds())
.sum();
if total_wait_time > 0 {
Some((total_wait_time / self.pending_tasks.len() as i64) as u64)
} else {
None
}
}
}
/// 队列统计信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStatistics {
/// 正在运行的任务数
pub total_running: u32,
/// 等待中的任务数
pub total_pending: u32,
/// 总任务数
pub total_tasks: u32,
/// 已完成任务数
pub completed_tasks: u32,
/// 平均等待时间(秒)
pub average_wait_time: Option<u64>,
}
/// 旧版 ComfyUI 配置(用于 Infrastructure 服务)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComfyuiConfig {
/// 服务基础 URL
pub base_url: String,
/// 连接超时时间(秒)
pub timeout: Option<u64>,
/// 重试次数
pub retry_attempts: Option<u32>,
/// 是否启用缓存
pub enable_cache: Option<bool>,
/// 最大并发数
pub max_concurrency: Option<u32>,
}
impl Default for ComfyuiConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:8188".to_string(),
timeout: Some(300),
retry_attempts: Some(3),
enable_cache: Some(true),
max_concurrency: Some(4),
}
}
}
/// ComfyUI 服务配置
/// 基于 SDK 的统一配置结构
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComfyUIConfig {
/// 服务基础 URL
pub base_url: String,
/// 连接超时时间(秒)
pub timeout_seconds: u64,
/// 重试次数
pub retry_attempts: u32,
/// 重试延迟(毫秒)
pub retry_delay_ms: u64,
/// 是否启用 WebSocket
pub enable_websocket: bool,
/// 是否启用缓存
pub enable_cache: bool,
/// 最大并发数
pub max_concurrency: u32,
/// 自定义请求头
pub custom_headers: Option<HashMap<String, String>>,
}
impl Default for ComfyUIConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:8188".to_string(),
timeout_seconds: 300, // 5分钟默认超时
retry_attempts: 3,
retry_delay_ms: 1000,
enable_websocket: true,
enable_cache: true,
max_concurrency: 4,
custom_headers: None,
}
}
}
impl ComfyUIConfig {
/// 转换为 SDK 客户端配置
pub fn to_sdk_config(&self) -> comfyui_sdk::types::ComfyUIClientConfig {
comfyui_sdk::types::ComfyUIClientConfig {
base_url: self.base_url.clone(),
timeout: Some(std::time::Duration::from_secs(self.timeout_seconds)),
retry_attempts: Some(self.retry_attempts),
retry_delay: Some(std::time::Duration::from_millis(self.retry_delay_ms)),
headers: self.custom_headers.clone(),
}
}
/// 验证配置
pub fn validate(&self) -> ValidationResult {
let mut result = ValidationResult::success();
// 验证 URL
if self.base_url.is_empty() {
result.add_error(SDKValidationError::new("base_url".to_string(), "Base URL cannot be empty".to_string(), None));
} else if url::Url::parse(&self.base_url).is_err() {
result.add_error(SDKValidationError::new("base_url".to_string(), "Invalid URL format".to_string(), None));
}
// 验证超时时间
if self.timeout_seconds == 0 {
result.add_error(SDKValidationError::new("timeout_seconds".to_string(), "Timeout must be greater than 0".to_string(), None));
}
// 验证重试次数
if self.retry_attempts > 10 {
result.add_error(SDKValidationError::new("retry_attempts".to_string(), "Retry attempts should not exceed 10".to_string(), None));
}
// 验证并发数
if self.max_concurrency == 0 {
result.add_error(SDKValidationError::new("max_concurrency".to_string(), "Max concurrency must be greater than 0".to_string(), None));
} else if self.max_concurrency > 100 {
result.add_error(SDKValidationError::new("max_concurrency".to_string(), "Max concurrency should not exceed 100".to_string(), None));
}
result
}
}
/// 工作流对象 - 用于 GET /api/workflow 响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Workflow {
/// 工作流的完整唯一名称,例如 'my_workflow [20250101120000]'
pub name: String,
/// 工作流的基础名称
pub base_name: Option<String>,
/// 工作流版本
pub version: Option<String>,
/// 工作流创建时间
pub created_at: Option<DateTime<Utc>>,
/// 工作流更新时间
pub updated_at: Option<DateTime<Utc>>,
/// 工作流描述
pub description: Option<String>,
/// 工作流配置数据
#[serde(flatten)]
pub additional_properties: HashMap<String, serde_json::Value>,
}
/// 发布工作流请求 - 用于 POST /api/workflow
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PublishWorkflowRequest {
/// 工作流名称
pub name: String,
/// 工作流配置数据
pub workflow_data: serde_json::Value,
/// 工作流描述
pub description: Option<String>,
/// 工作流版本
pub version: Option<String>,
}
/// 发布工作流响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PublishWorkflowResponse {
/// 操作是否成功
pub success: bool,
/// 响应消息
pub message: Option<String>,
/// 创建的工作流信息
pub workflow: Option<Workflow>,
}
/// 删除工作流响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeleteWorkflowResponse {
/// 操作是否成功
pub success: bool,
/// 响应消息
pub message: Option<String>,
/// 删除的工作流名称
pub deleted_workflow: Option<String>,
}
/// 执行工作流请求 - 用于 POST /api/run/
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecuteWorkflowRequest {
/// 工作流基础名称
pub base_name: String,
/// 工作流版本(可选)
pub version: Option<String>,
/// 请求数据
pub request_data: serde_json::Value,
}
/// 执行工作流响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecuteWorkflowResponse {
/// 执行任务ID
pub task_id: Option<String>,
/// 执行状态
pub status: Option<String>,
/// 响应消息
pub message: Option<String>,
/// 执行结果数据
pub result: Option<serde_json::Value>,
/// 错误信息
pub error: Option<String>,
}
/// 获取工作流规范请求 - 用于 GET /api/spec/
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetWorkflowSpecRequest {
/// 工作流基础名称
pub base_name: String,
/// 工作流版本(可选)
pub version: Option<String>,
}
/// 获取工作流规范响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetWorkflowSpecResponse {
/// 工作流规范数据
pub spec: serde_json::Value,
/// 工作流名称
pub name: Option<String>,
/// 工作流版本
pub version: Option<String>,
/// 规范描述
pub description: Option<String>,
}
/// 服务器队列详情
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerQueueDetails {
/// 正在运行的任务数量
pub running_count: i32,
/// 等待中的任务数量
pub pending_count: i32,
}
/// 服务器状态信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerStatus {
/// 服务器在配置列表中的索引
pub server_index: i32,
/// HTTP URL
pub http_url: String,
/// WebSocket URL
pub ws_url: String,
/// 输入目录路径
pub input_dir: String,
/// 输出目录路径
pub output_dir: String,
/// 服务器是否可达
pub is_reachable: bool,
/// 服务器是否空闲
pub is_free: bool,
/// 队列详情
pub queue_details: ServerQueueDetails,
}
/// 文件详情
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileDetails {
/// 文件名
pub name: String,
/// 文件大小KB
pub size_kb: f64,
/// 修改时间
pub modified_at: DateTime<Utc>,
}
/// 服务器文件信息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerFiles {
/// 服务器在配置列表中的索引
pub server_index: i32,
/// HTTP URL
pub http_url: String,
/// 输入文件列表
pub input_files: Vec<FileDetails>,
/// 输出文件列表
pub output_files: Vec<FileDetails>,
}
/// API 根响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiRootResponse {
/// API 使用指南
pub guide: String,
/// API 版本
pub version: Option<String>,
/// 可用端点
pub endpoints: Option<Vec<String>>,
}
/// HTTP 验证错误详情
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpValidationError {
/// 错误位置
pub loc: Vec<serde_json::Value>,
/// 错误消息
pub msg: String,
/// 错误类型
#[serde(rename = "type")]
pub error_type: String,
}
/// HTTP 验证错误响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HTTPValidationError {
/// 错误详情列表
pub detail: Vec<HttpValidationError>,
}
/// ComfyUI API 错误类型
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ComfyuiError {
/// 网络连接错误
NetworkError { message: String },
/// HTTP 错误
HttpError { status: u16, message: String },
/// 验证错误
ValidationError { errors: Vec<ValidationError> },
/// 服务器内部错误
ServerError { message: String },
/// 工作流不存在
WorkflowNotFound { workflow_name: String },
/// 配置错误
ConfigError { message: String },
/// 超时错误
TimeoutError { message: String },
/// 未知错误
UnknownError { message: String },
}