diff --git a/nodes/image.py b/nodes/image.py index f9d88ad..8551788 100644 --- a/nodes/image.py +++ b/nodes/image.py @@ -1,36 +1,33 @@ -import os.path +import os import uuid - -import torch -import torchvision - - +from PIL import Image +import numpy as np class SaveImagePath: @classmethod def INPUT_TYPES(s): return { "required": { - "image_path":("IMAGE", {"forceInput": True}), + "image_path": ("IMAGE", {"forceInput": True}), } } - RETURN_TYPES = ("STRING",) - FUNCTION = "load" - CATEGORY = "不忘科技-自定义节点🚩" - - def load(self, image_path:torch.Tensor): - image_path = image_path.float() - # 假设数据范围在 [0, 255],进行归一化 - image_path = image_path / 255.0 - # 检查并调整形状(这里只是示例,具体调整需根据实际情况) - if len(image_path.shape) == 3: - image_path = image_path.unsqueeze(0) # 添加批次维度 - u = uuid.uuid4() - p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output", "%s.jpg" % str(u)) - torchvision.utils.save_image(image_path, p) - # u = uuid.uuid4() - # p = os.path.join(os.path.dirname(os.path.abspath(__file__)),"output","%s.jpg" % str(u)) - # torchvision.utils.save_image(image_path, p) - return (p,) + def load(self, image_path): + # 确保数据类型为uint8且在0 - 255范围内 + if image_path.dtype != np.uint8: + image_path = np.clip(image_path, 0, 255).astype(np.uint8) + # 如果是单通道图像,转换为3通道 + if len(image_path.shape) == 2: + image_path = np.stack([image_path] * 3, axis=-1) + # 如果是通道优先格式 (C, H, W),转换为通道最后格式 (H, W, C) + elif len(image_path.shape) == 3 and image_path.shape[0] <= 4: + image_path = np.transpose(image_path, (1, 2, 0)) + pil_image = Image.fromarray(image_path) + output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + file_name = "%s.jpg" % str(uuid.uuid4()) + p = os.path.join(output_dir, file_name) + pil_image.save(p) + return (p,) \ No newline at end of file