ComfyUI-CustomNode/utils/object_storage/providers/s3_provider.py

503 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AWS S3存储提供者实现
本模块实现了AWS S3的具体存储操作继承自抽象存储接口。
提供完整的S3文件上传、下载、删除等功能。
特性:
- 支持文件、字节数据、PyTorch张量的上传
- 支持预签名URL生成
- 完整的错误处理和重试机制
- 统一的日志记录
"""
import os
import boto3
from typing import Optional, Dict, Any, List
import torch
import loguru
from ..storage_interface import (
StorageProvider,
StorageFactory,
UploadResult,
DownloadResult,
)
from ...config_utils import config
class S3StorageProvider(StorageProvider):
"""
AWS S3存储提供者实现
实现了StorageProvider接口的所有方法提供完整的S3存储功能。
采用懒加载模式初始化S3客户端提高性能。
"""
def __init__(self, config: Dict[str, Any]):
"""
初始化S3存储提供者
Args:
config: S3配置字典必须包含access_key_id和secret_access_key
"""
super().__init__(config)
self.bucket_name = config.get("bucket_name", "modal-media-cache")
self.region = config.get("region", "us-east-1")
self.cdn_base_url = config.get("cdn_base_url", "https://cdn.roasmax.cn")
self._client = None
def _validate_config(self) -> None:
"""
验证S3配置的完整性
Raises:
ValueError: 配置信息缺失时抛出异常
"""
required_keys = ["access_key_id", "secret_access_key"]
missing_keys = [key for key in required_keys if not self.config.get(key)]
if missing_keys:
raise ValueError(
f"S3配置缺失必要参数: {missing_keys}. " f"请检查配置文件或环境变量"
)
@property
def client(self):
"""
获取S3客户端实例懒加载模式
Returns:
boto3.client: S3客户端实例
"""
if self._client is None:
try:
self._client = boto3.client(
"s3",
aws_access_key_id=self.config["access_key_id"],
aws_secret_access_key=self.config["secret_access_key"],
region_name=self.region,
)
loguru.logger.info(f"S3客户端初始化成功区域: {self.region}")
except Exception as e:
loguru.logger.error(f"S3客户端初始化失败: {e}")
raise
return self._client
def upload_file(
self,
local_path: str,
remote_key: str,
content_type: Optional[str] = None,
**kwargs,
) -> UploadResult:
"""
上传本地文件到S3
Args:
local_path: 本地文件路径
remote_key: S3中的键名
content_type: 文件内容类型
**kwargs: 额外的上传参数
Returns:
UploadResult: 上传操作结果
"""
try:
if not os.path.exists(local_path):
error_msg = f"本地文件不存在: {local_path}"
loguru.logger.error(error_msg)
return UploadResult(
success=False,
key=remote_key,
message=error_msg,
error=FileNotFoundError(error_msg),
)
file_size = os.path.getsize(local_path)
loguru.logger.info(
f"开始上传文件到S3: {local_path} -> s3://{self.bucket_name}/{remote_key} "
f"({file_size} bytes)"
)
extra_args = kwargs.copy()
if content_type:
extra_args["ContentType"] = content_type
self.client.upload_file(
local_path,
self.bucket_name,
remote_key,
ExtraArgs=extra_args if extra_args else None,
)
url = f"{self.cdn_base_url}/{remote_key}"
success_msg = f"文件上传成功: s3://{self.bucket_name}/{remote_key}"
loguru.logger.info(success_msg)
return UploadResult(
success=True,
key=remote_key,
url=url,
size=file_size,
message=success_msg,
)
except Exception as e:
error_msg = f"S3文件上传失败: {str(e)}"
loguru.logger.error(f"{error_msg} (文件: {local_path}, 键: {remote_key})")
return UploadResult(
success=False, key=remote_key, message=error_msg, error=e
)
def upload_bytes(
self, data: bytes, remote_key: str, content_type: Optional[str] = None, **kwargs
) -> UploadResult:
"""
上传字节数据到S3
Args:
data: 字节数据
remote_key: S3中的键名
content_type: 文件内容类型
**kwargs: 额外的上传参数
Returns:
UploadResult: 上传操作结果
"""
try:
loguru.logger.info(
f"开始上传字节数据到S3: s3://{self.bucket_name}/{remote_key} "
f"({len(data)} bytes)"
)
extra_args = kwargs.copy()
if content_type:
extra_args["ContentType"] = content_type
self.client.put_object(
Bucket=self.bucket_name, Key=remote_key, Body=data, **extra_args
)
url = f"{self.cdn_base_url}/{remote_key}"
success_msg = f"字节数据上传成功: s3://{self.bucket_name}/{remote_key}"
loguru.logger.info(success_msg)
return UploadResult(
success=True,
key=remote_key,
url=url,
size=len(data),
message=success_msg,
)
except Exception as e:
error_msg = f"S3字节数据上传失败: {str(e)}"
loguru.logger.error(
f"{error_msg} (键: {remote_key}, 大小: {len(data)} bytes)"
)
return UploadResult(
success=False, key=remote_key, message=error_msg, error=e
)
def upload_tensor(
self, tensor: torch.Tensor, remote_key: str, format: str = "PNG", **kwargs
) -> UploadResult:
"""
上传PyTorch张量作为图像到S3
Args:
tensor: PyTorch张量
remote_key: S3中的键名
format: 图像格式PNG, JPEG等
**kwargs: 额外的上传参数
Returns:
UploadResult: 上传操作结果
"""
try:
from ...image_utils import tensor_to_tempfile
loguru.logger.info(
f"开始上传张量到S3: s3://{self.bucket_name}/{remote_key} "
f"(形状: {tensor.shape}, 格式: {format})"
)
# 将张量转换为临时文件
temp_file = tensor_to_tempfile(tensor, format=format)
temp_path = temp_file.name
try:
# 设置内容类型
content_type = f"image/{format.lower()}"
if format.upper() == "JPEG":
content_type = "image/jpeg"
# 上传临时文件
result = self.upload_file(
temp_path, remote_key, content_type=content_type, **kwargs
)
if result.success:
loguru.logger.info(
f"张量上传成功: s3://{self.bucket_name}/{remote_key}"
)
return result
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.unlink(temp_path)
loguru.logger.debug(f"临时文件已清理: {temp_path}")
except Exception as e:
error_msg = f"S3张量上传失败: {str(e)}"
loguru.logger.error(
f"{error_msg} (键: {remote_key}, 张量形状: {tensor.shape})"
)
return UploadResult(
success=False, key=remote_key, message=error_msg, error=e
)
def download_file(
self, remote_key: str, local_path: str, **kwargs
) -> DownloadResult:
"""
从S3下载文件到本地
Args:
remote_key: S3中的键名
local_path: 本地保存路径
**kwargs: 额外的下载参数
Returns:
DownloadResult: 下载操作结果
"""
try:
loguru.logger.info(
f"开始从S3下载文件: s3://{self.bucket_name}/{remote_key} -> {local_path}"
)
# 确保本地目录存在
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# 检查文件是否存在
if not self.file_exists(remote_key):
error_msg = f"S3文件不存在: s3://{self.bucket_name}/{remote_key}"
loguru.logger.error(error_msg)
return DownloadResult(
success=False, message=error_msg, error=FileNotFoundError(error_msg)
)
self.client.download_file(self.bucket_name, remote_key, local_path)
success_msg = f"文件下载成功: {local_path}"
loguru.logger.info(success_msg)
return DownloadResult(
success=True, local_path=local_path, message=success_msg
)
except Exception as e:
error_msg = f"S3文件下载失败: {str(e)}"
loguru.logger.error(
f"{error_msg} (键: {remote_key}, 本地路径: {local_path})"
)
return DownloadResult(success=False, message=error_msg, error=e)
def download_bytes(self, remote_key: str, **kwargs) -> DownloadResult:
"""
从S3下载文件为字节数据
Args:
remote_key: S3中的键名
**kwargs: 额外的下载参数
Returns:
DownloadResult: 下载操作结果数据包含在data字段中
"""
try:
loguru.logger.info(
f"开始从S3下载字节数据: s3://{self.bucket_name}/{remote_key}"
)
# 检查文件是否存在
if not self.file_exists(remote_key):
error_msg = f"S3文件不存在: s3://{self.bucket_name}/{remote_key}"
loguru.logger.error(error_msg)
return DownloadResult(
success=False, message=error_msg, error=FileNotFoundError(error_msg)
)
response = self.client.get_object(Bucket=self.bucket_name, Key=remote_key)
data = response["Body"].read()
success_msg = f"字节数据下载成功: {len(data)} bytes"
loguru.logger.info(success_msg)
return DownloadResult(success=True, data=data, message=success_msg)
except Exception as e:
error_msg = f"S3字节数据下载失败: {str(e)}"
loguru.logger.error(f"{error_msg} (键: {remote_key})")
return DownloadResult(success=False, message=error_msg, error=e)
def delete_file(self, remote_key: str, **kwargs) -> bool:
"""
删除S3中的文件
Args:
remote_key: S3中的键名
**kwargs: 额外的删除参数
Returns:
bool: 删除是否成功
"""
try:
loguru.logger.info(f"删除S3文件: s3://{self.bucket_name}/{remote_key}")
self.client.delete_object(Bucket=self.bucket_name, Key=remote_key)
loguru.logger.info(f"文件删除成功: s3://{self.bucket_name}/{remote_key}")
return True
except Exception as e:
loguru.logger.error(f"S3文件删除失败: {e} (键: {remote_key})")
return False
def file_exists(self, remote_key: str, **kwargs) -> bool:
"""
检查S3中文件是否存在
Args:
remote_key: S3中的键名
**kwargs: 额外的检查参数
Returns:
bool: 文件是否存在
"""
try:
self.client.head_object(Bucket=self.bucket_name, Key=remote_key)
return True
except Exception:
return False
def list_files(
self, prefix: str = "", max_keys: int = 1000, **kwargs
) -> List[Dict[str, Any]]:
"""
列出S3中的文件
Args:
prefix: 文件前缀过滤
max_keys: 最大返回数量
**kwargs: 额外的列表参数
Returns:
List[Dict[str, Any]]: 文件信息列表
"""
try:
loguru.logger.info(
f"列出S3文件: s3://{self.bucket_name}/{prefix} (最大: {max_keys})"
)
response = self.client.list_objects_v2(
Bucket=self.bucket_name, Prefix=prefix, MaxKeys=max_keys, **kwargs
)
files = []
if "Contents" in response:
for obj in response["Contents"]:
files.append(
{
"key": obj["Key"],
"size": obj["Size"],
"last_modified": obj["LastModified"],
"etag": obj.get("ETag", "").strip('"'),
"storage_class": obj.get("StorageClass", "STANDARD"),
}
)
loguru.logger.info(f"找到 {len(files)} 个文件")
return files
except Exception as e:
loguru.logger.error(f"S3文件列表获取失败: {e} (前缀: {prefix})")
return []
def get_file_url(self, remote_key: str, expires_in: int = 3600, **kwargs) -> str:
"""
获取S3文件的访问URL
Args:
remote_key: S3中的键名
expires_in: URL过期时间
**kwargs: 额外的URL生成参数
Returns:
str: 文件访问URL
"""
try:
# 如果配置了CDN优先返回CDN URL
if self.cdn_base_url and expires_in > 3600: # 长期URL使用CDN
return f"{self.cdn_base_url}/{remote_key}"
# 生成预签名URL
url = self.client.generate_presigned_url(
"get_object",
Params={"Bucket": self.bucket_name, "Key": remote_key},
ExpiresIn=expires_in,
**kwargs,
)
return url
except Exception as e:
loguru.logger.error(f"S3 URL生成失败: {e} (键: {remote_key})")
# 如果预签名URL生成失败返回CDN URL作为备选
if self.cdn_base_url:
return f"{self.cdn_base_url}/{remote_key}"
raise
class S3StorageFactory(StorageFactory):
"""
S3存储工厂实现
负责创建S3存储提供者实例处理配置验证和初始化。
"""
def create_provider(self, config_dict: Dict[str, Any]) -> StorageProvider:
"""
创建S3存储提供者实例
Args:
config_dict: S3配置字典
Returns:
StorageProvider: S3存储提供者实例
"""
# 如果没有提供配置,尝试从全局配置获取
if not config_dict or not config_dict.get("access_key_id"):
if config.has_aws_config():
aws_config = config.get_aws_config()
config_dict = {
"access_key_id": aws_config["access_key_id"],
"secret_access_key": aws_config["secret_access_key"],
**config_dict,
}
else:
raise ValueError("未提供有效的S3配置且全局配置中也没有AWS配置")
return S3StorageProvider(config_dict)
def get_supported_types(self) -> List[str]:
"""
获取支持的存储类型列表
Returns:
List[str]: 支持的存储类型
"""
return ["s3", "aws", "amazon"]