ComfyUI-CustomNode/utils/object_storage/storage_interface.py

388 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
对象存储接口抽象层
本模块定义了对象存储服务的抽象接口,支持多种云存储服务的统一操作。
采用抽象工厂模式和策略模式,确保系统的可扩展性和可维护性。
支持的存储类型:
- AWS S3
- 腾讯云COS
- 其他云存储服务(可扩展)
设计原则:
- 开闭原则:对扩展开放,对修改关闭
- 依赖倒置原则:依赖抽象而非具体实现
- 单一职责原则:每个类只负责一个职责
"""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
import torch
@dataclass
class UploadResult:
"""上传结果封装类"""
success: bool
key: str
url: Optional[str] = None
size: Optional[int] = None
message: Optional[str] = None
error: Optional[Exception] = None
@dataclass
class DownloadResult:
"""下载结果封装类"""
success: bool
local_path: Optional[str] = None
data: Optional[bytes] = None
message: Optional[str] = None
error: Optional[Exception] = None
class StorageProvider(ABC):
"""
对象存储提供者抽象基类
定义了所有对象存储服务必须实现的基本操作接口。
遵循里氏替换原则,所有子类都可以替换基类使用。
"""
def __init__(self, config: Dict[str, Any]):
"""
初始化存储提供者
Args:
config: 存储配置字典,包含认证信息和其他设置
"""
self.config = config
self._validate_config()
@abstractmethod
def _validate_config(self) -> None:
"""
验证配置信息的完整性和有效性
Raises:
ValueError: 配置信息缺失或无效时抛出异常
"""
pass
@abstractmethod
def upload_file(self, local_path: str, remote_key: str, **kwargs) -> UploadResult:
"""
上传本地文件到远程存储
Args:
local_path: 本地文件路径
remote_key: 远程存储中的键名
**kwargs: 额外的上传参数(如内容类型、元数据等)
Returns:
UploadResult: 上传操作结果
"""
pass
@abstractmethod
def upload_bytes(self, data: bytes, remote_key: str, **kwargs) -> UploadResult:
"""
上传字节数据到远程存储
Args:
data: 字节数据
remote_key: 远程存储中的键名
**kwargs: 额外的上传参数
Returns:
UploadResult: 上传操作结果
"""
pass
@abstractmethod
def upload_tensor(
self, tensor: torch.Tensor, remote_key: str, format: str = "PNG", **kwargs
) -> UploadResult:
"""
上传PyTorch张量作为图像到远程存储
Args:
tensor: PyTorch张量
remote_key: 远程存储中的键名
format: 图像格式PNG, JPEG等
**kwargs: 额外的上传参数
Returns:
UploadResult: 上传操作结果
"""
pass
@abstractmethod
def download_file(
self, remote_key: str, local_path: str, **kwargs
) -> DownloadResult:
"""
从远程存储下载文件到本地
Args:
remote_key: 远程存储中的键名
local_path: 本地保存路径
**kwargs: 额外的下载参数
Returns:
DownloadResult: 下载操作结果
"""
pass
@abstractmethod
def download_bytes(self, remote_key: str, **kwargs) -> DownloadResult:
"""
从远程存储下载文件为字节数据
Args:
remote_key: 远程存储中的键名
**kwargs: 额外的下载参数
Returns:
DownloadResult: 下载操作结果数据包含在data字段中
"""
pass
@abstractmethod
def delete_file(self, remote_key: str, **kwargs) -> bool:
"""
删除远程存储中的文件
Args:
remote_key: 远程存储中的键名
**kwargs: 额外的删除参数
Returns:
bool: 删除是否成功
"""
pass
@abstractmethod
def file_exists(self, remote_key: str, **kwargs) -> bool:
"""
检查远程存储中文件是否存在
Args:
remote_key: 远程存储中的键名
**kwargs: 额外的检查参数
Returns:
bool: 文件是否存在
"""
pass
@abstractmethod
def list_files(
self, prefix: str = "", max_keys: int = 1000, **kwargs
) -> List[Dict[str, Any]]:
"""
列出远程存储中的文件
Args:
prefix: 文件前缀过滤
max_keys: 最大返回数量
**kwargs: 额外的列表参数
Returns:
List[Dict[str, Any]]: 文件信息列表
"""
pass
@abstractmethod
def get_file_url(self, remote_key: str, expires_in: int = 3600, **kwargs) -> str:
"""
获取文件的访问URL
Args:
remote_key: 远程存储中的键名
expires_in: URL过期时间
**kwargs: 额外的URL生成参数
Returns:
str: 文件访问URL
"""
pass
def get_provider_name(self) -> str:
"""
获取存储提供者名称
Returns:
str: 提供者名称
"""
return self.__class__.__name__
class StorageFactory(ABC):
"""
存储提供者工厂抽象基类
采用抽象工厂模式,负责创建具体的存储提供者实例。
支持不同类型的存储服务创建。
"""
@abstractmethod
def create_provider(self, config: Dict[str, Any]) -> StorageProvider:
"""
创建存储提供者实例
Args:
config: 存储配置字典
Returns:
StorageProvider: 存储提供者实例
Raises:
ValueError: 配置无效时抛出异常
NotImplementedError: 不支持的存储类型时抛出异常
"""
pass
@abstractmethod
def get_supported_types(self) -> List[str]:
"""
获取支持的存储类型列表
Returns:
List[str]: 支持的存储类型
"""
pass
class StorageManager:
"""
存储管理器
采用策略模式,统一管理不同的存储提供者。
提供统一的存储操作接口,支持动态切换存储提供者。
"""
def __init__(self):
"""初始化存储管理器"""
self._factories: Dict[str, StorageFactory] = {}
self._providers: Dict[str, StorageProvider] = {}
self._default_provider: Optional[str] = None
def register_factory(self, storage_type: str, factory: StorageFactory) -> None:
"""
注册存储工厂
Args:
storage_type: 存储类型标识符(如:'s3', 'cos'
factory: 存储工厂实例
"""
self._factories[storage_type] = factory
def create_provider(
self,
storage_type: str,
config: Dict[str, Any],
provider_id: Optional[str] = None,
) -> StorageProvider:
"""
创建存储提供者
Args:
storage_type: 存储类型
config: 存储配置
provider_id: 提供者唯一标识符默认使用storage_type
Returns:
StorageProvider: 存储提供者实例
Raises:
ValueError: 不支持的存储类型时抛出异常
"""
if storage_type not in self._factories:
available_types = list(self._factories.keys())
raise ValueError(
f"不支持的存储类型: {storage_type}. " f"可用类型: {available_types}"
)
factory = self._factories[storage_type]
provider = factory.create_provider(config)
# 缓存提供者实例
provider_id = provider_id or storage_type
self._providers[provider_id] = provider
# 设置默认提供者
if self._default_provider is None:
self._default_provider = provider_id
return provider
def get_provider(self, provider_id: Optional[str] = None) -> StorageProvider:
"""
获取存储提供者
Args:
provider_id: 提供者标识符,默认返回默认提供者
Returns:
StorageProvider: 存储提供者实例
Raises:
ValueError: 提供者不存在时抛出异常
"""
if provider_id is None:
provider_id = self._default_provider
if provider_id is None or provider_id not in self._providers:
available_providers = list(self._providers.keys())
raise ValueError(
f"存储提供者不存在: {provider_id}. "
f"可用提供者: {available_providers}"
)
return self._providers[provider_id]
def set_default_provider(self, provider_id: str) -> None:
"""
设置默认存储提供者
Args:
provider_id: 提供者标识符
Raises:
ValueError: 提供者不存在时抛出异常
"""
if provider_id not in self._providers:
available_providers = list(self._providers.keys())
raise ValueError(
f"存储提供者不存在: {provider_id}. "
f"可用提供者: {available_providers}"
)
self._default_provider = provider_id
def get_available_providers(self) -> List[str]:
"""
获取所有可用的提供者列表
Returns:
List[str]: 提供者标识符列表
"""
return list(self._providers.keys())
def get_supported_types(self) -> List[str]:
"""
获取所有支持的存储类型
Returns:
List[str]: 存储类型列表
"""
return list(self._factories.keys())
# 全局存储管理器实例
storage_manager = StorageManager()