""" AWS S3存储提供者实现 本模块实现了AWS S3的具体存储操作,继承自抽象存储接口。 提供完整的S3文件上传、下载、删除等功能。 特性: - 支持文件、字节数据、PyTorch张量的上传 - 支持预签名URL生成 - 完整的错误处理和重试机制 - 统一的日志记录 """ import os import boto3 from typing import Optional, Dict, Any, List import torch import loguru from ..storage_interface import ( StorageProvider, StorageFactory, UploadResult, DownloadResult, ) from ...config_utils import config class S3StorageProvider(StorageProvider): """ AWS S3存储提供者实现 实现了StorageProvider接口的所有方法,提供完整的S3存储功能。 采用懒加载模式初始化S3客户端,提高性能。 """ def __init__(self, config: Dict[str, Any]): """ 初始化S3存储提供者 Args: config: S3配置字典,必须包含access_key_id和secret_access_key """ super().__init__(config) self.bucket_name = config.get("bucket_name", "modal-media-cache") self.region = config.get("region", "us-east-1") self.cdn_base_url = config.get("cdn_base_url", "https://cdn.roasmax.cn") self._client = None def _validate_config(self) -> None: """ 验证S3配置的完整性 Raises: ValueError: 配置信息缺失时抛出异常 """ required_keys = ["access_key_id", "secret_access_key"] missing_keys = [key for key in required_keys if not self.config.get(key)] if missing_keys: raise ValueError( f"S3配置缺失必要参数: {missing_keys}. " f"请检查配置文件或环境变量" ) @property def client(self): """ 获取S3客户端实例(懒加载模式) Returns: boto3.client: S3客户端实例 """ if self._client is None: try: self._client = boto3.client( "s3", aws_access_key_id=self.config["access_key_id"], aws_secret_access_key=self.config["secret_access_key"], region_name=self.region, ) loguru.logger.info(f"S3客户端初始化成功,区域: {self.region}") except Exception as e: loguru.logger.error(f"S3客户端初始化失败: {e}") raise return self._client def upload_file( self, local_path: str, remote_key: str, content_type: Optional[str] = None, **kwargs, ) -> UploadResult: """ 上传本地文件到S3 Args: local_path: 本地文件路径 remote_key: S3中的键名 content_type: 文件内容类型 **kwargs: 额外的上传参数 Returns: UploadResult: 上传操作结果 """ try: if not os.path.exists(local_path): error_msg = f"本地文件不存在: {local_path}" loguru.logger.error(error_msg) return UploadResult( success=False, key=remote_key, message=error_msg, error=FileNotFoundError(error_msg), ) file_size = os.path.getsize(local_path) loguru.logger.info( f"开始上传文件到S3: {local_path} -> s3://{self.bucket_name}/{remote_key} " f"({file_size} bytes)" ) extra_args = kwargs.copy() if content_type: extra_args["ContentType"] = content_type self.client.upload_file( local_path, self.bucket_name, remote_key, ExtraArgs=extra_args if extra_args else None, ) url = f"{self.cdn_base_url}/{remote_key}" success_msg = f"文件上传成功: s3://{self.bucket_name}/{remote_key}" loguru.logger.info(success_msg) return UploadResult( success=True, key=remote_key, url=url, size=file_size, message=success_msg, ) except Exception as e: error_msg = f"S3文件上传失败: {str(e)}" loguru.logger.error(f"{error_msg} (文件: {local_path}, 键: {remote_key})") return UploadResult( success=False, key=remote_key, message=error_msg, error=e ) def upload_bytes( self, data: bytes, remote_key: str, content_type: Optional[str] = None, **kwargs ) -> UploadResult: """ 上传字节数据到S3 Args: data: 字节数据 remote_key: S3中的键名 content_type: 文件内容类型 **kwargs: 额外的上传参数 Returns: UploadResult: 上传操作结果 """ try: loguru.logger.info( f"开始上传字节数据到S3: s3://{self.bucket_name}/{remote_key} " f"({len(data)} bytes)" ) extra_args = kwargs.copy() if content_type: extra_args["ContentType"] = content_type self.client.put_object( Bucket=self.bucket_name, Key=remote_key, Body=data, **extra_args ) url = f"{self.cdn_base_url}/{remote_key}" success_msg = f"字节数据上传成功: s3://{self.bucket_name}/{remote_key}" loguru.logger.info(success_msg) return UploadResult( success=True, key=remote_key, url=url, size=len(data), message=success_msg, ) except Exception as e: error_msg = f"S3字节数据上传失败: {str(e)}" loguru.logger.error( f"{error_msg} (键: {remote_key}, 大小: {len(data)} bytes)" ) return UploadResult( success=False, key=remote_key, message=error_msg, error=e ) def upload_tensor( self, tensor: torch.Tensor, remote_key: str, format: str = "PNG", **kwargs ) -> UploadResult: """ 上传PyTorch张量作为图像到S3 Args: tensor: PyTorch张量 remote_key: S3中的键名 format: 图像格式(PNG, JPEG等) **kwargs: 额外的上传参数 Returns: UploadResult: 上传操作结果 """ try: from ...image_utils import tensor_to_tempfile loguru.logger.info( f"开始上传张量到S3: s3://{self.bucket_name}/{remote_key} " f"(形状: {tensor.shape}, 格式: {format})" ) # 将张量转换为临时文件 temp_file = tensor_to_tempfile(tensor, format=format) temp_path = temp_file.name try: # 设置内容类型 content_type = f"image/{format.lower()}" if format.upper() == "JPEG": content_type = "image/jpeg" # 上传临时文件 result = self.upload_file( temp_path, remote_key, content_type=content_type, **kwargs ) if result.success: loguru.logger.info( f"张量上传成功: s3://{self.bucket_name}/{remote_key}" ) return result finally: # 清理临时文件 if os.path.exists(temp_path): os.unlink(temp_path) loguru.logger.debug(f"临时文件已清理: {temp_path}") except Exception as e: error_msg = f"S3张量上传失败: {str(e)}" loguru.logger.error( f"{error_msg} (键: {remote_key}, 张量形状: {tensor.shape})" ) return UploadResult( success=False, key=remote_key, message=error_msg, error=e ) def download_file( self, remote_key: str, local_path: str, **kwargs ) -> DownloadResult: """ 从S3下载文件到本地 Args: remote_key: S3中的键名 local_path: 本地保存路径 **kwargs: 额外的下载参数 Returns: DownloadResult: 下载操作结果 """ try: loguru.logger.info( f"开始从S3下载文件: s3://{self.bucket_name}/{remote_key} -> {local_path}" ) # 确保本地目录存在 os.makedirs(os.path.dirname(local_path), exist_ok=True) # 检查文件是否存在 if not self.file_exists(remote_key): error_msg = f"S3文件不存在: s3://{self.bucket_name}/{remote_key}" loguru.logger.error(error_msg) return DownloadResult( success=False, message=error_msg, error=FileNotFoundError(error_msg) ) self.client.download_file(self.bucket_name, remote_key, local_path) success_msg = f"文件下载成功: {local_path}" loguru.logger.info(success_msg) return DownloadResult( success=True, local_path=local_path, message=success_msg ) except Exception as e: error_msg = f"S3文件下载失败: {str(e)}" loguru.logger.error( f"{error_msg} (键: {remote_key}, 本地路径: {local_path})" ) return DownloadResult(success=False, message=error_msg, error=e) def download_bytes(self, remote_key: str, **kwargs) -> DownloadResult: """ 从S3下载文件为字节数据 Args: remote_key: S3中的键名 **kwargs: 额外的下载参数 Returns: DownloadResult: 下载操作结果,数据包含在data字段中 """ try: loguru.logger.info( f"开始从S3下载字节数据: s3://{self.bucket_name}/{remote_key}" ) # 检查文件是否存在 if not self.file_exists(remote_key): error_msg = f"S3文件不存在: s3://{self.bucket_name}/{remote_key}" loguru.logger.error(error_msg) return DownloadResult( success=False, message=error_msg, error=FileNotFoundError(error_msg) ) response = self.client.get_object(Bucket=self.bucket_name, Key=remote_key) data = response["Body"].read() success_msg = f"字节数据下载成功: {len(data)} bytes" loguru.logger.info(success_msg) return DownloadResult(success=True, data=data, message=success_msg) except Exception as e: error_msg = f"S3字节数据下载失败: {str(e)}" loguru.logger.error(f"{error_msg} (键: {remote_key})") return DownloadResult(success=False, message=error_msg, error=e) def delete_file(self, remote_key: str, **kwargs) -> bool: """ 删除S3中的文件 Args: remote_key: S3中的键名 **kwargs: 额外的删除参数 Returns: bool: 删除是否成功 """ try: loguru.logger.info(f"删除S3文件: s3://{self.bucket_name}/{remote_key}") self.client.delete_object(Bucket=self.bucket_name, Key=remote_key) loguru.logger.info(f"文件删除成功: s3://{self.bucket_name}/{remote_key}") return True except Exception as e: loguru.logger.error(f"S3文件删除失败: {e} (键: {remote_key})") return False def file_exists(self, remote_key: str, **kwargs) -> bool: """ 检查S3中文件是否存在 Args: remote_key: S3中的键名 **kwargs: 额外的检查参数 Returns: bool: 文件是否存在 """ try: self.client.head_object(Bucket=self.bucket_name, Key=remote_key) return True except Exception: return False def list_files( self, prefix: str = "", max_keys: int = 1000, **kwargs ) -> List[Dict[str, Any]]: """ 列出S3中的文件 Args: prefix: 文件前缀过滤 max_keys: 最大返回数量 **kwargs: 额外的列表参数 Returns: List[Dict[str, Any]]: 文件信息列表 """ try: loguru.logger.info( f"列出S3文件: s3://{self.bucket_name}/{prefix} (最大: {max_keys})" ) response = self.client.list_objects_v2( Bucket=self.bucket_name, Prefix=prefix, MaxKeys=max_keys, **kwargs ) files = [] if "Contents" in response: for obj in response["Contents"]: files.append( { "key": obj["Key"], "size": obj["Size"], "last_modified": obj["LastModified"], "etag": obj.get("ETag", "").strip('"'), "storage_class": obj.get("StorageClass", "STANDARD"), } ) loguru.logger.info(f"找到 {len(files)} 个文件") return files except Exception as e: loguru.logger.error(f"S3文件列表获取失败: {e} (前缀: {prefix})") return [] def get_file_url(self, remote_key: str, expires_in: int = 3600, **kwargs) -> str: """ 获取S3文件的访问URL Args: remote_key: S3中的键名 expires_in: URL过期时间(秒) **kwargs: 额外的URL生成参数 Returns: str: 文件访问URL """ try: # 如果配置了CDN,优先返回CDN URL if self.cdn_base_url and expires_in > 3600: # 长期URL使用CDN return f"{self.cdn_base_url}/{remote_key}" # 生成预签名URL url = self.client.generate_presigned_url( "get_object", Params={"Bucket": self.bucket_name, "Key": remote_key}, ExpiresIn=expires_in, **kwargs, ) return url except Exception as e: loguru.logger.error(f"S3 URL生成失败: {e} (键: {remote_key})") # 如果预签名URL生成失败,返回CDN URL作为备选 if self.cdn_base_url: return f"{self.cdn_base_url}/{remote_key}" raise class S3StorageFactory(StorageFactory): """ S3存储工厂实现 负责创建S3存储提供者实例,处理配置验证和初始化。 """ def create_provider(self, config_dict: Dict[str, Any]) -> StorageProvider: """ 创建S3存储提供者实例 Args: config_dict: S3配置字典 Returns: StorageProvider: S3存储提供者实例 """ # 如果没有提供配置,尝试从全局配置获取 if not config_dict or not config_dict.get("access_key_id"): if config.has_aws_config(): aws_config = config.get_aws_config() config_dict = { "access_key_id": aws_config["access_key_id"], "secret_access_key": aws_config["secret_access_key"], **config_dict, } else: raise ValueError("未提供有效的S3配置,且全局配置中也没有AWS配置") return S3StorageProvider(config_dict) def get_supported_types(self) -> List[str]: """ 获取支持的存储类型列表 Returns: List[str]: 支持的存储类型 """ return ["s3", "aws", "amazon"]