503 lines
16 KiB
Python
503 lines
16 KiB
Python
"""
|
||
AWS S3存储提供者实现
|
||
|
||
本模块实现了AWS S3的具体存储操作,继承自抽象存储接口。
|
||
提供完整的S3文件上传、下载、删除等功能。
|
||
|
||
特性:
|
||
- 支持文件、字节数据、PyTorch张量的上传
|
||
- 支持预签名URL生成
|
||
- 完整的错误处理和重试机制
|
||
- 统一的日志记录
|
||
"""
|
||
|
||
import os
|
||
import boto3
|
||
from typing import Optional, Dict, Any, List
|
||
import torch
|
||
import loguru
|
||
from ..storage_interface import (
|
||
StorageProvider,
|
||
StorageFactory,
|
||
UploadResult,
|
||
DownloadResult,
|
||
)
|
||
from ...config_utils import config
|
||
|
||
|
||
class S3StorageProvider(StorageProvider):
|
||
"""
|
||
AWS S3存储提供者实现
|
||
|
||
实现了StorageProvider接口的所有方法,提供完整的S3存储功能。
|
||
采用懒加载模式初始化S3客户端,提高性能。
|
||
"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
"""
|
||
初始化S3存储提供者
|
||
|
||
Args:
|
||
config: S3配置字典,必须包含access_key_id和secret_access_key
|
||
"""
|
||
super().__init__(config)
|
||
self.bucket_name = config.get("bucket_name", "modal-media-cache")
|
||
self.region = config.get("region", "us-east-1")
|
||
self.cdn_base_url = config.get("cdn_base_url", "https://cdn.roasmax.cn")
|
||
self._client = None
|
||
|
||
def _validate_config(self) -> None:
|
||
"""
|
||
验证S3配置的完整性
|
||
|
||
Raises:
|
||
ValueError: 配置信息缺失时抛出异常
|
||
"""
|
||
required_keys = ["access_key_id", "secret_access_key"]
|
||
missing_keys = [key for key in required_keys if not self.config.get(key)]
|
||
|
||
if missing_keys:
|
||
raise ValueError(
|
||
f"S3配置缺失必要参数: {missing_keys}. " f"请检查配置文件或环境变量"
|
||
)
|
||
|
||
@property
|
||
def client(self):
|
||
"""
|
||
获取S3客户端实例(懒加载模式)
|
||
|
||
Returns:
|
||
boto3.client: S3客户端实例
|
||
"""
|
||
if self._client is None:
|
||
try:
|
||
self._client = boto3.client(
|
||
"s3",
|
||
aws_access_key_id=self.config["access_key_id"],
|
||
aws_secret_access_key=self.config["secret_access_key"],
|
||
region_name=self.region,
|
||
)
|
||
loguru.logger.info(f"S3客户端初始化成功,区域: {self.region}")
|
||
except Exception as e:
|
||
loguru.logger.error(f"S3客户端初始化失败: {e}")
|
||
raise
|
||
|
||
return self._client
|
||
|
||
def upload_file(
|
||
self,
|
||
local_path: str,
|
||
remote_key: str,
|
||
content_type: Optional[str] = None,
|
||
**kwargs,
|
||
) -> UploadResult:
|
||
"""
|
||
上传本地文件到S3
|
||
|
||
Args:
|
||
local_path: 本地文件路径
|
||
remote_key: S3中的键名
|
||
content_type: 文件内容类型
|
||
**kwargs: 额外的上传参数
|
||
|
||
Returns:
|
||
UploadResult: 上传操作结果
|
||
"""
|
||
try:
|
||
if not os.path.exists(local_path):
|
||
error_msg = f"本地文件不存在: {local_path}"
|
||
loguru.logger.error(error_msg)
|
||
return UploadResult(
|
||
success=False,
|
||
key=remote_key,
|
||
message=error_msg,
|
||
error=FileNotFoundError(error_msg),
|
||
)
|
||
|
||
file_size = os.path.getsize(local_path)
|
||
loguru.logger.info(
|
||
f"开始上传文件到S3: {local_path} -> s3://{self.bucket_name}/{remote_key} "
|
||
f"({file_size} bytes)"
|
||
)
|
||
|
||
extra_args = kwargs.copy()
|
||
if content_type:
|
||
extra_args["ContentType"] = content_type
|
||
|
||
self.client.upload_file(
|
||
local_path,
|
||
self.bucket_name,
|
||
remote_key,
|
||
ExtraArgs=extra_args if extra_args else None,
|
||
)
|
||
|
||
url = f"{self.cdn_base_url}/{remote_key}"
|
||
success_msg = f"文件上传成功: s3://{self.bucket_name}/{remote_key}"
|
||
loguru.logger.info(success_msg)
|
||
|
||
return UploadResult(
|
||
success=True,
|
||
key=remote_key,
|
||
url=url,
|
||
size=file_size,
|
||
message=success_msg,
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"S3文件上传失败: {str(e)}"
|
||
loguru.logger.error(f"{error_msg} (文件: {local_path}, 键: {remote_key})")
|
||
return UploadResult(
|
||
success=False, key=remote_key, message=error_msg, error=e
|
||
)
|
||
|
||
def upload_bytes(
|
||
self, data: bytes, remote_key: str, content_type: Optional[str] = None, **kwargs
|
||
) -> UploadResult:
|
||
"""
|
||
上传字节数据到S3
|
||
|
||
Args:
|
||
data: 字节数据
|
||
remote_key: S3中的键名
|
||
content_type: 文件内容类型
|
||
**kwargs: 额外的上传参数
|
||
|
||
Returns:
|
||
UploadResult: 上传操作结果
|
||
"""
|
||
try:
|
||
loguru.logger.info(
|
||
f"开始上传字节数据到S3: s3://{self.bucket_name}/{remote_key} "
|
||
f"({len(data)} bytes)"
|
||
)
|
||
|
||
extra_args = kwargs.copy()
|
||
if content_type:
|
||
extra_args["ContentType"] = content_type
|
||
|
||
self.client.put_object(
|
||
Bucket=self.bucket_name, Key=remote_key, Body=data, **extra_args
|
||
)
|
||
|
||
url = f"{self.cdn_base_url}/{remote_key}"
|
||
success_msg = f"字节数据上传成功: s3://{self.bucket_name}/{remote_key}"
|
||
loguru.logger.info(success_msg)
|
||
|
||
return UploadResult(
|
||
success=True,
|
||
key=remote_key,
|
||
url=url,
|
||
size=len(data),
|
||
message=success_msg,
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"S3字节数据上传失败: {str(e)}"
|
||
loguru.logger.error(
|
||
f"{error_msg} (键: {remote_key}, 大小: {len(data)} bytes)"
|
||
)
|
||
return UploadResult(
|
||
success=False, key=remote_key, message=error_msg, error=e
|
||
)
|
||
|
||
def upload_tensor(
|
||
self, tensor: torch.Tensor, remote_key: str, format: str = "PNG", **kwargs
|
||
) -> UploadResult:
|
||
"""
|
||
上传PyTorch张量作为图像到S3
|
||
|
||
Args:
|
||
tensor: PyTorch张量
|
||
remote_key: S3中的键名
|
||
format: 图像格式(PNG, JPEG等)
|
||
**kwargs: 额外的上传参数
|
||
|
||
Returns:
|
||
UploadResult: 上传操作结果
|
||
"""
|
||
try:
|
||
from ...image_utils import tensor_to_tempfile
|
||
|
||
loguru.logger.info(
|
||
f"开始上传张量到S3: s3://{self.bucket_name}/{remote_key} "
|
||
f"(形状: {tensor.shape}, 格式: {format})"
|
||
)
|
||
|
||
# 将张量转换为临时文件
|
||
temp_file = tensor_to_tempfile(tensor, format=format)
|
||
temp_path = temp_file.name
|
||
|
||
try:
|
||
# 设置内容类型
|
||
content_type = f"image/{format.lower()}"
|
||
if format.upper() == "JPEG":
|
||
content_type = "image/jpeg"
|
||
|
||
# 上传临时文件
|
||
result = self.upload_file(
|
||
temp_path, remote_key, content_type=content_type, **kwargs
|
||
)
|
||
|
||
if result.success:
|
||
loguru.logger.info(
|
||
f"张量上传成功: s3://{self.bucket_name}/{remote_key}"
|
||
)
|
||
|
||
return result
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.unlink(temp_path)
|
||
loguru.logger.debug(f"临时文件已清理: {temp_path}")
|
||
|
||
except Exception as e:
|
||
error_msg = f"S3张量上传失败: {str(e)}"
|
||
loguru.logger.error(
|
||
f"{error_msg} (键: {remote_key}, 张量形状: {tensor.shape})"
|
||
)
|
||
return UploadResult(
|
||
success=False, key=remote_key, message=error_msg, error=e
|
||
)
|
||
|
||
def download_file(
|
||
self, remote_key: str, local_path: str, **kwargs
|
||
) -> DownloadResult:
|
||
"""
|
||
从S3下载文件到本地
|
||
|
||
Args:
|
||
remote_key: S3中的键名
|
||
local_path: 本地保存路径
|
||
**kwargs: 额外的下载参数
|
||
|
||
Returns:
|
||
DownloadResult: 下载操作结果
|
||
"""
|
||
try:
|
||
loguru.logger.info(
|
||
f"开始从S3下载文件: s3://{self.bucket_name}/{remote_key} -> {local_path}"
|
||
)
|
||
|
||
# 确保本地目录存在
|
||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||
|
||
# 检查文件是否存在
|
||
if not self.file_exists(remote_key):
|
||
error_msg = f"S3文件不存在: s3://{self.bucket_name}/{remote_key}"
|
||
loguru.logger.error(error_msg)
|
||
return DownloadResult(
|
||
success=False, message=error_msg, error=FileNotFoundError(error_msg)
|
||
)
|
||
|
||
self.client.download_file(self.bucket_name, remote_key, local_path)
|
||
|
||
success_msg = f"文件下载成功: {local_path}"
|
||
loguru.logger.info(success_msg)
|
||
|
||
return DownloadResult(
|
||
success=True, local_path=local_path, message=success_msg
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"S3文件下载失败: {str(e)}"
|
||
loguru.logger.error(
|
||
f"{error_msg} (键: {remote_key}, 本地路径: {local_path})"
|
||
)
|
||
return DownloadResult(success=False, message=error_msg, error=e)
|
||
|
||
def download_bytes(self, remote_key: str, **kwargs) -> DownloadResult:
|
||
"""
|
||
从S3下载文件为字节数据
|
||
|
||
Args:
|
||
remote_key: S3中的键名
|
||
**kwargs: 额外的下载参数
|
||
|
||
Returns:
|
||
DownloadResult: 下载操作结果,数据包含在data字段中
|
||
"""
|
||
try:
|
||
loguru.logger.info(
|
||
f"开始从S3下载字节数据: s3://{self.bucket_name}/{remote_key}"
|
||
)
|
||
|
||
# 检查文件是否存在
|
||
if not self.file_exists(remote_key):
|
||
error_msg = f"S3文件不存在: s3://{self.bucket_name}/{remote_key}"
|
||
loguru.logger.error(error_msg)
|
||
return DownloadResult(
|
||
success=False, message=error_msg, error=FileNotFoundError(error_msg)
|
||
)
|
||
|
||
response = self.client.get_object(Bucket=self.bucket_name, Key=remote_key)
|
||
data = response["Body"].read()
|
||
|
||
success_msg = f"字节数据下载成功: {len(data)} bytes"
|
||
loguru.logger.info(success_msg)
|
||
|
||
return DownloadResult(success=True, data=data, message=success_msg)
|
||
|
||
except Exception as e:
|
||
error_msg = f"S3字节数据下载失败: {str(e)}"
|
||
loguru.logger.error(f"{error_msg} (键: {remote_key})")
|
||
return DownloadResult(success=False, message=error_msg, error=e)
|
||
|
||
def delete_file(self, remote_key: str, **kwargs) -> bool:
|
||
"""
|
||
删除S3中的文件
|
||
|
||
Args:
|
||
remote_key: S3中的键名
|
||
**kwargs: 额外的删除参数
|
||
|
||
Returns:
|
||
bool: 删除是否成功
|
||
"""
|
||
try:
|
||
loguru.logger.info(f"删除S3文件: s3://{self.bucket_name}/{remote_key}")
|
||
|
||
self.client.delete_object(Bucket=self.bucket_name, Key=remote_key)
|
||
|
||
loguru.logger.info(f"文件删除成功: s3://{self.bucket_name}/{remote_key}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
loguru.logger.error(f"S3文件删除失败: {e} (键: {remote_key})")
|
||
return False
|
||
|
||
def file_exists(self, remote_key: str, **kwargs) -> bool:
|
||
"""
|
||
检查S3中文件是否存在
|
||
|
||
Args:
|
||
remote_key: S3中的键名
|
||
**kwargs: 额外的检查参数
|
||
|
||
Returns:
|
||
bool: 文件是否存在
|
||
"""
|
||
try:
|
||
self.client.head_object(Bucket=self.bucket_name, Key=remote_key)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
def list_files(
|
||
self, prefix: str = "", max_keys: int = 1000, **kwargs
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
列出S3中的文件
|
||
|
||
Args:
|
||
prefix: 文件前缀过滤
|
||
max_keys: 最大返回数量
|
||
**kwargs: 额外的列表参数
|
||
|
||
Returns:
|
||
List[Dict[str, Any]]: 文件信息列表
|
||
"""
|
||
try:
|
||
loguru.logger.info(
|
||
f"列出S3文件: s3://{self.bucket_name}/{prefix} (最大: {max_keys})"
|
||
)
|
||
|
||
response = self.client.list_objects_v2(
|
||
Bucket=self.bucket_name, Prefix=prefix, MaxKeys=max_keys, **kwargs
|
||
)
|
||
|
||
files = []
|
||
if "Contents" in response:
|
||
for obj in response["Contents"]:
|
||
files.append(
|
||
{
|
||
"key": obj["Key"],
|
||
"size": obj["Size"],
|
||
"last_modified": obj["LastModified"],
|
||
"etag": obj.get("ETag", "").strip('"'),
|
||
"storage_class": obj.get("StorageClass", "STANDARD"),
|
||
}
|
||
)
|
||
|
||
loguru.logger.info(f"找到 {len(files)} 个文件")
|
||
return files
|
||
|
||
except Exception as e:
|
||
loguru.logger.error(f"S3文件列表获取失败: {e} (前缀: {prefix})")
|
||
return []
|
||
|
||
def get_file_url(self, remote_key: str, expires_in: int = 3600, **kwargs) -> str:
|
||
"""
|
||
获取S3文件的访问URL
|
||
|
||
Args:
|
||
remote_key: S3中的键名
|
||
expires_in: URL过期时间(秒)
|
||
**kwargs: 额外的URL生成参数
|
||
|
||
Returns:
|
||
str: 文件访问URL
|
||
"""
|
||
try:
|
||
# 如果配置了CDN,优先返回CDN URL
|
||
if self.cdn_base_url and expires_in > 3600: # 长期URL使用CDN
|
||
return f"{self.cdn_base_url}/{remote_key}"
|
||
|
||
# 生成预签名URL
|
||
url = self.client.generate_presigned_url(
|
||
"get_object",
|
||
Params={"Bucket": self.bucket_name, "Key": remote_key},
|
||
ExpiresIn=expires_in,
|
||
**kwargs,
|
||
)
|
||
|
||
return url
|
||
|
||
except Exception as e:
|
||
loguru.logger.error(f"S3 URL生成失败: {e} (键: {remote_key})")
|
||
# 如果预签名URL生成失败,返回CDN URL作为备选
|
||
if self.cdn_base_url:
|
||
return f"{self.cdn_base_url}/{remote_key}"
|
||
raise
|
||
|
||
|
||
class S3StorageFactory(StorageFactory):
|
||
"""
|
||
S3存储工厂实现
|
||
|
||
负责创建S3存储提供者实例,处理配置验证和初始化。
|
||
"""
|
||
|
||
def create_provider(self, config_dict: Dict[str, Any]) -> StorageProvider:
|
||
"""
|
||
创建S3存储提供者实例
|
||
|
||
Args:
|
||
config_dict: S3配置字典
|
||
|
||
Returns:
|
||
StorageProvider: S3存储提供者实例
|
||
"""
|
||
# 如果没有提供配置,尝试从全局配置获取
|
||
if not config_dict or not config_dict.get("access_key_id"):
|
||
if config.has_aws_config():
|
||
aws_config = config.get_aws_config()
|
||
config_dict = {
|
||
"access_key_id": aws_config["access_key_id"],
|
||
"secret_access_key": aws_config["secret_access_key"],
|
||
**config_dict,
|
||
}
|
||
else:
|
||
raise ValueError("未提供有效的S3配置,且全局配置中也没有AWS配置")
|
||
|
||
return S3StorageProvider(config_dict)
|
||
|
||
def get_supported_types(self) -> List[str]:
|
||
"""
|
||
获取支持的存储类型列表
|
||
|
||
Returns:
|
||
List[str]: 支持的存储类型
|
||
"""
|
||
return ["s3", "aws", "amazon"]
|