74 lines
2.0 KiB
Python
74 lines
2.0 KiB
Python
import base64
|
||
import io
|
||
|
||
import torch
|
||
from PIL import Image
|
||
from torchvision import transforms
|
||
|
||
|
||
def base64_to_tensor(base64_data: str) -> torch.Tensor:
|
||
"""
|
||
将"data:image/xxx;base64,xxx"格式的图像数据转换为PyTorch张量
|
||
|
||
参数:
|
||
base64_data: 图像的Base64编码字符串
|
||
|
||
返回:
|
||
torch.Tensor: 形状为[C, H, W]的张量,取值范围为[0, 1]
|
||
"""
|
||
# 分离数据前缀和实际Base64编码部分
|
||
if ';base64,' in base64_data:
|
||
_, encoded = base64_data.split(';base64,', 1)
|
||
else:
|
||
encoded = base64_data # 假设直接提供了Base64编码部分
|
||
|
||
# 解码Base64数据
|
||
decoded_data = base64.b64decode(encoded)
|
||
|
||
# 使用PIL打开图像
|
||
image = Image.open(io.BytesIO(decoded_data))
|
||
|
||
# 转换为RGB模式(处理PNG的Alpha通道和WebP格式)
|
||
if image.mode != 'RGB':
|
||
image = image.convert('RGB')
|
||
|
||
# 转换为PyTorch张量
|
||
from torchvision import transforms
|
||
transform = transforms.Compose([
|
||
transforms.ToTensor() # [H, W, C] -> [C, H, W],并归一化到[0, 1]
|
||
])
|
||
tensor = transform(image)
|
||
|
||
return tensor.unsqueeze(0).permute(0, 2, 3, 1)
|
||
|
||
|
||
def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes:
|
||
"""
|
||
将PyTorch张量转换为图像字节流
|
||
|
||
参数:
|
||
tensor: 形状为[C, H, W]的图像张量,取值范围为[0, 1]
|
||
format: 图像格式,可选'PNG'、'JPEG'等
|
||
|
||
返回:
|
||
bytes: 图像的字节流数据
|
||
"""
|
||
if tensor.dim() == 4:
|
||
if tensor.shape[0] > 1:
|
||
print("警告:输入张量包含多个图像,仅使用第一个")
|
||
tensor = tensor[0] # 取批量中的第一张图像
|
||
tensor = tensor.permute(2, 0, 1)
|
||
# 确保张量在[0, 255]范围内
|
||
if tensor.max() <= 1.0:
|
||
tensor = tensor * 255
|
||
|
||
# 转换为PIL图像
|
||
image = transforms.ToPILImage()(tensor.byte())
|
||
|
||
# 保存为字节流
|
||
buffer = io.BytesIO()
|
||
image.save(buffer, format=format)
|
||
buffer.seek(0) # 重置指针到开始位置
|
||
|
||
return buffer.getvalue()
|