""" 对象存储接口抽象层 本模块定义了对象存储服务的抽象接口,支持多种云存储服务的统一操作。 采用抽象工厂模式和策略模式,确保系统的可扩展性和可维护性。 支持的存储类型: - AWS S3 - 腾讯云COS - 其他云存储服务(可扩展) 设计原则: - 开闭原则:对扩展开放,对修改关闭 - 依赖倒置原则:依赖抽象而非具体实现 - 单一职责原则:每个类只负责一个职责 """ from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List from dataclasses import dataclass import torch @dataclass class UploadResult: """上传结果封装类""" success: bool key: str url: Optional[str] = None size: Optional[int] = None message: Optional[str] = None error: Optional[Exception] = None @dataclass class DownloadResult: """下载结果封装类""" success: bool local_path: Optional[str] = None data: Optional[bytes] = None message: Optional[str] = None error: Optional[Exception] = None class StorageProvider(ABC): """ 对象存储提供者抽象基类 定义了所有对象存储服务必须实现的基本操作接口。 遵循里氏替换原则,所有子类都可以替换基类使用。 """ def __init__(self, config: Dict[str, Any]): """ 初始化存储提供者 Args: config: 存储配置字典,包含认证信息和其他设置 """ self.config = config self._validate_config() @abstractmethod def _validate_config(self) -> None: """ 验证配置信息的完整性和有效性 Raises: ValueError: 配置信息缺失或无效时抛出异常 """ pass @abstractmethod def upload_file(self, local_path: str, remote_key: str, **kwargs) -> UploadResult: """ 上传本地文件到远程存储 Args: local_path: 本地文件路径 remote_key: 远程存储中的键名 **kwargs: 额外的上传参数(如内容类型、元数据等) Returns: UploadResult: 上传操作结果 """ pass @abstractmethod def upload_bytes(self, data: bytes, remote_key: str, **kwargs) -> UploadResult: """ 上传字节数据到远程存储 Args: data: 字节数据 remote_key: 远程存储中的键名 **kwargs: 额外的上传参数 Returns: UploadResult: 上传操作结果 """ pass @abstractmethod def upload_tensor( self, tensor: torch.Tensor, remote_key: str, format: str = "PNG", **kwargs ) -> UploadResult: """ 上传PyTorch张量作为图像到远程存储 Args: tensor: PyTorch张量 remote_key: 远程存储中的键名 format: 图像格式(PNG, JPEG等) **kwargs: 额外的上传参数 Returns: UploadResult: 上传操作结果 """ pass @abstractmethod def download_file( self, remote_key: str, local_path: str, **kwargs ) -> DownloadResult: """ 从远程存储下载文件到本地 Args: remote_key: 远程存储中的键名 local_path: 本地保存路径 **kwargs: 额外的下载参数 Returns: DownloadResult: 下载操作结果 """ pass @abstractmethod def download_bytes(self, remote_key: str, **kwargs) -> DownloadResult: """ 从远程存储下载文件为字节数据 Args: remote_key: 远程存储中的键名 **kwargs: 额外的下载参数 Returns: DownloadResult: 下载操作结果,数据包含在data字段中 """ pass @abstractmethod def delete_file(self, remote_key: str, **kwargs) -> bool: """ 删除远程存储中的文件 Args: remote_key: 远程存储中的键名 **kwargs: 额外的删除参数 Returns: bool: 删除是否成功 """ pass @abstractmethod def file_exists(self, remote_key: str, **kwargs) -> bool: """ 检查远程存储中文件是否存在 Args: remote_key: 远程存储中的键名 **kwargs: 额外的检查参数 Returns: bool: 文件是否存在 """ pass @abstractmethod def list_files( self, prefix: str = "", max_keys: int = 1000, **kwargs ) -> List[Dict[str, Any]]: """ 列出远程存储中的文件 Args: prefix: 文件前缀过滤 max_keys: 最大返回数量 **kwargs: 额外的列表参数 Returns: List[Dict[str, Any]]: 文件信息列表 """ pass @abstractmethod def get_file_url(self, remote_key: str, expires_in: int = 3600, **kwargs) -> str: """ 获取文件的访问URL Args: remote_key: 远程存储中的键名 expires_in: URL过期时间(秒) **kwargs: 额外的URL生成参数 Returns: str: 文件访问URL """ pass def get_provider_name(self) -> str: """ 获取存储提供者名称 Returns: str: 提供者名称 """ return self.__class__.__name__ class StorageFactory(ABC): """ 存储提供者工厂抽象基类 采用抽象工厂模式,负责创建具体的存储提供者实例。 支持不同类型的存储服务创建。 """ @abstractmethod def create_provider(self, config: Dict[str, Any]) -> StorageProvider: """ 创建存储提供者实例 Args: config: 存储配置字典 Returns: StorageProvider: 存储提供者实例 Raises: ValueError: 配置无效时抛出异常 NotImplementedError: 不支持的存储类型时抛出异常 """ pass @abstractmethod def get_supported_types(self) -> List[str]: """ 获取支持的存储类型列表 Returns: List[str]: 支持的存储类型 """ pass class StorageManager: """ 存储管理器 采用策略模式,统一管理不同的存储提供者。 提供统一的存储操作接口,支持动态切换存储提供者。 """ def __init__(self): """初始化存储管理器""" self._factories: Dict[str, StorageFactory] = {} self._providers: Dict[str, StorageProvider] = {} self._default_provider: Optional[str] = None def register_factory(self, storage_type: str, factory: StorageFactory) -> None: """ 注册存储工厂 Args: storage_type: 存储类型标识符(如:'s3', 'cos') factory: 存储工厂实例 """ self._factories[storage_type] = factory def create_provider( self, storage_type: str, config: Dict[str, Any], provider_id: Optional[str] = None, ) -> StorageProvider: """ 创建存储提供者 Args: storage_type: 存储类型 config: 存储配置 provider_id: 提供者唯一标识符,默认使用storage_type Returns: StorageProvider: 存储提供者实例 Raises: ValueError: 不支持的存储类型时抛出异常 """ if storage_type not in self._factories: available_types = list(self._factories.keys()) raise ValueError( f"不支持的存储类型: {storage_type}. " f"可用类型: {available_types}" ) factory = self._factories[storage_type] provider = factory.create_provider(config) # 缓存提供者实例 provider_id = provider_id or storage_type self._providers[provider_id] = provider # 设置默认提供者 if self._default_provider is None: self._default_provider = provider_id return provider def get_provider(self, provider_id: Optional[str] = None) -> StorageProvider: """ 获取存储提供者 Args: provider_id: 提供者标识符,默认返回默认提供者 Returns: StorageProvider: 存储提供者实例 Raises: ValueError: 提供者不存在时抛出异常 """ if provider_id is None: provider_id = self._default_provider if provider_id is None or provider_id not in self._providers: available_providers = list(self._providers.keys()) raise ValueError( f"存储提供者不存在: {provider_id}. " f"可用提供者: {available_providers}" ) return self._providers[provider_id] def set_default_provider(self, provider_id: str) -> None: """ 设置默认存储提供者 Args: provider_id: 提供者标识符 Raises: ValueError: 提供者不存在时抛出异常 """ if provider_id not in self._providers: available_providers = list(self._providers.keys()) raise ValueError( f"存储提供者不存在: {provider_id}. " f"可用提供者: {available_providers}" ) self._default_provider = provider_id def get_available_providers(self) -> List[str]: """ 获取所有可用的提供者列表 Returns: List[str]: 提供者标识符列表 """ return list(self._providers.keys()) def get_supported_types(self) -> List[str]: """ 获取所有支持的存储类型 Returns: List[str]: 存储类型列表 """ return list(self._factories.keys()) # 全局存储管理器实例 storage_manager = StorageManager()