824 lines
27 KiB
Python
824 lines
27 KiB
Python
"""
|
||
对象存储节点模块
|
||
|
||
本模块提供ComfyUI自定义节点,支持多种云存储服务的文件上传和下载操作。
|
||
采用统一的存储抽象层,支持AWS S3和腾讯云COS等存储服务。
|
||
|
||
节点类型:
|
||
- COSDownload: 腾讯云COS文件下载
|
||
- COSUpload: 腾讯云COS文件上传
|
||
- S3Download: AWS S3文件下载
|
||
- S3Upload: AWS S3文件上传
|
||
- S3UploadURL: AWS S3文件上传并返回URL
|
||
- S3UploadIMAGEURL: AWS S3图像上传并返回URL
|
||
|
||
设计特性:
|
||
- 统一的错误处理和日志记录
|
||
- 支持重试机制
|
||
- 类型安全的参数和返回值
|
||
- 完整的文档和注释
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Optional, Tuple, Union
|
||
|
||
import loguru
|
||
import torch
|
||
|
||
from ..utils.object_storage import DownloadResult, UploadResult, get_provider
|
||
|
||
|
||
def construct_storage_key(
|
||
subfolder: str,
|
||
filename: Optional[str] = None,
|
||
path: Optional[Union[str, Path]] = None,
|
||
) -> str:
|
||
"""
|
||
构造存储键名的通用工具函数
|
||
|
||
该函数用于生成云存储中的对象键名,支持多种输入方式:
|
||
1. 直接指定文件名
|
||
2. 从本地文件路径中提取文件名
|
||
3. 自动生成基于时间戳的文件名
|
||
|
||
Args:
|
||
subfolder (str): 存储中的子文件夹路径
|
||
filename (Optional[str]): 指定的文件名,优先级最高
|
||
path (Optional[Union[str, Path]]): 本地文件路径,用于提取文件名
|
||
|
||
Returns:
|
||
str: 完整的存储键名,格式为 "subfolder/filename"
|
||
|
||
Examples:
|
||
>>> construct_storage_key("images", filename="photo.jpg")
|
||
'images/photo.jpg'
|
||
>>> construct_storage_key("docs", path="/home/user/document.pdf")
|
||
'docs/document.pdf'
|
||
>>> construct_storage_key("temp") # 自动生成时间戳文件名
|
||
'temp/file_1640995200'
|
||
"""
|
||
# 参数验证
|
||
if not subfolder:
|
||
raise ValueError("subfolder参数不能为空")
|
||
|
||
# 优先使用指定的文件名
|
||
if filename is not None:
|
||
if not filename or not filename.strip():
|
||
raise ValueError("filename参数不能为空字符串")
|
||
return f"{subfolder.strip('/')}/{filename.strip()}"
|
||
|
||
# 从路径中提取文件名
|
||
if path:
|
||
path_obj = Path(path)
|
||
extracted_filename = path_obj.name
|
||
if not extracted_filename:
|
||
raise ValueError(f"无法从路径中提取文件名: {path}")
|
||
return f"{subfolder.strip('/')}/{extracted_filename}"
|
||
|
||
# 生成基于时间戳的默认文件名
|
||
timestamp = int(time.time())
|
||
return f"{subfolder.strip('/')}/file_{timestamp}"
|
||
|
||
|
||
class COSDownload:
|
||
"""
|
||
腾讯云COS文件下载节点
|
||
|
||
提供从腾讯云COS存储服务下载文件到本地的功能。该节点基于统一的存储抽象层实现,
|
||
具有以下特性:
|
||
- 统一的错误处理和详细的日志记录
|
||
- 自动创建本地目录结构
|
||
- 完整的参数验证机制
|
||
- 支持各种文件类型的下载
|
||
|
||
节点配置:
|
||
- 输入: COS存储桶名称、文件键名
|
||
- 输出: 本地文件路径
|
||
- 类别: 腾讯云COS存储操作
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
"""定义节点输入类型"""
|
||
return {
|
||
"required": {
|
||
"cos_bucket": ("STRING", {"default": "bwkj-cos-1324682537"}),
|
||
"cos_key": ("STRING", {"multiline": True}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("文件存储路径",)
|
||
FUNCTION = "download"
|
||
CATEGORY = "不忘科技-自定义节点🛍/对象存储/COS"
|
||
|
||
def download(self, cos_bucket: str, cos_key: str) -> Tuple[str]:
|
||
"""
|
||
从腾讯云COS下载文件到本地
|
||
|
||
该方法执行完整的COS文件下载流程:
|
||
1. 验证输入参数
|
||
2. 构造本地保存路径
|
||
3. 创建必要的目录结构
|
||
4. 执行文件下载
|
||
5. 验证下载结果
|
||
|
||
Args:
|
||
cos_bucket (str): COS存储桶名称,不能为空
|
||
cos_key (str): COS中的文件键名,不能为空
|
||
|
||
Returns:
|
||
Tuple[str]: 包含本地文件完整路径的单元素元组
|
||
|
||
Raises:
|
||
ValueError: 当输入参数无效时
|
||
Exception: 当下载操作失败时
|
||
|
||
Examples:
|
||
>>> node = COSDownload()
|
||
>>> result = node.download("my-bucket", "images/photo.jpg")
|
||
>>> print(result[0]) # 输出本地文件路径
|
||
"""
|
||
# 输入参数验证
|
||
if not cos_bucket or not cos_bucket.strip():
|
||
raise ValueError("COS存储桶名称不能为空")
|
||
if not cos_key or not cos_key.strip():
|
||
raise ValueError("COS文件键名不能为空")
|
||
|
||
try:
|
||
# 构造安全的本地保存路径
|
||
cos_key_normalized = cos_key.strip().replace("/", os.sep)
|
||
base_dir = Path(__file__).parent.parent
|
||
download_dir = base_dir / "download"
|
||
|
||
# 构建完整的目标路径
|
||
destination_path = download_dir / cos_key_normalized
|
||
destination = str(destination_path)
|
||
|
||
# 确保目录存在
|
||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
loguru.logger.info(
|
||
f"开始从COS下载文件",
|
||
extra={
|
||
"bucket": cos_bucket,
|
||
"key": cos_key,
|
||
"destination": destination,
|
||
},
|
||
)
|
||
|
||
# 获取COS存储提供者实例
|
||
provider = get_provider("cos", {"bucket_name": cos_bucket})
|
||
|
||
# 执行文件下载操作
|
||
result: DownloadResult = provider.download_file(cos_key, destination)
|
||
|
||
# 检查下载结果
|
||
if not result.success:
|
||
error_msg = f"COS下载操作失败: {result.message}"
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"bucket": cos_bucket,
|
||
"key": cos_key,
|
||
"error_details": result.message,
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
loguru.logger.info(
|
||
f"COS文件下载成功",
|
||
extra={
|
||
"bucket": cos_bucket,
|
||
"key": cos_key,
|
||
"local_path": destination,
|
||
"file_size": (
|
||
os.path.getsize(destination)
|
||
if os.path.exists(destination)
|
||
else "未知"
|
||
),
|
||
},
|
||
)
|
||
|
||
return (destination,)
|
||
|
||
except ValueError:
|
||
# 重新抛出参数验证错误
|
||
raise
|
||
except Exception as e:
|
||
error_msg = (
|
||
f"COS下载过程中发生错误: bucket={cos_bucket}, "
|
||
f"key={cos_key}, error={str(e)}"
|
||
)
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"bucket": cos_bucket,
|
||
"key": cos_key,
|
||
"exception_type": type(e).__name__,
|
||
"exception_details": str(e),
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
|
||
class COSUpload:
|
||
"""
|
||
腾讯云COS文件上传节点
|
||
|
||
提供上传本地文件到腾讯云COS存储服务的功能。该节点基于统一的存储抽象层实现,
|
||
具有以下特性:
|
||
- 统一的错误处理和详细的日志记录
|
||
- 完整的文件存在性和权限验证
|
||
- 灵活的目标路径和命名策略
|
||
- 支持各种文件类型和大小
|
||
|
||
节点配置:
|
||
- 输入: COS存储桶名称、本地文件路径、子文件夹
|
||
- 输出: COS中的文件键名
|
||
- 类别: 腾讯云COS存储操作
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
"""定义节点输入类型"""
|
||
return {
|
||
"required": {
|
||
"cos_bucket": ("STRING", {"default": "bwkj-cos-1324682537"}),
|
||
"path": ("STRING", {"multiline": True}),
|
||
"subfolder": ("STRING", {"default": "test"}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("COS文件Key",)
|
||
FUNCTION = "upload"
|
||
CATEGORY = "不忘科技-自定义节点🛍/对象存储/COS"
|
||
|
||
def upload(self, cos_bucket: str, path: str, subfolder: str) -> Tuple[str]:
|
||
"""
|
||
上传本地文件到腾讯云COS存储
|
||
|
||
该方法执行完整的文件上传流程:
|
||
1. 验证输入参数和文件存在性
|
||
2. 构造目标存储键名
|
||
3. 执行文件上传操作
|
||
4. 返回上传后的文件键名
|
||
|
||
Args:
|
||
cos_bucket (str): COS存储桶名称,不能为空
|
||
path (str): 本地文件的完整路径,文件必须存在
|
||
subfolder (str): 目标子文件夹名称,不能为空
|
||
|
||
Returns:
|
||
Tuple[str]: 包含COS中文件键名的单元素元组
|
||
|
||
Raises:
|
||
ValueError: 当输入参数无效或文件不存在时
|
||
Exception: 当上传操作失败时
|
||
|
||
Examples:
|
||
>>> node = COSUpload()
|
||
>>> result = node.upload("my-bucket", "/local/image.jpg", "images")
|
||
>>> print(result[0]) # 输出: "images/image.jpg"
|
||
"""
|
||
# 输入参数验证
|
||
if not cos_bucket or not cos_bucket.strip():
|
||
raise ValueError("COS存储桶名称不能为空")
|
||
if not path or not path.strip():
|
||
raise ValueError("本地文件路径不能为空")
|
||
if not subfolder or not subfolder.strip():
|
||
raise ValueError("子文件夹名称不能为空")
|
||
|
||
# 文件存在性检查
|
||
file_path = Path(path)
|
||
if not file_path.exists():
|
||
raise ValueError(f"本地文件不存在: {path}")
|
||
if not file_path.is_file():
|
||
raise ValueError(f"指定路径不是文件: {path}")
|
||
|
||
try:
|
||
# 构造目标存储键名
|
||
dest_key = construct_storage_key(subfolder, path=path)
|
||
file_size = file_path.stat().st_size
|
||
|
||
loguru.logger.info(
|
||
"开始COS文件上传",
|
||
extra={
|
||
"local_path": path,
|
||
"bucket": cos_bucket,
|
||
"target_key": dest_key,
|
||
"file_size_bytes": file_size,
|
||
},
|
||
)
|
||
|
||
# 获取COS存储提供者实例
|
||
provider = get_provider("cos", {"bucket_name": cos_bucket})
|
||
|
||
# 执行文件上传操作
|
||
result: UploadResult = provider.upload_file(path, dest_key)
|
||
|
||
# 检查上传结果
|
||
if not result.success:
|
||
error_msg = f"COS上传操作失败: {result.message}"
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"local_path": path,
|
||
"bucket": cos_bucket,
|
||
"target_key": dest_key,
|
||
"error_details": result.message,
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
loguru.logger.info(
|
||
"COS文件上传成功",
|
||
extra={
|
||
"local_path": path,
|
||
"bucket": cos_bucket,
|
||
"uploaded_key": dest_key,
|
||
"file_size_bytes": file_size,
|
||
},
|
||
)
|
||
|
||
return (dest_key,)
|
||
|
||
except ValueError:
|
||
# 重新抛出参数验证错误
|
||
raise
|
||
except Exception as e:
|
||
error_msg = (
|
||
f"COS上传过程中发生错误: bucket={cos_bucket}, "
|
||
f"path={path}, subfolder={subfolder}, error={str(e)}"
|
||
)
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"local_path": path,
|
||
"bucket": cos_bucket,
|
||
"subfolder": subfolder,
|
||
"exception_type": type(e).__name__,
|
||
"exception_details": str(e),
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
|
||
class S3Download:
|
||
"""
|
||
AWS S3文件下载节点
|
||
|
||
提供从Amazon S3存储服务下载文件到本地的功能。该节点基于统一的存储抽象层实现,
|
||
具有以下特性:
|
||
- 统一的错误处理和详细的日志记录
|
||
- 自动创建本地目录结构
|
||
- 完整的参数验证机制
|
||
- 支持各种文件类型的下载
|
||
|
||
节点配置:
|
||
- 输入: S3存储桶名称、文件键名
|
||
- 输出: 本地文件路径
|
||
- 类别: AWS S3存储操作
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
"""定义节点输入类型"""
|
||
return {
|
||
"required": {
|
||
"s3_bucket": ("STRING", {"default": "bw-comfyui-input"}),
|
||
"s3_key": ("STRING", {"multiline": True}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("文件存储路径",)
|
||
FUNCTION = "download"
|
||
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
|
||
|
||
def download(self, s3_bucket: str, s3_key: str) -> Tuple[str]:
|
||
"""
|
||
从AWS S3存储服务下载文件到本地
|
||
|
||
执行完整的S3文件下载流程,包括参数验证、路径构造、
|
||
目录创建和文件下载等操作。
|
||
|
||
Args:
|
||
s3_bucket (str): S3存储桶名称,不能为空
|
||
s3_key (str): S3中的文件键名,不能为空
|
||
|
||
Returns:
|
||
Tuple[str]: 包含本地文件完整路径的单元素元组
|
||
|
||
Raises:
|
||
ValueError: 当输入参数无效时
|
||
Exception: 当下载操作失败时
|
||
"""
|
||
# 输入参数验证
|
||
if not s3_bucket or not s3_bucket.strip():
|
||
raise ValueError("S3存储桶名称不能为空")
|
||
if not s3_key or not s3_key.strip():
|
||
raise ValueError("S3文件键名不能为空")
|
||
|
||
try:
|
||
# 构造安全的本地保存路径
|
||
s3_key_normalized = s3_key.strip().replace("/", os.sep)
|
||
base_dir = Path(__file__).parent.parent
|
||
download_dir = base_dir / "download"
|
||
|
||
# 构建完整的目标路径
|
||
destination_path = download_dir / s3_key_normalized
|
||
destination = str(destination_path)
|
||
|
||
# 确保目录结构存在
|
||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
loguru.logger.info(
|
||
"开始从S3下载文件",
|
||
extra={"bucket": s3_bucket, "key": s3_key, "destination": destination},
|
||
)
|
||
|
||
# 获取S3存储提供者实例
|
||
provider = get_provider("s3", {"bucket_name": s3_bucket})
|
||
|
||
# 执行文件下载操作
|
||
result: DownloadResult = provider.download_file(s3_key, destination)
|
||
|
||
# 验证下载结果
|
||
if not result.success:
|
||
error_msg = f"S3下载操作失败: {result.message}"
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"bucket": s3_bucket,
|
||
"key": s3_key,
|
||
"error_details": result.message,
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
loguru.logger.info(
|
||
"S3文件下载成功",
|
||
extra={
|
||
"bucket": s3_bucket,
|
||
"key": s3_key,
|
||
"local_path": destination,
|
||
"file_size": (
|
||
os.path.getsize(destination)
|
||
if os.path.exists(destination)
|
||
else "未知"
|
||
),
|
||
},
|
||
)
|
||
|
||
return (destination,)
|
||
|
||
except ValueError:
|
||
# 重新抛出参数验证错误
|
||
raise
|
||
except Exception as e:
|
||
error_msg = (
|
||
f"S3下载过程中发生错误: bucket={s3_bucket}, "
|
||
f"key={s3_key}, error={str(e)}"
|
||
)
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"bucket": s3_bucket,
|
||
"key": s3_key,
|
||
"exception_type": type(e).__name__,
|
||
"exception_details": str(e),
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
|
||
class S3Upload:
|
||
"""
|
||
AWS S3文件上传节点
|
||
|
||
提供上传本地文件到Amazon S3存储服务的功能。该节点基于统一的存储抽象层实现,
|
||
具有以下特性:
|
||
- 统一的错误处理和详细的日志记录
|
||
- 完整的文件存在性和权限验证
|
||
- 灵活的目标路径和命名策略
|
||
- 支持各种文件类型和大小
|
||
|
||
节点配置:
|
||
- 输入: S3存储桶名称、本地文件路径、子文件夹
|
||
- 输出: S3中的文件键名
|
||
- 类别: Amazon S3存储操作
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
"""定义节点输入类型"""
|
||
return {
|
||
"required": {
|
||
"s3_bucket": ("STRING", {"default": "bw-comfyui-output"}),
|
||
"path": ("STRING", {"multiline": True}),
|
||
"subfolder": ("STRING", {"default": "test"}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("S3文件Key",)
|
||
FUNCTION = "upload"
|
||
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
|
||
|
||
def upload(self, s3_bucket: str, path: str, subfolder: str) -> Tuple[str]:
|
||
"""
|
||
上传文件到AWS S3
|
||
|
||
Args:
|
||
s3_bucket: S3存储桶名称
|
||
path: 本地文件路径
|
||
subfolder: 子文件夹名称
|
||
|
||
Returns:
|
||
Tuple[str]: 包含S3文件键名的元组
|
||
|
||
Raises:
|
||
Exception: 上传失败时抛出异常
|
||
"""
|
||
try:
|
||
# 构造存储键名
|
||
dest_key = construct_storage_key(subfolder, path=path)
|
||
|
||
loguru.logger.info(f"开始S3上传: {path} -> {s3_bucket}/{dest_key}")
|
||
|
||
# 获取S3存储提供者
|
||
provider = get_provider("s3", {"bucket_name": s3_bucket})
|
||
|
||
# 执行上传
|
||
result = provider.upload_file(path, dest_key)
|
||
|
||
if not result.success:
|
||
error_msg = f"S3上传失败: {result.message}"
|
||
loguru.logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
loguru.logger.info(f"S3上传成功: {dest_key}")
|
||
return (dest_key,)
|
||
|
||
except Exception as e:
|
||
error_msg = f"S3上传失败: bucket={s3_bucket}, path={path}, subfolder={subfolder}, error={str(e)}"
|
||
loguru.logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
|
||
class S3UploadURL:
|
||
"""
|
||
AWS S3文件上传节点(返回URL)
|
||
|
||
提供上传文件到Amazon S3并返回公开访问URL的功能。该节点专门为需要
|
||
即时分享文件的场景设计,具有以下特性:
|
||
- 自动使用预配置的媒体存储桶
|
||
- 支持CDN加速和全球访问
|
||
- 提供可靠的URL回退机制
|
||
- 统一的错误处理和日志记录
|
||
|
||
节点配置:
|
||
- 输入: 本地文件路径、子文件夹
|
||
- 输出: 公开访问URL
|
||
- 类别: Amazon S3存储操作
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
"""定义节点输入类型"""
|
||
return {
|
||
"required": {
|
||
"path": ("STRING", {"multiline": True}),
|
||
"subfolder": ("STRING", {"default": "test"}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("URL",)
|
||
FUNCTION = "upload"
|
||
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
|
||
|
||
def upload(self, path: str, subfolder: str) -> Tuple[str]:
|
||
"""
|
||
上传文件到AWS S3并返回访问URL
|
||
|
||
Args:
|
||
path: 本地文件路径
|
||
subfolder: 子文件夹名称
|
||
|
||
Returns:
|
||
Tuple[str]: 包含文件访问URL的元组
|
||
|
||
Raises:
|
||
Exception: 上传失败时抛出异常
|
||
"""
|
||
try:
|
||
# 使用默认的媒体缓存桶
|
||
s3_bucket = "modal-media-cache"
|
||
|
||
# 构造存储键名
|
||
dest_key = construct_storage_key(subfolder, path=path)
|
||
|
||
loguru.logger.info(f"开始S3上传: {path} -> {s3_bucket}/{dest_key}")
|
||
|
||
# 获取S3存储提供者
|
||
provider = get_provider("s3", {"bucket_name": s3_bucket})
|
||
|
||
# 执行上传
|
||
result = provider.upload_file(path, dest_key)
|
||
|
||
if not result.success:
|
||
error_msg = f"S3上传失败: {result.message}"
|
||
loguru.logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
# 返回CDN URL
|
||
url = result.url or f"https://cdn.roasmax.cn/{dest_key}"
|
||
|
||
loguru.logger.info(f"S3上传成功,URL: {url}")
|
||
return (url,)
|
||
|
||
except Exception as e:
|
||
error_msg = (
|
||
f"S3上传失败: path={path}, subfolder={subfolder}, error={str(e)}"
|
||
)
|
||
loguru.logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
|
||
class S3UploadIMAGEURL:
|
||
"""
|
||
AWS S3图像张量上传节点(返回URL)
|
||
|
||
提供上传PyTorch图像张量到Amazon S3并返回公开访问URL的专业功能。
|
||
该节点专为深度学习和图像处理工作流设计,具有以下特性:
|
||
- 全面的PyTorch张量校验和预处理
|
||
- 支持2D、3D、4D张量的自动识别和转换
|
||
- 内置NaN和无限大值检测
|
||
- 高质量PNG格式输出和CDN加速
|
||
- 自动生成唯一文件名防止冲突
|
||
|
||
节点配置:
|
||
- 输入: PyTorch图像张量、子文件夹
|
||
- 输出: 图像公开访问URL
|
||
- 类别: Amazon S3存储操作
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
"""定义节点输入类型"""
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE", {}),
|
||
"subfolder": ("STRING", {"default": "test"}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("URL",)
|
||
FUNCTION = "upload"
|
||
CATEGORY = "不忘科技-自定义节点🛍/对象存储/S3"
|
||
|
||
def upload(self, image: torch.Tensor, subfolder: str) -> Tuple[str]:
|
||
"""
|
||
上传PyTorch张量作为图像到AWS S3存储并返回可访问的URL
|
||
|
||
该方法专门处理图像张量的上传操作:
|
||
1. 验证输入张量的有效性和维度
|
||
2. 生成唯一的文件名(基于时间戳)
|
||
3. 将张量转换为指定格式并上传
|
||
4. 返回可公开访问的CDN URL
|
||
|
||
Args:
|
||
image (torch.Tensor): PyTorch图像张量,应为有效的图像数据
|
||
支持的维度格式:(H, W, C) 或 (B, H, W, C)
|
||
subfolder (str): 目标子文件夹名称,用于组织存储结构
|
||
|
||
Returns:
|
||
Tuple[str]: 包含图像可访问URL的单元素元组
|
||
|
||
Raises:
|
||
ValueError: 当输入参数无效或张量格式不正确时
|
||
Exception: 当上传操作失败时
|
||
|
||
Examples:
|
||
>>> import torch
|
||
>>> node = S3UploadIMAGEURL()
|
||
>>> tensor = torch.randn(256, 256, 3) # RGB图像
|
||
>>> result = node.upload(tensor, "generated_images")
|
||
>>> print(result[0]) # 输出CDN URL
|
||
"""
|
||
# 输入参数验证
|
||
if not isinstance(image, torch.Tensor):
|
||
raise ValueError("输入必须是PyTorch张量")
|
||
if not subfolder or not subfolder.strip():
|
||
raise ValueError("子文件夹名称不能为空")
|
||
|
||
# 张量维度验证
|
||
if image.dim() < 2 or image.dim() > 4:
|
||
raise ValueError(f"不支持的张量维度: {image.dim()}D,支持2D、3D或4D张量")
|
||
|
||
# 张量数值验证
|
||
if torch.isnan(image).any():
|
||
raise ValueError("张量包含NaN值")
|
||
if torch.isinf(image).any():
|
||
raise ValueError("张量包含无限大值")
|
||
|
||
try:
|
||
# 使用配置的默认媒体存储桶
|
||
s3_bucket = "modal-media-cache"
|
||
|
||
# 生成基于时间戳的唯一文件名
|
||
timestamp = int(time.time())
|
||
filename = f"image_{timestamp}.png"
|
||
|
||
# 构造完整的存储键名
|
||
dest_key = f"{subfolder.strip('/')}/{filename}"
|
||
|
||
loguru.logger.info(
|
||
"开始S3图像张量上传",
|
||
extra={
|
||
"tensor_shape": list(image.shape),
|
||
"tensor_dtype": str(image.dtype),
|
||
"bucket": s3_bucket,
|
||
"target_key": dest_key,
|
||
"tensor_size_mb": image.numel()
|
||
* image.element_size()
|
||
/ (1024 * 1024),
|
||
},
|
||
)
|
||
|
||
# 获取S3存储提供者实例
|
||
provider = get_provider("s3", {"bucket_name": s3_bucket})
|
||
|
||
# 执行张量上传操作(转换为PNG格式)
|
||
result: UploadResult = provider.upload_tensor(image, dest_key, format="PNG")
|
||
|
||
# 验证上传结果
|
||
if not result.success:
|
||
error_msg = f"S3图像上传操作失败: {result.message}"
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"tensor_shape": list(image.shape),
|
||
"bucket": s3_bucket,
|
||
"target_key": dest_key,
|
||
"error_details": result.message,
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
# 构造最终的访问URL(优先使用返回的URL,否则使用默认CDN)
|
||
final_url = result.url or f"https://cdn.roasmax.cn/{dest_key}"
|
||
|
||
loguru.logger.info(
|
||
"S3图像张量上传成功",
|
||
extra={
|
||
"tensor_shape": list(image.shape),
|
||
"bucket": s3_bucket,
|
||
"uploaded_key": dest_key,
|
||
"access_url": final_url,
|
||
},
|
||
)
|
||
|
||
return (final_url,)
|
||
|
||
except ValueError:
|
||
# 重新抛出参数验证错误
|
||
raise
|
||
except Exception as e:
|
||
error_msg = (
|
||
f"S3图像上传过程中发生错误: subfolder={subfolder}, "
|
||
f"张量形状={list(image.shape)}, error={str(e)}"
|
||
)
|
||
loguru.logger.error(
|
||
error_msg,
|
||
extra={
|
||
"subfolder": subfolder,
|
||
"tensor_shape": list(image.shape),
|
||
"tensor_dtype": str(image.dtype),
|
||
"exception_type": type(e).__name__,
|
||
"exception_details": str(e),
|
||
},
|
||
)
|
||
raise Exception(error_msg)
|
||
|
||
|
||
# 公共API导出
|
||
__all__ = [
|
||
# 工具函数
|
||
"construct_storage_key",
|
||
# 腾讯云COS节点
|
||
"COSDownload",
|
||
"COSUpload",
|
||
# Amazon S3节点
|
||
"S3Download",
|
||
"S3Upload",
|
||
"S3UploadURL",
|
||
"S3UploadIMAGEURL",
|
||
]
|