//! ComfyUI 数据模型 //! 基于 comfyui-sdk 重新设计的统一数据模型 use serde::{Deserialize, Serialize}; use chrono::{DateTime, Utc}; use std::collections::HashMap; use uuid::Uuid; // 重新导出 SDK 类型 pub use comfyui_sdk::types::{ ComfyUIWorkflow, ParameterSchema, ParameterType, ParameterValues, ValidationResult, ValidationError as SDKValidationError, TemplateMetadata, WorkflowTemplateData, QueueStatus, SystemStats, ObjectInfo }; /// 工作流数据模型 /// 存储完整的工作流信息,包括元数据和工作流定义 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WorkflowModel { /// 工作流唯一标识符 pub id: String, /// 工作流名称 pub name: String, /// 工作流描述 pub description: Option, /// 工作流 JSON 数据 pub workflow_data: ComfyUIWorkflow, /// 工作流版本 pub version: String, /// 创建时间 pub created_at: DateTime, /// 更新时间 pub updated_at: DateTime, /// 是否启用 pub enabled: bool, /// 标签 pub tags: Vec, /// 分类 pub category: Option, } impl WorkflowModel { /// 创建新的工作流模型 pub fn new(name: String, workflow_data: ComfyUIWorkflow) -> Self { let now = Utc::now(); Self { id: Uuid::new_v4().to_string(), name, description: None, workflow_data, version: "1.0".to_string(), created_at: now, updated_at: now, enabled: true, tags: Vec::new(), category: None, } } /// 更新工作流数据 pub fn update_workflow(&mut self, workflow_data: ComfyUIWorkflow) { self.workflow_data = workflow_data; self.updated_at = Utc::now(); } /// 更新元数据 pub fn update_metadata(&mut self, name: Option, description: Option, tags: Option>) { if let Some(name) = name { self.name = name; } if let Some(description) = description { self.description = Some(description); } if let Some(tags) = tags { self.tags = tags; } self.updated_at = Utc::now(); } } /// 模板数据模型 /// 存储工作流模板信息,包括参数定义和模板数据 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TemplateModel { /// 模板唯一标识符 pub id: String, /// 模板名称 pub name: String, /// 模板分类 pub category: Option, /// 模板描述 pub description: Option, /// 模板数据 pub template_data: WorkflowTemplateData, /// 参数定义 pub parameter_schema: HashMap, /// 创建时间 pub created_at: DateTime, /// 更新时间 pub updated_at: DateTime, /// 是否启用 pub enabled: bool, /// 标签 pub tags: Vec, /// 作者 pub author: Option, /// 版本 pub version: String, } impl TemplateModel { /// 创建新的模板模型 pub fn new(name: String, template_data: WorkflowTemplateData, parameter_schema: HashMap) -> Self { let now = Utc::now(); Self { id: Uuid::new_v4().to_string(), name, category: None, description: None, template_data, parameter_schema, created_at: now, updated_at: now, enabled: true, tags: Vec::new(), author: None, version: "1.0".to_string(), } } /// 验证参数 pub fn validate_parameters(&self, params: &ParameterValues) -> ValidationResult { let mut result = ValidationResult::success(); // 检查必需参数 for (param_name, schema) in &self.parameter_schema { if schema.required.unwrap_or(false) && !params.contains_key(param_name) { result.add_error(ValidationError::new( param_name, format!("Required parameter '{}' is missing", param_name) )); } } // 检查参数类型和值 for (param_name, value) in params { if let Some(schema) = self.parameter_schema.get(param_name) { if let Err(error) = self.validate_parameter_value(param_name, value, schema) { result.add_error(error); } } } result } /// 验证单个参数值 fn validate_parameter_value(&self, param_name: &str, value: &serde_json::Value, schema: &ParameterSchema) -> Result<(), ValidationError> { match (&schema.param_type, value) { (ParameterType::String, serde_json::Value::String(_)) => Ok(()), (ParameterType::Number, serde_json::Value::Number(_)) => Ok(()), (ParameterType::Boolean, serde_json::Value::Bool(_)) => Ok(()), (ParameterType::Array, serde_json::Value::Array(_)) => Ok(()), (ParameterType::Object, serde_json::Value::Object(_)) => Ok(()), _ => Err(ValidationError::with_value( param_name, format!("Parameter '{}' has invalid type, expected {:?}", param_name, schema.param_type), value.clone() )), } } /// 更新模板数据 pub fn update_template(&mut self, template_data: WorkflowTemplateData, parameter_schema: HashMap) { self.template_data = template_data; self.parameter_schema = parameter_schema; self.updated_at = Utc::now(); } } /// 执行状态枚举 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ExecutionStatus { /// 等待中 Pending, /// 运行中 Running, /// 已完成 Completed, /// 失败 Failed, /// 已取消 Cancelled, } impl std::fmt::Display for ExecutionStatus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ExecutionStatus::Pending => write!(f, "pending"), ExecutionStatus::Running => write!(f, "running"), ExecutionStatus::Completed => write!(f, "completed"), ExecutionStatus::Failed => write!(f, "failed"), ExecutionStatus::Cancelled => write!(f, "cancelled"), } } } /// 执行记录模型 /// 存储工作流执行的详细信息和结果 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionModel { /// 执行唯一标识符 pub id: String, /// 关联的工作流 ID(可选) pub workflow_id: Option, /// 关联的模板 ID(可选) pub template_id: Option, /// ComfyUI 提示 ID pub prompt_id: String, /// 执行状态 pub status: ExecutionStatus, /// 执行参数 pub parameters: Option, /// 执行结果 pub results: Option>, /// 输出文件 URLs pub output_urls: Vec, /// 错误信息 pub error_message: Option, /// 执行时间(毫秒) pub execution_time: Option, /// 创建时间 pub created_at: DateTime, /// 完成时间 pub completed_at: Option>, /// 客户端 ID pub client_id: Option, /// 节点输出详情 pub node_outputs: Option>, } impl ExecutionModel { /// 创建新的执行记录 pub fn new(prompt_id: String) -> Self { Self { id: Uuid::new_v4().to_string(), workflow_id: None, template_id: None, prompt_id, status: ExecutionStatus::Pending, parameters: None, results: None, output_urls: Vec::new(), error_message: None, execution_time: None, created_at: Utc::now(), completed_at: None, client_id: None, node_outputs: None, } } /// 创建工作流执行记录 pub fn for_workflow(workflow_id: String, prompt_id: String, parameters: Option) -> Self { let mut execution = Self::new(prompt_id); execution.workflow_id = Some(workflow_id); execution.parameters = parameters; execution } /// 创建模板执行记录 pub fn for_template(template_id: String, prompt_id: String, parameters: ParameterValues) -> Self { let mut execution = Self::new(prompt_id); execution.template_id = Some(template_id); execution.parameters = Some(parameters); execution } /// 更新执行状态 pub fn update_status(&mut self, status: ExecutionStatus) { self.status = status; if matches!(status, ExecutionStatus::Completed | ExecutionStatus::Failed | ExecutionStatus::Cancelled) { self.completed_at = Some(Utc::now()); if let Some(created_at) = self.created_at.timestamp_millis().try_into().ok() { if let Some(completed_at) = self.completed_at.and_then(|t| t.timestamp_millis().try_into().ok()) { self.execution_time = Some(completed_at - created_at); } } } } /// 设置执行结果 pub fn set_results(&mut self, results: HashMap, output_urls: Vec) { self.results = Some(results); self.output_urls = output_urls; self.update_status(ExecutionStatus::Completed); } /// 设置执行错误 pub fn set_error(&mut self, error_message: String) { self.error_message = Some(error_message); self.update_status(ExecutionStatus::Failed); } /// 检查是否已完成 pub fn is_completed(&self) -> bool { matches!(self.status, ExecutionStatus::Completed | ExecutionStatus::Failed | ExecutionStatus::Cancelled) } /// 检查是否成功 pub fn is_successful(&self) -> bool { self.status == ExecutionStatus::Completed } } /// 队列状态模型 /// 存储 ComfyUI 队列的状态信息 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueueModel { /// 队列快照 ID pub id: String, /// 正在运行的任务 pub running_tasks: Vec, /// 等待中的任务 pub pending_tasks: Vec, /// 快照时间 pub snapshot_time: DateTime, /// 总任务数 pub total_tasks: u32, /// 已完成任务数 pub completed_tasks: u32, } /// 队列任务信息 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueueTaskInfo { /// 提示 ID pub prompt_id: String, /// 任务编号 pub number: u32, /// 任务状态 pub status: String, /// 提交时间 pub submitted_at: Option>, /// 开始时间 pub started_at: Option>, /// 预估剩余时间(秒) pub estimated_remaining: Option, } impl QueueModel { /// 从 SDK 队列状态创建模型 pub fn from_queue_status(queue_status: &QueueStatus) -> Self { let running_tasks: Vec = queue_status.queue_running .iter() .map(|item| QueueTaskInfo { prompt_id: item.prompt_id.clone(), number: item.number, status: "running".to_string(), submitted_at: None, started_at: Some(Utc::now()), // 假设正在运行的任务刚开始 estimated_remaining: None, }) .collect(); let pending_tasks: Vec = queue_status.queue_pending .iter() .map(|item| QueueTaskInfo { prompt_id: item.prompt_id.clone(), number: item.number, status: "pending".to_string(), submitted_at: Some(Utc::now()), // 假设等待中的任务刚提交 started_at: None, estimated_remaining: None, }) .collect(); let total_tasks = (running_tasks.len() + pending_tasks.len()) as u32; Self { id: Uuid::new_v4().to_string(), running_tasks, pending_tasks, snapshot_time: Utc::now(), total_tasks, completed_tasks: 0, // 这个需要从历史记录中计算 } } /// 获取队列统计信息 pub fn get_statistics(&self) -> QueueStatistics { QueueStatistics { total_running: self.running_tasks.len() as u32, total_pending: self.pending_tasks.len() as u32, total_tasks: self.total_tasks, completed_tasks: self.completed_tasks, average_wait_time: self.calculate_average_wait_time(), } } /// 计算平均等待时间 fn calculate_average_wait_time(&self) -> Option { if self.pending_tasks.is_empty() { return None; } let now = Utc::now(); let total_wait_time: i64 = self.pending_tasks .iter() .filter_map(|task| task.submitted_at) .map(|submitted| (now - submitted).num_seconds()) .sum(); if total_wait_time > 0 { Some((total_wait_time / self.pending_tasks.len() as i64) as u64) } else { None } } } /// 队列统计信息 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueueStatistics { /// 正在运行的任务数 pub total_running: u32, /// 等待中的任务数 pub total_pending: u32, /// 总任务数 pub total_tasks: u32, /// 已完成任务数 pub completed_tasks: u32, /// 平均等待时间(秒) pub average_wait_time: Option, } /// 旧版 ComfyUI 配置(用于 Infrastructure 服务) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ComfyuiConfig { /// 服务基础 URL pub base_url: String, /// 连接超时时间(秒) pub timeout: Option, /// 重试次数 pub retry_attempts: Option, /// 是否启用缓存 pub enable_cache: Option, /// 最大并发数 pub max_concurrency: Option, } impl Default for ComfyuiConfig { fn default() -> Self { Self { base_url: "http://localhost:8188".to_string(), timeout: Some(300), retry_attempts: Some(3), enable_cache: Some(true), max_concurrency: Some(4), } } } /// ComfyUI 服务配置 /// 基于 SDK 的统一配置结构 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ComfyUIConfig { /// 服务基础 URL pub base_url: String, /// 连接超时时间(秒) pub timeout_seconds: u64, /// 重试次数 pub retry_attempts: u32, /// 重试延迟(毫秒) pub retry_delay_ms: u64, /// 是否启用 WebSocket pub enable_websocket: bool, /// 是否启用缓存 pub enable_cache: bool, /// 最大并发数 pub max_concurrency: u32, /// 自定义请求头 pub custom_headers: Option>, } impl Default for ComfyUIConfig { fn default() -> Self { Self { base_url: "http://localhost:8188".to_string(), timeout_seconds: 300, // 5分钟默认超时 retry_attempts: 3, retry_delay_ms: 1000, enable_websocket: true, enable_cache: true, max_concurrency: 4, custom_headers: None, } } } impl ComfyUIConfig { /// 转换为 SDK 客户端配置 pub fn to_sdk_config(&self) -> comfyui_sdk::types::ComfyUIClientConfig { comfyui_sdk::types::ComfyUIClientConfig { base_url: self.base_url.clone(), timeout: Some(std::time::Duration::from_secs(self.timeout_seconds)), retry_attempts: Some(self.retry_attempts), retry_delay: Some(std::time::Duration::from_millis(self.retry_delay_ms)), headers: self.custom_headers.clone(), } } /// 验证配置 pub fn validate(&self) -> ValidationResult { let mut result = ValidationResult::success(); // 验证 URL if self.base_url.is_empty() { result.add_error(SDKValidationError::new("base_url".to_string(), "Base URL cannot be empty".to_string(), None)); } else if url::Url::parse(&self.base_url).is_err() { result.add_error(SDKValidationError::new("base_url".to_string(), "Invalid URL format".to_string(), None)); } // 验证超时时间 if self.timeout_seconds == 0 { result.add_error(SDKValidationError::new("timeout_seconds".to_string(), "Timeout must be greater than 0".to_string(), None)); } // 验证重试次数 if self.retry_attempts > 10 { result.add_error(SDKValidationError::new("retry_attempts".to_string(), "Retry attempts should not exceed 10".to_string(), None)); } // 验证并发数 if self.max_concurrency == 0 { result.add_error(SDKValidationError::new("max_concurrency".to_string(), "Max concurrency must be greater than 0".to_string(), None)); } else if self.max_concurrency > 100 { result.add_error(SDKValidationError::new("max_concurrency".to_string(), "Max concurrency should not exceed 100".to_string(), None)); } result } } /// 工作流对象 - 用于 GET /api/workflow 响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Workflow { /// 工作流的完整唯一名称,例如 'my_workflow [20250101120000]' pub name: String, /// 工作流的基础名称 pub base_name: Option, /// 工作流版本 pub version: Option, /// 工作流创建时间 pub created_at: Option>, /// 工作流更新时间 pub updated_at: Option>, /// 工作流描述 pub description: Option, /// 工作流配置数据 #[serde(flatten)] pub additional_properties: HashMap, } /// 发布工作流请求 - 用于 POST /api/workflow #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PublishWorkflowRequest { /// 工作流名称 pub name: String, /// 工作流配置数据 pub workflow_data: serde_json::Value, /// 工作流描述 pub description: Option, /// 工作流版本 pub version: Option, } /// 发布工作流响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PublishWorkflowResponse { /// 操作是否成功 pub success: bool, /// 响应消息 pub message: Option, /// 创建的工作流信息 pub workflow: Option, } /// 删除工作流响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DeleteWorkflowResponse { /// 操作是否成功 pub success: bool, /// 响应消息 pub message: Option, /// 删除的工作流名称 pub deleted_workflow: Option, } /// 执行工作流请求 - 用于 POST /api/run/ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecuteWorkflowRequest { /// 工作流基础名称 pub base_name: String, /// 工作流版本(可选) pub version: Option, /// 请求数据 pub request_data: serde_json::Value, } /// 执行工作流响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecuteWorkflowResponse { /// 执行任务ID pub task_id: Option, /// 执行状态 pub status: Option, /// 响应消息 pub message: Option, /// 执行结果数据 pub result: Option, /// 错误信息 pub error: Option, } /// 获取工作流规范请求 - 用于 GET /api/spec/ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GetWorkflowSpecRequest { /// 工作流基础名称 pub base_name: String, /// 工作流版本(可选) pub version: Option, } /// 获取工作流规范响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GetWorkflowSpecResponse { /// 工作流规范数据 pub spec: serde_json::Value, /// 工作流名称 pub name: Option, /// 工作流版本 pub version: Option, /// 规范描述 pub description: Option, } /// 服务器队列详情 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerQueueDetails { /// 正在运行的任务数量 pub running_count: i32, /// 等待中的任务数量 pub pending_count: i32, } /// 服务器状态信息 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerStatus { /// 服务器在配置列表中的索引 pub server_index: i32, /// HTTP URL pub http_url: String, /// WebSocket URL pub ws_url: String, /// 输入目录路径 pub input_dir: String, /// 输出目录路径 pub output_dir: String, /// 服务器是否可达 pub is_reachable: bool, /// 服务器是否空闲 pub is_free: bool, /// 队列详情 pub queue_details: ServerQueueDetails, } /// 文件详情 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileDetails { /// 文件名 pub name: String, /// 文件大小(KB) pub size_kb: f64, /// 修改时间 pub modified_at: DateTime, } /// 服务器文件信息 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerFiles { /// 服务器在配置列表中的索引 pub server_index: i32, /// HTTP URL pub http_url: String, /// 输入文件列表 pub input_files: Vec, /// 输出文件列表 pub output_files: Vec, } /// API 根响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ApiRootResponse { /// API 使用指南 pub guide: String, /// API 版本 pub version: Option, /// 可用端点 pub endpoints: Option>, } /// HTTP 验证错误详情 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HttpValidationError { /// 错误位置 pub loc: Vec, /// 错误消息 pub msg: String, /// 错误类型 #[serde(rename = "type")] pub error_type: String, } /// HTTP 验证错误响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HTTPValidationError { /// 错误详情列表 pub detail: Vec, } /// ComfyUI API 错误类型 #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ComfyuiError { /// 网络连接错误 NetworkError { message: String }, /// HTTP 错误 HttpError { status: u16, message: String }, /// 验证错误 ValidationError { errors: Vec }, /// 服务器内部错误 ServerError { message: String }, /// 工作流不存在 WorkflowNotFound { workflow_name: String }, /// 配置错误 ConfigError { message: String }, /// 超时错误 TimeoutError { message: String }, /// 未知错误 UnknownError { message: String }, }