388 lines
10 KiB
Python
388 lines
10 KiB
Python
"""
|
||
对象存储接口抽象层
|
||
|
||
本模块定义了对象存储服务的抽象接口,支持多种云存储服务的统一操作。
|
||
采用抽象工厂模式和策略模式,确保系统的可扩展性和可维护性。
|
||
|
||
支持的存储类型:
|
||
- AWS S3
|
||
- 腾讯云COS
|
||
- 其他云存储服务(可扩展)
|
||
|
||
设计原则:
|
||
- 开闭原则:对扩展开放,对修改关闭
|
||
- 依赖倒置原则:依赖抽象而非具体实现
|
||
- 单一职责原则:每个类只负责一个职责
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional, Dict, Any, List
|
||
from dataclasses import dataclass
|
||
import torch
|
||
|
||
|
||
@dataclass
|
||
class UploadResult:
|
||
"""上传结果封装类"""
|
||
|
||
success: bool
|
||
key: str
|
||
url: Optional[str] = None
|
||
size: Optional[int] = None
|
||
message: Optional[str] = None
|
||
error: Optional[Exception] = None
|
||
|
||
|
||
@dataclass
|
||
class DownloadResult:
|
||
"""下载结果封装类"""
|
||
|
||
success: bool
|
||
local_path: Optional[str] = None
|
||
data: Optional[bytes] = None
|
||
message: Optional[str] = None
|
||
error: Optional[Exception] = None
|
||
|
||
|
||
class StorageProvider(ABC):
|
||
"""
|
||
对象存储提供者抽象基类
|
||
|
||
定义了所有对象存储服务必须实现的基本操作接口。
|
||
遵循里氏替换原则,所有子类都可以替换基类使用。
|
||
"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
"""
|
||
初始化存储提供者
|
||
|
||
Args:
|
||
config: 存储配置字典,包含认证信息和其他设置
|
||
"""
|
||
self.config = config
|
||
self._validate_config()
|
||
|
||
@abstractmethod
|
||
def _validate_config(self) -> None:
|
||
"""
|
||
验证配置信息的完整性和有效性
|
||
|
||
Raises:
|
||
ValueError: 配置信息缺失或无效时抛出异常
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def upload_file(self, local_path: str, remote_key: str, **kwargs) -> UploadResult:
|
||
"""
|
||
上传本地文件到远程存储
|
||
|
||
Args:
|
||
local_path: 本地文件路径
|
||
remote_key: 远程存储中的键名
|
||
**kwargs: 额外的上传参数(如内容类型、元数据等)
|
||
|
||
Returns:
|
||
UploadResult: 上传操作结果
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def upload_bytes(self, data: bytes, remote_key: str, **kwargs) -> UploadResult:
|
||
"""
|
||
上传字节数据到远程存储
|
||
|
||
Args:
|
||
data: 字节数据
|
||
remote_key: 远程存储中的键名
|
||
**kwargs: 额外的上传参数
|
||
|
||
Returns:
|
||
UploadResult: 上传操作结果
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def upload_tensor(
|
||
self, tensor: torch.Tensor, remote_key: str, format: str = "PNG", **kwargs
|
||
) -> UploadResult:
|
||
"""
|
||
上传PyTorch张量作为图像到远程存储
|
||
|
||
Args:
|
||
tensor: PyTorch张量
|
||
remote_key: 远程存储中的键名
|
||
format: 图像格式(PNG, JPEG等)
|
||
**kwargs: 额外的上传参数
|
||
|
||
Returns:
|
||
UploadResult: 上传操作结果
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def download_file(
|
||
self, remote_key: str, local_path: str, **kwargs
|
||
) -> DownloadResult:
|
||
"""
|
||
从远程存储下载文件到本地
|
||
|
||
Args:
|
||
remote_key: 远程存储中的键名
|
||
local_path: 本地保存路径
|
||
**kwargs: 额外的下载参数
|
||
|
||
Returns:
|
||
DownloadResult: 下载操作结果
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def download_bytes(self, remote_key: str, **kwargs) -> DownloadResult:
|
||
"""
|
||
从远程存储下载文件为字节数据
|
||
|
||
Args:
|
||
remote_key: 远程存储中的键名
|
||
**kwargs: 额外的下载参数
|
||
|
||
Returns:
|
||
DownloadResult: 下载操作结果,数据包含在data字段中
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def delete_file(self, remote_key: str, **kwargs) -> bool:
|
||
"""
|
||
删除远程存储中的文件
|
||
|
||
Args:
|
||
remote_key: 远程存储中的键名
|
||
**kwargs: 额外的删除参数
|
||
|
||
Returns:
|
||
bool: 删除是否成功
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def file_exists(self, remote_key: str, **kwargs) -> bool:
|
||
"""
|
||
检查远程存储中文件是否存在
|
||
|
||
Args:
|
||
remote_key: 远程存储中的键名
|
||
**kwargs: 额外的检查参数
|
||
|
||
Returns:
|
||
bool: 文件是否存在
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def list_files(
|
||
self, prefix: str = "", max_keys: int = 1000, **kwargs
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
列出远程存储中的文件
|
||
|
||
Args:
|
||
prefix: 文件前缀过滤
|
||
max_keys: 最大返回数量
|
||
**kwargs: 额外的列表参数
|
||
|
||
Returns:
|
||
List[Dict[str, Any]]: 文件信息列表
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_file_url(self, remote_key: str, expires_in: int = 3600, **kwargs) -> str:
|
||
"""
|
||
获取文件的访问URL
|
||
|
||
Args:
|
||
remote_key: 远程存储中的键名
|
||
expires_in: URL过期时间(秒)
|
||
**kwargs: 额外的URL生成参数
|
||
|
||
Returns:
|
||
str: 文件访问URL
|
||
"""
|
||
pass
|
||
|
||
def get_provider_name(self) -> str:
|
||
"""
|
||
获取存储提供者名称
|
||
|
||
Returns:
|
||
str: 提供者名称
|
||
"""
|
||
return self.__class__.__name__
|
||
|
||
|
||
class StorageFactory(ABC):
|
||
"""
|
||
存储提供者工厂抽象基类
|
||
|
||
采用抽象工厂模式,负责创建具体的存储提供者实例。
|
||
支持不同类型的存储服务创建。
|
||
"""
|
||
|
||
@abstractmethod
|
||
def create_provider(self, config: Dict[str, Any]) -> StorageProvider:
|
||
"""
|
||
创建存储提供者实例
|
||
|
||
Args:
|
||
config: 存储配置字典
|
||
|
||
Returns:
|
||
StorageProvider: 存储提供者实例
|
||
|
||
Raises:
|
||
ValueError: 配置无效时抛出异常
|
||
NotImplementedError: 不支持的存储类型时抛出异常
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_supported_types(self) -> List[str]:
|
||
"""
|
||
获取支持的存储类型列表
|
||
|
||
Returns:
|
||
List[str]: 支持的存储类型
|
||
"""
|
||
pass
|
||
|
||
|
||
class StorageManager:
|
||
"""
|
||
存储管理器
|
||
|
||
采用策略模式,统一管理不同的存储提供者。
|
||
提供统一的存储操作接口,支持动态切换存储提供者。
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化存储管理器"""
|
||
self._factories: Dict[str, StorageFactory] = {}
|
||
self._providers: Dict[str, StorageProvider] = {}
|
||
self._default_provider: Optional[str] = None
|
||
|
||
def register_factory(self, storage_type: str, factory: StorageFactory) -> None:
|
||
"""
|
||
注册存储工厂
|
||
|
||
Args:
|
||
storage_type: 存储类型标识符(如:'s3', 'cos')
|
||
factory: 存储工厂实例
|
||
"""
|
||
self._factories[storage_type] = factory
|
||
|
||
def create_provider(
|
||
self,
|
||
storage_type: str,
|
||
config: Dict[str, Any],
|
||
provider_id: Optional[str] = None,
|
||
) -> StorageProvider:
|
||
"""
|
||
创建存储提供者
|
||
|
||
Args:
|
||
storage_type: 存储类型
|
||
config: 存储配置
|
||
provider_id: 提供者唯一标识符,默认使用storage_type
|
||
|
||
Returns:
|
||
StorageProvider: 存储提供者实例
|
||
|
||
Raises:
|
||
ValueError: 不支持的存储类型时抛出异常
|
||
"""
|
||
if storage_type not in self._factories:
|
||
available_types = list(self._factories.keys())
|
||
raise ValueError(
|
||
f"不支持的存储类型: {storage_type}. " f"可用类型: {available_types}"
|
||
)
|
||
|
||
factory = self._factories[storage_type]
|
||
provider = factory.create_provider(config)
|
||
|
||
# 缓存提供者实例
|
||
provider_id = provider_id or storage_type
|
||
self._providers[provider_id] = provider
|
||
|
||
# 设置默认提供者
|
||
if self._default_provider is None:
|
||
self._default_provider = provider_id
|
||
|
||
return provider
|
||
|
||
def get_provider(self, provider_id: Optional[str] = None) -> StorageProvider:
|
||
"""
|
||
获取存储提供者
|
||
|
||
Args:
|
||
provider_id: 提供者标识符,默认返回默认提供者
|
||
|
||
Returns:
|
||
StorageProvider: 存储提供者实例
|
||
|
||
Raises:
|
||
ValueError: 提供者不存在时抛出异常
|
||
"""
|
||
if provider_id is None:
|
||
provider_id = self._default_provider
|
||
|
||
if provider_id is None or provider_id not in self._providers:
|
||
available_providers = list(self._providers.keys())
|
||
raise ValueError(
|
||
f"存储提供者不存在: {provider_id}. "
|
||
f"可用提供者: {available_providers}"
|
||
)
|
||
|
||
return self._providers[provider_id]
|
||
|
||
def set_default_provider(self, provider_id: str) -> None:
|
||
"""
|
||
设置默认存储提供者
|
||
|
||
Args:
|
||
provider_id: 提供者标识符
|
||
|
||
Raises:
|
||
ValueError: 提供者不存在时抛出异常
|
||
"""
|
||
if provider_id not in self._providers:
|
||
available_providers = list(self._providers.keys())
|
||
raise ValueError(
|
||
f"存储提供者不存在: {provider_id}. "
|
||
f"可用提供者: {available_providers}"
|
||
)
|
||
|
||
self._default_provider = provider_id
|
||
|
||
def get_available_providers(self) -> List[str]:
|
||
"""
|
||
获取所有可用的提供者列表
|
||
|
||
Returns:
|
||
List[str]: 提供者标识符列表
|
||
"""
|
||
return list(self._providers.keys())
|
||
|
||
def get_supported_types(self) -> List[str]:
|
||
"""
|
||
获取所有支持的存储类型
|
||
|
||
Returns:
|
||
List[str]: 存储类型列表
|
||
"""
|
||
return list(self._factories.keys())
|
||
|
||
|
||
# 全局存储管理器实例
|
||
storage_manager = StorageManager()
|