import base64 import io import torch from PIL import Image from torchvision import transforms def base64_to_tensor(base64_data: str) -> torch.Tensor: """ 将"data:image/xxx;base64,xxx"格式的图像数据转换为PyTorch张量 参数: base64_data: 图像的Base64编码字符串 返回: torch.Tensor: 形状为[C, H, W]的张量,取值范围为[0, 1] """ # 分离数据前缀和实际Base64编码部分 if ';base64,' in base64_data: _, encoded = base64_data.split(';base64,', 1) else: encoded = base64_data # 假设直接提供了Base64编码部分 # 解码Base64数据 decoded_data = base64.b64decode(encoded) # 使用PIL打开图像 image = Image.open(io.BytesIO(decoded_data)) # 转换为RGB模式(处理PNG的Alpha通道和WebP格式) if image.mode != 'RGB': image = image.convert('RGB') # 转换为PyTorch张量 from torchvision import transforms transform = transforms.Compose([ transforms.ToTensor() # [H, W, C] -> [C, H, W],并归一化到[0, 1] ]) tensor = transform(image) return tensor.unsqueeze(0).permute(0, 2, 3, 1) def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes: """ 将PyTorch张量转换为图像字节流 参数: tensor: 形状为[C, H, W]的图像张量,取值范围为[0, 1] format: 图像格式,可选'PNG'、'JPEG'等 返回: bytes: 图像的字节流数据 """ if tensor.dim() == 4: if tensor.shape[0] > 1: print("警告:输入张量包含多个图像,仅使用第一个") tensor = tensor[0] # 取批量中的第一张图像 tensor = tensor.permute(2, 0, 1) # 确保张量在[0, 255]范围内 if tensor.max() <= 1.0: tensor = tensor * 255 # 转换为PIL图像 image = transforms.ToPILImage()(tensor.byte()) # 保存为字节流 buffer = io.BytesIO() image.save(buffer, format=format) buffer.seek(0) # 重置指针到开始位置 return buffer.getvalue()