import base64 import io import tempfile import torch from PIL import Image from torchvision import transforms def base64_to_tensor(base64_data: str) -> torch.Tensor: """ 将"data:image/xxx;base64,xxx"格式的图像数据转换为PyTorch张量 参数: base64_data: 图像的Base64编码字符串 返回: torch.Tensor: 形状为[C, H, W]的张量,取值范围为[0, 1] """ # 分离数据前缀和实际Base64编码部分 if ';base64,' in base64_data: _, encoded = base64_data.split(';base64,', 1) else: encoded = base64_data # 假设直接提供了Base64编码部分 # 解码Base64数据 decoded_data = base64.b64decode(encoded) # 使用PIL打开图像 image = Image.open(io.BytesIO(decoded_data)) # 转换为RGB模式(处理PNG的Alpha通道和WebP格式) if image.mode != 'RGB': image = image.convert('RGB') # 转换为PyTorch张量 from torchvision import transforms transform = transforms.Compose([ transforms.ToTensor() # [H, W, C] -> [C, H, W],并归一化到[0, 1] ]) tensor = transform(image) return tensor.unsqueeze(0).permute(0, 2, 3, 1) def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes: """ 将PyTorch张量转换为图像字节流 参数: tensor: 形状为[C, H, W]的图像张量,取值范围为[0, 1] format: 图像格式,可选'PNG'、'JPEG'等 返回: bytes: 图像的字节流数据 """ if tensor.dim() == 4: if tensor.shape[0] > 1: print("警告:输入张量包含多个图像,仅使用第一个") tensor = tensor[0] # 取批量中的第一张图像 tensor = tensor.permute(2, 0, 1) # 确保张量在[0, 255]范围内 if tensor.max() <= 1.0: tensor = tensor * 255 # 转换为PIL图像 image = transforms.ToPILImage()(tensor.byte()) # 保存为字节流 buffer = io.BytesIO() image.save(buffer, format=format) buffer.seek(0) # 重置指针到开始位置 return buffer.getvalue() def tensor_to_tempfile(tensor: torch.Tensor, format: str = 'PNG', normalize: bool = True, range=None) -> tempfile.NamedTemporaryFile: """ 将PyTorch张量转换为图像并保存到临时文件 参数: tensor: 输入的PyTorch张量,可以是4D(BCHW)、3D(CHW)或2D(HW) format: 图像格式,如'PNG'、'JPEG'等 normalize: 是否对张量进行归一化处理 range: 归一化范围,元组(min, max),默认为张量的最小值和最大值 返回: 临时文件对象,关闭后会自动删除 """ # 处理4D张量 (BCHW),只取第一个样本 if tensor.dim() == 4: if tensor.size(0) > 1: print(f"警告: 输入张量包含多个样本,仅使用第一个样本 ({tensor.size(0)} -> 1)") tensor = tensor[0] # 确保张量在CPU上 tensor = tensor.cpu() tensor = tensor.permute(2,0,1) # 归一化处理 if normalize: if range is None: min_val, max_val = tensor.min(), tensor.max() else: min_val, max_val = range if max_val > min_val: tensor = (tensor - min_val) / (max_val - min_val) else: tensor = torch.zeros_like(tensor) # 转换为PIL图像 if tensor.dim() == 2: # HW格式 (灰度图) pil_img = transforms.ToPILImage()(tensor.unsqueeze(0)) # 添加通道维度 elif tensor.dim() == 3: # CHW格式 pil_img = transforms.ToPILImage()(tensor) else: raise ValueError(f"不支持的张量维度: {tensor.dim()}") # 创建临时文件 temp_file = tempfile.NamedTemporaryFile(suffix=f'.{format.lower()}', delete=False) try: # 保存图像到临时文件 pil_img.save(temp_file, format=format) except Exception as e: # 发生错误时删除临时文件 temp_file.close() raise e temp_file.close() # 关闭文件但不删除 return temp_file