feat: 实现ComfyUI工作流批量执行文件夹处理功能

- 新增后端文件夹遍历API (get_directory_files, validate_directory_access, get_directory_info)
- 修改前端文件选择逻辑,批量模式下支持文件夹选择
- 优化批量模式UI显示,显示文件夹选择提示和文件数量统计
- 集成文件夹处理与组合生成逻辑,支持文件数组参与排列组合
- 添加错误处理和用户体验优化:权限检查、空文件夹处理、加载状态显示
- 编写单元测试用例验证功能正确性

功能特性:
 支持递归遍历文件夹及子文件夹
 按文件扩展名过滤文件类型
 批量上传文件到云端
 完整的错误处理和用户反馈
 与现有批量执行逻辑无缝集成
This commit is contained in:
imeepos 2025-08-20 17:51:32 +08:00
parent cc5812fce8
commit f32b945742
4 changed files with 523 additions and 33 deletions

View File

@ -66,6 +66,10 @@ pub fn run() {
commands::system_commands::cleanup_performance_data,
commands::system_commands::record_performance_metric,
commands::system_commands::cleanup_invalid_projects,
// 文件系统命令
commands::file_system_commands::get_directory_files,
commands::file_system_commands::validate_directory_access,
commands::file_system_commands::get_directory_info,
commands::database_commands::initialize_database,
commands::database_commands::check_database_connection,
commands::database_commands::force_release_database_connection,

View File

@ -0,0 +1,322 @@
use std::path::{Path, PathBuf};
use std::fs;
use tauri::command;
use anyhow::{Result, anyhow};
use tracing::{info, warn, error};
/// 文件夹遍历结果
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct DirectoryFilesResult {
pub files: Vec<String>,
pub total_count: usize,
pub directory_path: String,
pub filtered_extensions: Vec<String>,
}
/// 递归获取指定文件夹及子文件夹中的所有文件
///
/// # 参数
/// - `directory_path`: 文件夹路径
/// - `extensions`: 可选的文件扩展名过滤器(如 ["jpg", "png", "gif"]
/// - `recursive`: 是否递归遍历子文件夹,默认为 true
/// - `max_files`: 最大文件数量限制,默认为 1000
///
/// # 返回
/// 返回文件路径列表和统计信息
#[command]
pub async fn get_directory_files(
directory_path: String,
extensions: Option<Vec<String>>,
recursive: Option<bool>,
max_files: Option<usize>,
) -> Result<DirectoryFilesResult, String> {
info!("开始遍历文件夹: {}", directory_path);
let path = Path::new(&directory_path);
// 验证路径是否存在且为目录
if !path.exists() {
let error_msg = format!("路径不存在: {}", directory_path);
error!("{}", error_msg);
return Err(error_msg);
}
if !path.is_dir() {
let error_msg = format!("路径不是文件夹: {}", directory_path);
error!("{}", error_msg);
return Err(error_msg);
}
// 检查读取权限
match fs::read_dir(path) {
Ok(_) => {},
Err(e) => {
let error_msg = format!("无法访问文件夹 {}: {}", directory_path, e);
error!("{}", error_msg);
return Err(error_msg);
}
}
let recursive = recursive.unwrap_or(true);
let max_files = max_files.unwrap_or(1000);
let extensions = extensions.unwrap_or_default();
// 标准化扩展名(转为小写,去掉点号)
let normalized_extensions: Vec<String> = extensions
.iter()
.map(|ext| ext.trim_start_matches('.').to_lowercase())
.collect();
info!("遍历参数 - 递归: {}, 最大文件数: {}, 扩展名过滤: {:?}",
recursive, max_files, normalized_extensions);
let mut files = Vec::new();
match collect_files(path, &normalized_extensions, recursive, max_files, &mut files) {
Ok(_) => {
let total_count = files.len();
info!("文件夹遍历完成,找到 {} 个文件", total_count);
Ok(DirectoryFilesResult {
files,
total_count,
directory_path: directory_path.clone(),
filtered_extensions: normalized_extensions,
})
},
Err(e) => {
let error_msg = format!("遍历文件夹时出错: {}", e);
error!("{}", error_msg);
Err(error_msg)
}
}
}
/// 递归收集文件
fn collect_files(
dir: &Path,
extensions: &[String],
recursive: bool,
max_files: usize,
files: &mut Vec<String>,
) -> Result<()> {
if files.len() >= max_files {
warn!("已达到最大文件数量限制: {}", max_files);
return Ok(());
}
let entries = fs::read_dir(dir)
.map_err(|e| anyhow!("无法读取目录 {}: {}", dir.display(), e))?;
for entry in entries {
if files.len() >= max_files {
break;
}
let entry = entry.map_err(|e| anyhow!("读取目录项时出错: {}", e))?;
let path = entry.path();
if path.is_file() {
// 检查文件扩展名
if should_include_file(&path, extensions) {
if let Some(path_str) = path.to_str() {
files.push(path_str.to_string());
} else {
warn!("无法转换路径为字符串: {:?}", path);
}
}
} else if path.is_dir() && recursive {
// 递归遍历子目录
if let Err(e) = collect_files(&path, extensions, recursive, max_files, files) {
warn!("遍历子目录 {} 时出错: {}", path.display(), e);
// 继续遍历其他目录,不中断整个过程
}
}
}
Ok(())
}
/// 检查文件是否应该被包含
fn should_include_file(path: &Path, extensions: &[String]) -> bool {
// 如果没有指定扩展名过滤器,包含所有文件
if extensions.is_empty() {
return true;
}
// 获取文件扩展名
if let Some(ext) = path.extension() {
if let Some(ext_str) = ext.to_str() {
let ext_lower = ext_str.to_lowercase();
return extensions.contains(&ext_lower);
}
}
false
}
/// 验证文件夹路径是否有效且可访问
#[command]
pub async fn validate_directory_access(directory_path: String) -> Result<bool, String> {
let path = Path::new(&directory_path);
if !path.exists() {
return Ok(false);
}
if !path.is_dir() {
return Ok(false);
}
// 检查读取权限
match fs::read_dir(path) {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
/// 获取文件夹基本信息
#[command]
pub async fn get_directory_info(directory_path: String) -> Result<serde_json::Value, String> {
let path = Path::new(&directory_path);
if !path.exists() {
return Err("路径不存在".to_string());
}
if !path.is_dir() {
return Err("路径不是文件夹".to_string());
}
let mut file_count = 0;
let mut dir_count = 0;
match fs::read_dir(path) {
Ok(entries) => {
for entry in entries {
if let Ok(entry) = entry {
let path = entry.path();
if path.is_file() {
file_count += 1;
} else if path.is_dir() {
dir_count += 1;
}
}
}
},
Err(e) => {
return Err(format!("无法读取文件夹: {}", e));
}
}
let info = serde_json::json!({
"path": directory_path,
"name": path.file_name().and_then(|n| n.to_str()).unwrap_or(""),
"file_count": file_count,
"directory_count": dir_count,
"total_items": file_count + dir_count
});
Ok(info)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
use std::fs::File;
use std::io::Write;
#[tokio::test]
async fn test_get_directory_files() {
// 创建临时目录结构
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path();
// 创建测试文件
let mut file1 = File::create(temp_path.join("test1.jpg")).unwrap();
file1.write_all(b"test content").unwrap();
let mut file2 = File::create(temp_path.join("test2.png")).unwrap();
file2.write_all(b"test content").unwrap();
let mut file3 = File::create(temp_path.join("test3.txt")).unwrap();
file3.write_all(b"test content").unwrap();
// 创建子目录和文件
let sub_dir = temp_path.join("subdir");
fs::create_dir(&sub_dir).unwrap();
let mut file4 = File::create(sub_dir.join("test4.jpg")).unwrap();
file4.write_all(b"test content").unwrap();
// 测试获取所有文件
let result = get_directory_files(
temp_path.to_str().unwrap().to_string(),
None,
Some(true),
None,
).await.unwrap();
assert_eq!(result.total_count, 4);
assert!(result.files.iter().any(|f| f.contains("test1.jpg")));
assert!(result.files.iter().any(|f| f.contains("test2.png")));
assert!(result.files.iter().any(|f| f.contains("test3.txt")));
assert!(result.files.iter().any(|f| f.contains("test4.jpg")));
// 测试扩展名过滤
let result = get_directory_files(
temp_path.to_str().unwrap().to_string(),
Some(vec!["jpg".to_string(), "png".to_string()]),
Some(true),
None,
).await.unwrap();
assert_eq!(result.total_count, 3);
assert!(result.files.iter().any(|f| f.contains("test1.jpg")));
assert!(result.files.iter().any(|f| f.contains("test2.png")));
assert!(result.files.iter().any(|f| f.contains("test4.jpg")));
assert!(!result.files.iter().any(|f| f.contains("test3.txt")));
// 测试非递归
let result = get_directory_files(
temp_path.to_str().unwrap().to_string(),
None,
Some(false),
None,
).await.unwrap();
assert_eq!(result.total_count, 3);
assert!(!result.files.iter().any(|f| f.contains("test4.jpg")));
}
#[tokio::test]
async fn test_validate_directory_access() {
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path().to_str().unwrap();
// 测试有效路径
let result = validate_directory_access(temp_path.to_string()).await.unwrap();
assert!(result);
// 测试无效路径
let result = validate_directory_access("/non/existent/path".to_string()).await.unwrap();
assert!(!result);
}
#[tokio::test]
async fn test_get_directory_info() {
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path();
// 创建测试文件和目录
File::create(temp_path.join("test.txt")).unwrap();
fs::create_dir(temp_path.join("subdir")).unwrap();
let result = get_directory_info(temp_path.to_str().unwrap().to_string()).await.unwrap();
assert_eq!(result["file_count"], 1);
assert_eq!(result["directory_count"], 1);
assert_eq!(result["total_items"], 2);
}
}

View File

@ -40,6 +40,7 @@ pub mod system_voice_commands;
pub mod outfit_image_commands;
pub mod outfit_photo_generation_commands;
pub mod workflow_management_commands;
pub mod file_system_commands;
pub mod error_handling_commands;
pub mod volcano_video_commands;
pub mod bowong_text_video_agent_commands;

View File

@ -2,6 +2,7 @@ import React, { useState, useMemo } from 'react';
import { useForm, Controller } from 'react-hook-form';
import { ajvResolver } from '@hookform/resolvers/ajv';
import { open } from '@tauri-apps/plugin-dialog';
import { invoke } from '@tauri-apps/api/core';
import {
XMarkIcon,
InformationCircleIcon,
@ -141,7 +142,14 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
const defaults: any = {};
formFields.forEach(field => {
if (mode === 'batch') {
defaults[field.name] = field.enum ? [field.enum[0]] : [field.default || ''];
if (field.type === 'image' || field.format === 'binary') {
// 文件字段在批量模式下默认为空数组,等待用户选择文件夹
defaults[field.name] = [];
} else if (field.enum) {
defaults[field.name] = [field.enum[0]];
} else {
defaults[field.name] = [field.default || ''];
}
} else {
defaults[field.name] = field.default || (field.type === 'number' ? 0 : '');
}
@ -219,7 +227,17 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
// 生成参数组合
const generateCombinations = (data: any): any[] => {
const keys = Object.keys(data);
const values = keys.map(key => Array.isArray(data[key]) ? data[key] : [data[key]]);
const values = keys.map(key => {
if (Array.isArray(data[key])) {
// 如果是数组且不为空,使用数组值
return data[key].length > 0 ? data[key] : [''];
} else {
// 如果不是数组,转换为单元素数组
return [data[key]];
}
});
console.log('🔄 批量组合生成 - 键值对:', keys.map((key, i) => ({ key, values: values[i] })));
const combinations: any[] = [];
const generate = (current: any[], index: number) => {
@ -238,6 +256,10 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
};
generate([], 0);
console.log('🎯 生成的组合数量:', combinations.length);
console.log('📋 组合预览:', combinations.slice(0, 3));
return combinations;
};
@ -286,32 +308,138 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
setUploading(true);
setUploadProgress(0);
// 1. 选择文件
const acceptedExtensions = field.acceptedTypes || ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp'];
const selected = await open({
multiple: false,
filters: [
{
name: field.contentMediaType === 'image/*' ? '图像文件' : '文件',
extensions: acceptedExtensions,
},
],
});
if (selected && typeof selected === 'string') {
// 2. 上传到云端
const result = await fileUploadService.uploadFileToCloud(
selected,
undefined,
(progress: number) => setUploadProgress(progress)
);
if (mode === 'batch') {
// 批量模式:选择文件夹
const selected = await open({
directory: true,
multiple: false,
});
if (result.status === 'success' && result.url) {
setUploadedUrl(result.url);
onChange(result.url); // 使用云端URL而不是本地路径
} else {
console.error('上传失败:', result.error);
alert(`上传失败: ${result.error}`);
if (selected && typeof selected === 'string') {
console.log('选择的文件夹:', selected);
try {
// 首先验证文件夹访问权限
const isAccessible = await invoke('validate_directory_access', {
directoryPath: selected
}) as boolean;
if (!isAccessible) {
alert('无法访问选择的文件夹,请检查文件夹权限或选择其他文件夹');
return;
}
// 获取文件夹基本信息
const dirInfo = await invoke('get_directory_info', {
directoryPath: selected
}) as { file_count: number, directory_count: number, total_items: number };
console.log('文件夹信息:', dirInfo);
if (dirInfo.file_count === 0) {
alert('选择的文件夹中没有任何文件,请选择包含文件的文件夹');
return;
}
// 调用后端API获取文件夹中的所有文件
setUploadProgress(10); // 开始遍历
const result = await invoke('get_directory_files', {
directoryPath: selected,
extensions: acceptedExtensions,
recursive: true,
maxFiles: 1000
}) as { files: string[], total_count: number, directory_path: string };
console.log('文件夹遍历结果:', result);
setUploadProgress(20); // 遍历完成
if (!result.files || result.files.length === 0) {
alert(`文件夹中没有找到匹配的 ${acceptedExtensions.map(ext => ext.toUpperCase()).join('、')} 格式文件`);
return;
}
if (result.files.length > 100) {
const confirmed = confirm(`文件夹中找到 ${result.files.length} 个匹配文件,上传可能需要较长时间。是否继续?`);
if (!confirmed) {
return;
}
}
// 批量上传所有文件到云端
const uploadedFiles: string[] = [];
const failedFiles: string[] = [];
const totalFiles = result.files.length;
for (let i = 0; i < result.files.length; i++) {
const filePath = result.files[i];
const progress = 20 + Math.round((i / totalFiles) * 70); // 20-90%
setUploadProgress(progress);
try {
const uploadResult = await fileUploadService.uploadFileToCloud(
filePath,
undefined,
() => {} // 不显示单个文件的进度
);
if (uploadResult.status === 'success' && uploadResult.url) {
uploadedFiles.push(uploadResult.url);
} else {
console.warn(`文件上传失败: ${filePath}`, uploadResult.error);
failedFiles.push(filePath);
}
} catch (error) {
console.error(`文件上传异常: ${filePath}`, error);
failedFiles.push(filePath);
}
}
setUploadProgress(95); // 上传完成
if (uploadedFiles.length > 0) {
setUploadedUrl(`成功上传 ${uploadedFiles.length} 个文件${failedFiles.length > 0 ? `${failedFiles.length} 个文件失败` : ''}`);
onChange(uploadedFiles); // 批量模式返回文件数组
if (failedFiles.length > 0) {
console.warn('上传失败的文件:', failedFiles);
}
} else {
alert('所有文件上传都失败了,请检查网络连接或文件格式');
}
} catch (error) {
console.error('文件夹处理出错:', error);
alert(`文件夹处理失败: ${error}`);
}
}
} else {
// 单文件模式:选择单个文件
const selected = await open({
multiple: false,
filters: [
{
name: field.contentMediaType === 'image/*' ? '图像文件' : '文件',
extensions: acceptedExtensions,
},
],
});
if (selected && typeof selected === 'string') {
// 上传到云端
const result = await fileUploadService.uploadFileToCloud(
selected,
undefined,
(progress: number) => setUploadProgress(progress)
);
if (result.status === 'success' && result.url) {
setUploadedUrl(result.url);
onChange(result.url); // 使用云端URL而不是本地路径
} else {
console.error('上传失败:', result.error);
alert(`上传失败: ${result.error}`);
}
}
}
} catch (error) {
@ -329,7 +457,9 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
<div className="space-y-4">
<CloudArrowUpIcon className="mx-auto h-16 w-16 text-blue-500 animate-pulse" />
<div>
<p className="text-base font-medium text-blue-600 mb-2">...</p>
<p className="text-base font-medium text-blue-600 mb-2">
{mode === 'batch' ? '批量处理中...' : '上传中...'}
</p>
<div className="w-full bg-gray-200 rounded-full h-2">
<div
className="bg-blue-600 h-2 rounded-full transition-all duration-300"
@ -337,6 +467,13 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
></div>
</div>
<p className="text-sm text-gray-500 mt-1">{uploadProgress}%</p>
{mode === 'batch' && (
<div className="mt-2 text-xs text-gray-400">
{uploadProgress < 20 && '正在遍历文件夹...'}
{uploadProgress >= 20 && uploadProgress < 90 && '正在批量上传文件...'}
{uploadProgress >= 90 && '即将完成...'}
</div>
)}
</div>
</div>
) : (
@ -349,14 +486,24 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
disabled={uploading}
className="text-base font-medium text-blue-600 hover:text-blue-700 transition-colors bg-blue-50 hover:bg-blue-100 px-4 py-2 rounded-lg disabled:opacity-50"
>
{mode === 'batch' ? '选择文件夹' : '选择本地文件'}
</button>
<p className="text-sm text-gray-500 mt-2">
{field.acceptedTypes?.map(type => type.toUpperCase()).join('、') || 'JPG、PNG、GIF'}
{mode === 'batch'
? `选择包含 ${field.acceptedTypes?.map(type => type.toUpperCase()).join('、') || 'JPG、PNG、GIF'} 格式文件的文件夹,将自动遍历并上传所有匹配文件`
: `支持 ${field.acceptedTypes?.map(type => type.toUpperCase()).join('、') || 'JPG、PNG、GIF'} 格式,将自动上传到云端`
}
</p>
{field.description && (
<p className="text-xs text-gray-400 mt-1">{field.description}</p>
)}
{mode === 'batch' && (
<div className="mt-2 p-2 bg-purple-50 border border-purple-200 rounded-lg">
<p className="text-xs text-purple-700">
💡
</p>
</div>
)}
</div>
</>
)}
@ -367,11 +514,27 @@ export const ReactHookFormWorkflow: React.FC<ReactHookFormWorkflowProps> = ({
<CheckCircleIcon className="h-5 w-5 text-green-600 mr-2" />
<div className="text-left">
<p className="text-sm text-green-700 font-medium">
{mode === 'batch' ? '批量上传完成' : '已上传到云端'}
</p>
<p className="text-xs text-green-600 mt-1 break-all">
{uploadedUrl}
{mode === 'batch' && Array.isArray(value)
? `成功上传 ${value.length} 个文件到云端`
: uploadedUrl
}
</p>
{mode === 'batch' && Array.isArray(value) && value.length > 0 && (
<div className="mt-2 max-h-20 overflow-y-auto">
<p className="text-xs text-green-500 font-medium mb-1">:</p>
{value.slice(0, 3).map((url: string, index: number) => (
<p key={index} className="text-xs text-green-500 truncate">
{index + 1}. {url.split('/').pop() || url}
</p>
))}
{value.length > 3 && (
<p className="text-xs text-green-500">... {value.length - 3} </p>
)}
</div>
)}
</div>
</div>
</div>