From a37d35b8b8e203741e512084ea4910c11cf4ca82 Mon Sep 17 00:00:00 2001 From: yp Date: Wed, 11 Jun 2025 17:59:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BF=9D=E5=AD=98=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nodes/image.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/nodes/image.py b/nodes/image.py index 8551788..c333579 100644 --- a/nodes/image.py +++ b/nodes/image.py @@ -2,6 +2,8 @@ import os import uuid from PIL import Image import numpy as np +import torch # 添加这一行,用于类型检查 + class SaveImagePath: @classmethod def INPUT_TYPES(s): @@ -10,24 +12,36 @@ class SaveImagePath: "image_path": ("IMAGE", {"forceInput": True}), } } + RETURN_TYPES = ("STRING",) FUNCTION = "load" CATEGORY = "不忘科技-自定义节点🚩" + def load(self, image_path): + # 如果是torch.Tensor类型,转换为numpy数组 + if isinstance(image_path, torch.Tensor): + image_path = image_path.cpu().numpy() + # 确保数据类型为uint8且在0 - 255范围内 if image_path.dtype != np.uint8: image_path = np.clip(image_path, 0, 255).astype(np.uint8) + # 如果是单通道图像,转换为3通道 if len(image_path.shape) == 2: image_path = np.stack([image_path] * 3, axis=-1) + # 如果是通道优先格式 (C, H, W),转换为通道最后格式 (H, W, C) elif len(image_path.shape) == 3 and image_path.shape[0] <= 4: image_path = np.transpose(image_path, (1, 2, 0)) + pil_image = Image.fromarray(image_path) + output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") if not os.path.exists(output_dir): os.makedirs(output_dir) + file_name = "%s.jpg" % str(uuid.uuid4()) p = os.path.join(output_dir, file_name) pil_image.save(p) - return (p,) \ No newline at end of file + + return (p,)