ComfyUI-CustomNode/utils/image_utils.py

74 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import io
import torch
from PIL import Image
from torchvision import transforms
def base64_to_tensor(base64_data: str) -> torch.Tensor:
"""
""格式的图像数据转换为PyTorch张量
参数:
base64_data: 图像的Base64编码字符串
返回:
torch.Tensor: 形状为[C, H, W]的张量,取值范围为[0, 1]
"""
# 分离数据前缀和实际Base64编码部分
if ';base64,' in base64_data:
_, encoded = base64_data.split(';base64,', 1)
else:
encoded = base64_data # 假设直接提供了Base64编码部分
# 解码Base64数据
decoded_data = base64.b64decode(encoded)
# 使用PIL打开图像
image = Image.open(io.BytesIO(decoded_data))
# 转换为RGB模式处理PNG的Alpha通道和WebP格式
if image.mode != 'RGB':
image = image.convert('RGB')
# 转换为PyTorch张量
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor() # [H, W, C] -> [C, H, W],并归一化到[0, 1]
])
tensor = transform(image)
return tensor.unsqueeze(0).permute(0, 2, 3, 1)
def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes:
"""
将PyTorch张量转换为图像字节流
参数:
tensor: 形状为[C, H, W]的图像张量,取值范围为[0, 1]
format: 图像格式,可选'PNG''JPEG'
返回:
bytes: 图像的字节流数据
"""
if tensor.dim() == 4:
if tensor.shape[0] > 1:
print("警告:输入张量包含多个图像,仅使用第一个")
tensor = tensor[0] # 取批量中的第一张图像
tensor = tensor.permute(2, 0, 1)
# 确保张量在[0, 255]范围内
if tensor.max() <= 1.0:
tensor = tensor * 255
# 转换为PIL图像
image = transforms.ToPILImage()(tensor.byte())
# 保存为字节流
buffer = io.BytesIO()
image.save(buffer, format=format)
buffer.seek(0) # 重置指针到开始位置
return buffer.getvalue()