134 lines
4.0 KiB
Python
134 lines
4.0 KiB
Python
import base64
|
||
import io
|
||
import tempfile
|
||
|
||
import torch
|
||
from PIL import Image
|
||
from torchvision import transforms
|
||
|
||
|
||
def base64_to_tensor(base64_data: str) -> torch.Tensor:
|
||
"""
|
||
将"data:image/xxx;base64,xxx"格式的图像数据转换为PyTorch张量
|
||
|
||
参数:
|
||
base64_data: 图像的Base64编码字符串
|
||
|
||
返回:
|
||
torch.Tensor: 形状为[C, H, W]的张量,取值范围为[0, 1]
|
||
"""
|
||
# 分离数据前缀和实际Base64编码部分
|
||
if ';base64,' in base64_data:
|
||
_, encoded = base64_data.split(';base64,', 1)
|
||
else:
|
||
encoded = base64_data # 假设直接提供了Base64编码部分
|
||
|
||
# 解码Base64数据
|
||
decoded_data = base64.b64decode(encoded)
|
||
|
||
# 使用PIL打开图像
|
||
image = Image.open(io.BytesIO(decoded_data))
|
||
|
||
# 转换为RGB模式(处理PNG的Alpha通道和WebP格式)
|
||
if image.mode != 'RGB':
|
||
image = image.convert('RGB')
|
||
|
||
# 转换为PyTorch张量
|
||
from torchvision import transforms
|
||
transform = transforms.Compose([
|
||
transforms.ToTensor() # [H, W, C] -> [C, H, W],并归一化到[0, 1]
|
||
])
|
||
tensor = transform(image)
|
||
|
||
return tensor.unsqueeze(0).permute(0, 2, 3, 1)
|
||
|
||
|
||
def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes:
|
||
"""
|
||
将PyTorch张量转换为图像字节流
|
||
|
||
参数:
|
||
tensor: 形状为[C, H, W]的图像张量,取值范围为[0, 1]
|
||
format: 图像格式,可选'PNG'、'JPEG'等
|
||
|
||
返回:
|
||
bytes: 图像的字节流数据
|
||
"""
|
||
if tensor.dim() == 4:
|
||
if tensor.shape[0] > 1:
|
||
print("警告:输入张量包含多个图像,仅使用第一个")
|
||
tensor = tensor[0] # 取批量中的第一张图像
|
||
tensor = tensor.permute(2, 0, 1)
|
||
# 确保张量在[0, 255]范围内
|
||
if tensor.max() <= 1.0:
|
||
tensor = tensor * 255
|
||
|
||
# 转换为PIL图像
|
||
image = transforms.ToPILImage()(tensor.byte())
|
||
|
||
# 保存为字节流
|
||
buffer = io.BytesIO()
|
||
image.save(buffer, format=format)
|
||
buffer.seek(0) # 重置指针到开始位置
|
||
|
||
return buffer.getvalue()
|
||
|
||
|
||
def tensor_to_tempfile(tensor: torch.Tensor, format: str = 'PNG',
|
||
normalize: bool = True, range=None) -> tempfile.NamedTemporaryFile:
|
||
"""
|
||
将PyTorch张量转换为图像并保存到临时文件
|
||
|
||
参数:
|
||
tensor: 输入的PyTorch张量,可以是4D(BCHW)、3D(CHW)或2D(HW)
|
||
format: 图像格式,如'PNG'、'JPEG'等
|
||
normalize: 是否对张量进行归一化处理
|
||
range: 归一化范围,元组(min, max),默认为张量的最小值和最大值
|
||
|
||
返回:
|
||
临时文件对象,关闭后会自动删除
|
||
"""
|
||
# 处理4D张量 (BCHW),只取第一个样本
|
||
|
||
if tensor.dim() == 4:
|
||
if tensor.size(0) > 1:
|
||
print(f"警告: 输入张量包含多个样本,仅使用第一个样本 ({tensor.size(0)} -> 1)")
|
||
tensor = tensor[0]
|
||
|
||
# 确保张量在CPU上
|
||
tensor = tensor.cpu()
|
||
tensor = tensor.permute(2,0,1)
|
||
|
||
# 归一化处理
|
||
if normalize:
|
||
if range is None:
|
||
min_val, max_val = tensor.min(), tensor.max()
|
||
else:
|
||
min_val, max_val = range
|
||
|
||
if max_val > min_val:
|
||
tensor = (tensor - min_val) / (max_val - min_val)
|
||
else:
|
||
tensor = torch.zeros_like(tensor)
|
||
|
||
# 转换为PIL图像
|
||
if tensor.dim() == 2: # HW格式 (灰度图)
|
||
pil_img = transforms.ToPILImage()(tensor.unsqueeze(0)) # 添加通道维度
|
||
elif tensor.dim() == 3: # CHW格式
|
||
pil_img = transforms.ToPILImage()(tensor)
|
||
else:
|
||
raise ValueError(f"不支持的张量维度: {tensor.dim()}")
|
||
|
||
# 创建临时文件
|
||
temp_file = tempfile.NamedTemporaryFile(suffix=f'.{format.lower()}', delete=False)
|
||
|
||
try:
|
||
# 保存图像到临时文件
|
||
pil_img.save(temp_file, format=format)
|
||
except Exception as e:
|
||
# 发生错误时删除临时文件
|
||
temp_file.close()
|
||
raise e
|
||
|
||
temp_file.close() # 关闭文件但不删除
|
||
return temp_file |