ComfyUI-CustomNode/nodes/object_storage_nodes.py

824 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
对象存储节点模块
本模块提供ComfyUI自定义节点支持多种云存储服务的文件上传和下载操作。
采用统一的存储抽象层支持AWS S3和腾讯云COS等存储服务。
节点类型:
- COSDownload: 腾讯云COS文件下载
- COSUpload: 腾讯云COS文件上传
- S3Download: AWS S3文件下载
- S3Upload: AWS S3文件上传
- S3UploadURL: AWS S3文件上传并返回URL
- S3UploadIMAGEURL: AWS S3图像上传并返回URL
设计特性:
- 统一的错误处理和日志记录
- 支持重试机制
- 类型安全的参数和返回值
- 完整的文档和注释
"""
import os
import time
from pathlib import Path
from typing import Optional, Tuple, Union
import loguru
import torch
from ..utils.object_storage import DownloadResult, UploadResult, get_provider
def construct_storage_key(
subfolder: str,
filename: Optional[str] = None,
path: Optional[Union[str, Path]] = None,
) -> str:
"""
构造存储键名的通用工具函数
该函数用于生成云存储中的对象键名,支持多种输入方式:
1. 直接指定文件名
2. 从本地文件路径中提取文件名
3. 自动生成基于时间戳的文件名
Args:
subfolder (str): 存储中的子文件夹路径
filename (Optional[str]): 指定的文件名,优先级最高
path (Optional[Union[str, Path]]): 本地文件路径,用于提取文件名
Returns:
str: 完整的存储键名,格式为 "subfolder/filename"
Examples:
>>> construct_storage_key("images", filename="photo.jpg")
'images/photo.jpg'
>>> construct_storage_key("docs", path="/home/user/document.pdf")
'docs/document.pdf'
>>> construct_storage_key("temp") # 自动生成时间戳文件名
'temp/file_1640995200'
"""
# 参数验证
if not subfolder:
raise ValueError("subfolder参数不能为空")
# 优先使用指定的文件名
if filename is not None:
if not filename or not filename.strip():
raise ValueError("filename参数不能为空字符串")
return f"{subfolder.strip('/')}/{filename.strip()}"
# 从路径中提取文件名
if path:
path_obj = Path(path)
extracted_filename = path_obj.name
if not extracted_filename:
raise ValueError(f"无法从路径中提取文件名: {path}")
return f"{subfolder.strip('/')}/{extracted_filename}"
# 生成基于时间戳的默认文件名
timestamp = int(time.time())
return f"{subfolder.strip('/')}/file_{timestamp}"
class COSDownload:
"""
腾讯云COS文件下载节点
提供从腾讯云COS存储服务下载文件到本地的功能。该节点基于统一的存储抽象层实现
具有以下特性:
- 统一的错误处理和详细的日志记录
- 自动创建本地目录结构
- 完整的参数验证机制
- 支持各种文件类型的下载
节点配置:
- 输入: COS存储桶名称、文件键名
- 输出: 本地文件路径
- 类别: 腾讯云COS存储操作
"""
@classmethod
def INPUT_TYPES(cls):
"""定义节点输入类型"""
return {
"required": {
"cos_bucket": ("STRING", {"default": "bwkj-cos-1324682537"}),
"cos_key": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("文件存储路径",)
FUNCTION = "download"
CATEGORY = "不忘科技-自定义节点🛍/对象存储/COS"
def download(self, cos_bucket: str, cos_key: str) -> Tuple[str]:
"""
从腾讯云COS下载文件到本地
该方法执行完整的COS文件下载流程
1. 验证输入参数
2. 构造本地保存路径
3. 创建必要的目录结构
4. 执行文件下载
5. 验证下载结果
Args:
cos_bucket (str): COS存储桶名称不能为空
cos_key (str): COS中的文件键名不能为空
Returns:
Tuple[str]: 包含本地文件完整路径的单元素元组
Raises:
ValueError: 当输入参数无效时
Exception: 当下载操作失败时
Examples:
>>> node = COSDownload()
>>> result = node.download("my-bucket", "images/photo.jpg")
>>> print(result[0]) # 输出本地文件路径
"""
# 输入参数验证
if not cos_bucket or not cos_bucket.strip():
raise ValueError("COS存储桶名称不能为空")
if not cos_key or not cos_key.strip():
raise ValueError("COS文件键名不能为空")
try:
# 构造安全的本地保存路径
cos_key_normalized = cos_key.strip().replace("/", os.sep)
base_dir = Path(__file__).parent.parent
download_dir = base_dir / "download"
# 构建完整的目标路径
destination_path = download_dir / cos_key_normalized
destination = str(destination_path)
# 确保目录存在
destination_path.parent.mkdir(parents=True, exist_ok=True)
loguru.logger.info(
f"开始从COS下载文件",
extra={
"bucket": cos_bucket,
"key": cos_key,
"destination": destination,
},
)
# 获取COS存储提供者实例
provider = get_provider("cos", {"bucket_name": cos_bucket})
# 执行文件下载操作
result: DownloadResult = provider.download_file(cos_key, destination)
# 检查下载结果
if not result.success:
error_msg = f"COS下载操作失败: {result.message}"
loguru.logger.error(
error_msg,
extra={
"bucket": cos_bucket,
"key": cos_key,
"error_details": result.message,
},
)
raise Exception(error_msg)
loguru.logger.info(
f"COS文件下载成功",
extra={
"bucket": cos_bucket,
"key": cos_key,
"local_path": destination,
"file_size": (
os.path.getsize(destination)
if os.path.exists(destination)
else "未知"
),
},
)
return (destination,)
except ValueError:
# 重新抛出参数验证错误
raise
except Exception as e:
error_msg = (
f"COS下载过程中发生错误: bucket={cos_bucket}, "
f"key={cos_key}, error={str(e)}"
)
loguru.logger.error(
error_msg,
extra={
"bucket": cos_bucket,
"key": cos_key,
"exception_type": type(e).__name__,
"exception_details": str(e),
},
)
raise Exception(error_msg)
class COSUpload:
"""
腾讯云COS文件上传节点
提供上传本地文件到腾讯云COS存储服务的功能。该节点基于统一的存储抽象层实现
具有以下特性:
- 统一的错误处理和详细的日志记录
- 完整的文件存在性和权限验证
- 灵活的目标路径和命名策略
- 支持各种文件类型和大小
节点配置:
- 输入: COS存储桶名称、本地文件路径、子文件夹
- 输出: COS中的文件键名
- 类别: 腾讯云COS存储操作
"""
@classmethod
def INPUT_TYPES(cls):
"""定义节点输入类型"""
return {
"required": {
"cos_bucket": ("STRING", {"default": "bwkj-cos-1324682537"}),
"path": ("STRING", {"multiline": True}),
"subfolder": ("STRING", {"default": "test"}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("COS文件Key",)
FUNCTION = "upload"
CATEGORY = "不忘科技-自定义节点🛍/对象存储/COS"
def upload(self, cos_bucket: str, path: str, subfolder: str) -> Tuple[str]:
"""
上传本地文件到腾讯云COS存储
该方法执行完整的文件上传流程:
1. 验证输入参数和文件存在性
2. 构造目标存储键名
3. 执行文件上传操作
4. 返回上传后的文件键名
Args:
cos_bucket (str): COS存储桶名称不能为空
path (str): 本地文件的完整路径,文件必须存在
subfolder (str): 目标子文件夹名称,不能为空
Returns:
Tuple[str]: 包含COS中文件键名的单元素元组
Raises:
ValueError: 当输入参数无效或文件不存在时
Exception: 当上传操作失败时
Examples:
>>> node = COSUpload()
>>> result = node.upload("my-bucket", "/local/image.jpg", "images")
>>> print(result[0]) # 输出: "images/image.jpg"
"""
# 输入参数验证
if not cos_bucket or not cos_bucket.strip():
raise ValueError("COS存储桶名称不能为空")
if not path or not path.strip():
raise ValueError("本地文件路径不能为空")
if not subfolder or not subfolder.strip():
raise ValueError("子文件夹名称不能为空")
# 文件存在性检查
file_path = Path(path)
if not file_path.exists():
raise ValueError(f"本地文件不存在: {path}")
if not file_path.is_file():
raise ValueError(f"指定路径不是文件: {path}")
try:
# 构造目标存储键名
dest_key = construct_storage_key(subfolder, path=path)
file_size = file_path.stat().st_size
loguru.logger.info(
"开始COS文件上传",
extra={
"local_path": path,
"bucket": cos_bucket,
"target_key": dest_key,
"file_size_bytes": file_size,
},
)
# 获取COS存储提供者实例
provider = get_provider("cos", {"bucket_name": cos_bucket})
# 执行文件上传操作
result: UploadResult = provider.upload_file(path, dest_key)
# 检查上传结果
if not result.success:
error_msg = f"COS上传操作失败: {result.message}"
loguru.logger.error(
error_msg,
extra={
"local_path": path,
"bucket": cos_bucket,
"target_key": dest_key,
"error_details": result.message,
},
)
raise Exception(error_msg)
loguru.logger.info(
"COS文件上传成功",
extra={
"local_path": path,
"bucket": cos_bucket,
"uploaded_key": dest_key,
"file_size_bytes": file_size,
},
)
return (dest_key,)
except ValueError:
# 重新抛出参数验证错误
raise
except Exception as e:
error_msg = (
f"COS上传过程中发生错误: bucket={cos_bucket}, "
f"path={path}, subfolder={subfolder}, error={str(e)}"
)
loguru.logger.error(
error_msg,
extra={
"local_path": path,
"bucket": cos_bucket,
"subfolder": subfolder,
"exception_type": type(e).__name__,
"exception_details": str(e),
},
)
raise Exception(error_msg)
class S3Download:
"""
AWS S3文件下载节点
提供从Amazon S3存储服务下载文件到本地的功能。该节点基于统一的存储抽象层实现
具有以下特性:
- 统一的错误处理和详细的日志记录
- 自动创建本地目录结构
- 完整的参数验证机制
- 支持各种文件类型的下载
节点配置:
- 输入: S3存储桶名称、文件键名
- 输出: 本地文件路径
- 类别: AWS S3存储操作
"""
@classmethod
def INPUT_TYPES(cls):
"""定义节点输入类型"""
return {
"required": {
"s3_bucket": ("STRING", {"default": "bw-comfyui-input"}),
"s3_key": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("文件存储路径",)
FUNCTION = "download"
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
def download(self, s3_bucket: str, s3_key: str) -> Tuple[str]:
"""
从AWS S3存储服务下载文件到本地
执行完整的S3文件下载流程包括参数验证、路径构造、
目录创建和文件下载等操作。
Args:
s3_bucket (str): S3存储桶名称不能为空
s3_key (str): S3中的文件键名不能为空
Returns:
Tuple[str]: 包含本地文件完整路径的单元素元组
Raises:
ValueError: 当输入参数无效时
Exception: 当下载操作失败时
"""
# 输入参数验证
if not s3_bucket or not s3_bucket.strip():
raise ValueError("S3存储桶名称不能为空")
if not s3_key or not s3_key.strip():
raise ValueError("S3文件键名不能为空")
try:
# 构造安全的本地保存路径
s3_key_normalized = s3_key.strip().replace("/", os.sep)
base_dir = Path(__file__).parent.parent
download_dir = base_dir / "download"
# 构建完整的目标路径
destination_path = download_dir / s3_key_normalized
destination = str(destination_path)
# 确保目录结构存在
destination_path.parent.mkdir(parents=True, exist_ok=True)
loguru.logger.info(
"开始从S3下载文件",
extra={"bucket": s3_bucket, "key": s3_key, "destination": destination},
)
# 获取S3存储提供者实例
provider = get_provider("s3", {"bucket_name": s3_bucket})
# 执行文件下载操作
result: DownloadResult = provider.download_file(s3_key, destination)
# 验证下载结果
if not result.success:
error_msg = f"S3下载操作失败: {result.message}"
loguru.logger.error(
error_msg,
extra={
"bucket": s3_bucket,
"key": s3_key,
"error_details": result.message,
},
)
raise Exception(error_msg)
loguru.logger.info(
"S3文件下载成功",
extra={
"bucket": s3_bucket,
"key": s3_key,
"local_path": destination,
"file_size": (
os.path.getsize(destination)
if os.path.exists(destination)
else "未知"
),
},
)
return (destination,)
except ValueError:
# 重新抛出参数验证错误
raise
except Exception as e:
error_msg = (
f"S3下载过程中发生错误: bucket={s3_bucket}, "
f"key={s3_key}, error={str(e)}"
)
loguru.logger.error(
error_msg,
extra={
"bucket": s3_bucket,
"key": s3_key,
"exception_type": type(e).__name__,
"exception_details": str(e),
},
)
raise Exception(error_msg)
class S3Upload:
"""
AWS S3文件上传节点
提供上传本地文件到Amazon S3存储服务的功能。该节点基于统一的存储抽象层实现
具有以下特性:
- 统一的错误处理和详细的日志记录
- 完整的文件存在性和权限验证
- 灵活的目标路径和命名策略
- 支持各种文件类型和大小
节点配置:
- 输入: S3存储桶名称、本地文件路径、子文件夹
- 输出: S3中的文件键名
- 类别: Amazon S3存储操作
"""
@classmethod
def INPUT_TYPES(cls):
"""定义节点输入类型"""
return {
"required": {
"s3_bucket": ("STRING", {"default": "bw-comfyui-output"}),
"path": ("STRING", {"multiline": True}),
"subfolder": ("STRING", {"default": "test"}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("S3文件Key",)
FUNCTION = "upload"
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
def upload(self, s3_bucket: str, path: str, subfolder: str) -> Tuple[str]:
"""
上传文件到AWS S3
Args:
s3_bucket: S3存储桶名称
path: 本地文件路径
subfolder: 子文件夹名称
Returns:
Tuple[str]: 包含S3文件键名的元组
Raises:
Exception: 上传失败时抛出异常
"""
try:
# 构造存储键名
dest_key = construct_storage_key(subfolder, path=path)
loguru.logger.info(f"开始S3上传: {path} -> {s3_bucket}/{dest_key}")
# 获取S3存储提供者
provider = get_provider("s3", {"bucket_name": s3_bucket})
# 执行上传
result = provider.upload_file(path, dest_key)
if not result.success:
error_msg = f"S3上传失败: {result.message}"
loguru.logger.error(error_msg)
raise Exception(error_msg)
loguru.logger.info(f"S3上传成功: {dest_key}")
return (dest_key,)
except Exception as e:
error_msg = f"S3上传失败: bucket={s3_bucket}, path={path}, subfolder={subfolder}, error={str(e)}"
loguru.logger.error(error_msg)
raise Exception(error_msg)
class S3UploadURL:
"""
AWS S3文件上传节点返回URL
提供上传文件到Amazon S3并返回公开访问URL的功能。该节点专门为需要
即时分享文件的场景设计,具有以下特性:
- 自动使用预配置的媒体存储桶
- 支持CDN加速和全球访问
- 提供可靠的URL回退机制
- 统一的错误处理和日志记录
节点配置:
- 输入: 本地文件路径、子文件夹
- 输出: 公开访问URL
- 类别: Amazon S3存储操作
"""
@classmethod
def INPUT_TYPES(cls):
"""定义节点输入类型"""
return {
"required": {
"path": ("STRING", {"multiline": True}),
"subfolder": ("STRING", {"default": "test"}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("URL",)
FUNCTION = "upload"
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
def upload(self, path: str, subfolder: str) -> Tuple[str]:
"""
上传文件到AWS S3并返回访问URL
Args:
path: 本地文件路径
subfolder: 子文件夹名称
Returns:
Tuple[str]: 包含文件访问URL的元组
Raises:
Exception: 上传失败时抛出异常
"""
try:
# 使用默认的媒体缓存桶
s3_bucket = "modal-media-cache"
# 构造存储键名
dest_key = construct_storage_key(subfolder, path=path)
loguru.logger.info(f"开始S3上传: {path} -> {s3_bucket}/{dest_key}")
# 获取S3存储提供者
provider = get_provider("s3", {"bucket_name": s3_bucket})
# 执行上传
result = provider.upload_file(path, dest_key)
if not result.success:
error_msg = f"S3上传失败: {result.message}"
loguru.logger.error(error_msg)
raise Exception(error_msg)
# 返回CDN URL
url = result.url or f"https://cdn.roasmax.cn/{dest_key}"
loguru.logger.info(f"S3上传成功URL: {url}")
return (url,)
except Exception as e:
error_msg = (
f"S3上传失败: path={path}, subfolder={subfolder}, error={str(e)}"
)
loguru.logger.error(error_msg)
raise Exception(error_msg)
class S3UploadIMAGEURL:
"""
AWS S3图像张量上传节点返回URL
提供上传PyTorch图像张量到Amazon S3并返回公开访问URL的专业功能。
该节点专为深度学习和图像处理工作流设计,具有以下特性:
- 全面的PyTorch张量校验和预处理
- 支持2D、3D、4D张量的自动识别和转换
- 内置NaN和无限大值检测
- 高质量PNG格式输出和CDN加速
- 自动生成唯一文件名防止冲突
节点配置:
- 输入: PyTorch图像张量、子文件夹
- 输出: 图像公开访问URL
- 类别: Amazon S3存储操作
"""
@classmethod
def INPUT_TYPES(cls):
"""定义节点输入类型"""
return {
"required": {
"image": ("IMAGE", {}),
"subfolder": ("STRING", {"default": "test"}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("URL",)
FUNCTION = "upload"
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
def upload(self, image: torch.Tensor, subfolder: str) -> Tuple[str]:
"""
上传PyTorch张量作为图像到AWS S3存储并返回可访问的URL
该方法专门处理图像张量的上传操作:
1. 验证输入张量的有效性和维度
2. 生成唯一的文件名(基于时间戳)
3. 将张量转换为指定格式并上传
4. 返回可公开访问的CDN URL
Args:
image (torch.Tensor): PyTorch图像张量应为有效的图像数据
支持的维度格式:(H, W, C) 或 (B, H, W, C)
subfolder (str): 目标子文件夹名称,用于组织存储结构
Returns:
Tuple[str]: 包含图像可访问URL的单元素元组
Raises:
ValueError: 当输入参数无效或张量格式不正确时
Exception: 当上传操作失败时
Examples:
>>> import torch
>>> node = S3UploadIMAGEURL()
>>> tensor = torch.randn(256, 256, 3) # RGB图像
>>> result = node.upload(tensor, "generated_images")
>>> print(result[0]) # 输出CDN URL
"""
# 输入参数验证
if not isinstance(image, torch.Tensor):
raise ValueError("输入必须是PyTorch张量")
if not subfolder or not subfolder.strip():
raise ValueError("子文件夹名称不能为空")
# 张量维度验证
if image.dim() < 2 or image.dim() > 4:
raise ValueError(f"不支持的张量维度: {image.dim()}D支持2D、3D或4D张量")
# 张量数值验证
if torch.isnan(image).any():
raise ValueError("张量包含NaN值")
if torch.isinf(image).any():
raise ValueError("张量包含无限大值")
try:
# 使用配置的默认媒体存储桶
s3_bucket = "modal-media-cache"
# 生成基于时间戳的唯一文件名
timestamp = int(time.time())
filename = f"image_{timestamp}.png"
# 构造完整的存储键名
dest_key = f"{subfolder.strip('/')}/{filename}"
loguru.logger.info(
"开始S3图像张量上传",
extra={
"tensor_shape": list(image.shape),
"tensor_dtype": str(image.dtype),
"bucket": s3_bucket,
"target_key": dest_key,
"tensor_size_mb": image.numel()
* image.element_size()
/ (1024 * 1024),
},
)
# 获取S3存储提供者实例
provider = get_provider("s3", {"bucket_name": s3_bucket})
# 执行张量上传操作转换为PNG格式
result: UploadResult = provider.upload_tensor(image, dest_key, format="PNG")
# 验证上传结果
if not result.success:
error_msg = f"S3图像上传操作失败: {result.message}"
loguru.logger.error(
error_msg,
extra={
"tensor_shape": list(image.shape),
"bucket": s3_bucket,
"target_key": dest_key,
"error_details": result.message,
},
)
raise Exception(error_msg)
# 构造最终的访问URL优先使用返回的URL否则使用默认CDN
final_url = result.url or f"https://cdn.roasmax.cn/{dest_key}"
loguru.logger.info(
"S3图像张量上传成功",
extra={
"tensor_shape": list(image.shape),
"bucket": s3_bucket,
"uploaded_key": dest_key,
"access_url": final_url,
},
)
return (final_url,)
except ValueError:
# 重新抛出参数验证错误
raise
except Exception as e:
error_msg = (
f"S3图像上传过程中发生错误: subfolder={subfolder}, "
f"张量形状={list(image.shape)}, error={str(e)}"
)
loguru.logger.error(
error_msg,
extra={
"subfolder": subfolder,
"tensor_shape": list(image.shape),
"tensor_dtype": str(image.dtype),
"exception_type": type(e).__name__,
"exception_details": str(e),
},
)
raise Exception(error_msg)
# 公共API导出
__all__ = [
# 工具函数
"construct_storage_key",
# 腾讯云COS节点
"COSDownload",
"COSUpload",
# Amazon S3节点
"S3Download",
"S3Upload",
"S3UploadURL",
"S3UploadIMAGEURL",
]