修改保存图片节点

This commit is contained in:
杨平 2025-06-11 17:54:17 +08:00
parent f52db91713
commit a06de2445c
1 changed files with 22 additions and 25 deletions

View File

@ -1,10 +1,7 @@
import os.path import os
import uuid import uuid
from PIL import Image
import torch import numpy as np
import torchvision
class SaveImagePath: class SaveImagePath:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -13,24 +10,24 @@ class SaveImagePath:
"image_path": ("IMAGE", {"forceInput": True}), "image_path": ("IMAGE", {"forceInput": True}),
} }
} }
RETURN_TYPES = ("STRING",) RETURN_TYPES = ("STRING",)
FUNCTION = "load" FUNCTION = "load"
CATEGORY = "不忘科技-自定义节点🚩" CATEGORY = "不忘科技-自定义节点🚩"
def load(self, image_path):
def load(self, image_path:torch.Tensor): # 确保数据类型为uint8且在0 - 255范围内
image_path = image_path.float() if image_path.dtype != np.uint8:
# 假设数据范围在 [0, 255],进行归一化 image_path = np.clip(image_path, 0, 255).astype(np.uint8)
image_path = image_path / 255.0 # 如果是单通道图像转换为3通道
# 检查并调整形状(这里只是示例,具体调整需根据实际情况) if len(image_path.shape) == 2:
if len(image_path.shape) == 3: image_path = np.stack([image_path] * 3, axis=-1)
image_path = image_path.unsqueeze(0) # 添加批次维度 # 如果是通道优先格式 (C, H, W),转换为通道最后格式 (H, W, C)
u = uuid.uuid4() elif len(image_path.shape) == 3 and image_path.shape[0] <= 4:
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output", "%s.jpg" % str(u)) image_path = np.transpose(image_path, (1, 2, 0))
torchvision.utils.save_image(image_path, p) pil_image = Image.fromarray(image_path)
# u = uuid.uuid4() output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
# p = os.path.join(os.path.dirname(os.path.abspath(__file__)),"output","%s.jpg" % str(u)) if not os.path.exists(output_dir):
# torchvision.utils.save_image(image_path, p) os.makedirs(output_dir)
file_name = "%s.jpg" % str(uuid.uuid4())
p = os.path.join(output_dir, file_name)
pil_image.save(p)
return (p,) return (p,)