ComfyUI-CustomNode/utils/image_utils.py

134 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import io
import tempfile
import torch
from PIL import Image
from torchvision import transforms
def base64_to_tensor(base64_data: str) -> torch.Tensor:
"""
""格式的图像数据转换为PyTorch张量
参数:
base64_data: 图像的Base64编码字符串
返回:
torch.Tensor: 形状为[C, H, W]的张量,取值范围为[0, 1]
"""
# 分离数据前缀和实际Base64编码部分
if ';base64,' in base64_data:
_, encoded = base64_data.split(';base64,', 1)
else:
encoded = base64_data # 假设直接提供了Base64编码部分
# 解码Base64数据
decoded_data = base64.b64decode(encoded)
# 使用PIL打开图像
image = Image.open(io.BytesIO(decoded_data))
# 转换为RGB模式处理PNG的Alpha通道和WebP格式
if image.mode != 'RGB':
image = image.convert('RGB')
# 转换为PyTorch张量
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor() # [H, W, C] -> [C, H, W],并归一化到[0, 1]
])
tensor = transform(image)
return tensor.unsqueeze(0).permute(0, 2, 3, 1)
def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes:
"""
将PyTorch张量转换为图像字节流
参数:
tensor: 形状为[C, H, W]的图像张量,取值范围为[0, 1]
format: 图像格式,可选'PNG''JPEG'
返回:
bytes: 图像的字节流数据
"""
if tensor.dim() == 4:
if tensor.shape[0] > 1:
print("警告:输入张量包含多个图像,仅使用第一个")
tensor = tensor[0] # 取批量中的第一张图像
tensor = tensor.permute(2, 0, 1)
# 确保张量在[0, 255]范围内
if tensor.max() <= 1.0:
tensor = tensor * 255
# 转换为PIL图像
image = transforms.ToPILImage()(tensor.byte())
# 保存为字节流
buffer = io.BytesIO()
image.save(buffer, format=format)
buffer.seek(0) # 重置指针到开始位置
return buffer.getvalue()
def tensor_to_tempfile(tensor: torch.Tensor, format: str = 'PNG',
normalize: bool = True, range=None) -> tempfile.NamedTemporaryFile:
"""
将PyTorch张量转换为图像并保存到临时文件
参数:
tensor: 输入的PyTorch张量可以是4D(BCHW)、3D(CHW)或2D(HW)
format: 图像格式,如'PNG''JPEG'
normalize: 是否对张量进行归一化处理
range: 归一化范围,元组(min, max),默认为张量的最小值和最大值
返回:
临时文件对象,关闭后会自动删除
"""
# 处理4D张量 (BCHW),只取第一个样本
if tensor.dim() == 4:
if tensor.size(0) > 1:
print(f"警告: 输入张量包含多个样本,仅使用第一个样本 ({tensor.size(0)} -> 1)")
tensor = tensor[0]
# 确保张量在CPU上
tensor = tensor.cpu()
tensor = tensor.permute(2,0,1)
# 归一化处理
if normalize:
if range is None:
min_val, max_val = tensor.min(), tensor.max()
else:
min_val, max_val = range
if max_val > min_val:
tensor = (tensor - min_val) / (max_val - min_val)
else:
tensor = torch.zeros_like(tensor)
# 转换为PIL图像
if tensor.dim() == 2: # HW格式 (灰度图)
pil_img = transforms.ToPILImage()(tensor.unsqueeze(0)) # 添加通道维度
elif tensor.dim() == 3: # CHW格式
pil_img = transforms.ToPILImage()(tensor)
else:
raise ValueError(f"不支持的张量维度: {tensor.dim()}")
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(suffix=f'.{format.lower()}', delete=False)
try:
# 保存图像到临时文件
pil_img.save(temp_file, format=format)
except Exception as e:
# 发生错误时删除临时文件
temp_file.close()
raise e
temp_file.close() # 关闭文件但不删除
return temp_file