修改保存图片节点

This commit is contained in:
杨平 2025-06-11 17:59:22 +08:00
parent a06de2445c
commit a37d35b8b8
1 changed files with 15 additions and 1 deletions

View File

@ -2,6 +2,8 @@ import os
import uuid import uuid
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torch # 添加这一行,用于类型检查
class SaveImagePath: class SaveImagePath:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -10,24 +12,36 @@ class SaveImagePath:
"image_path": ("IMAGE", {"forceInput": True}), "image_path": ("IMAGE", {"forceInput": True}),
} }
} }
RETURN_TYPES = ("STRING",) RETURN_TYPES = ("STRING",)
FUNCTION = "load" FUNCTION = "load"
CATEGORY = "不忘科技-自定义节点🚩" CATEGORY = "不忘科技-自定义节点🚩"
def load(self, image_path): def load(self, image_path):
# 如果是torch.Tensor类型转换为numpy数组
if isinstance(image_path, torch.Tensor):
image_path = image_path.cpu().numpy()
# 确保数据类型为uint8且在0 - 255范围内 # 确保数据类型为uint8且在0 - 255范围内
if image_path.dtype != np.uint8: if image_path.dtype != np.uint8:
image_path = np.clip(image_path, 0, 255).astype(np.uint8) image_path = np.clip(image_path, 0, 255).astype(np.uint8)
# 如果是单通道图像转换为3通道 # 如果是单通道图像转换为3通道
if len(image_path.shape) == 2: if len(image_path.shape) == 2:
image_path = np.stack([image_path] * 3, axis=-1) image_path = np.stack([image_path] * 3, axis=-1)
# 如果是通道优先格式 (C, H, W),转换为通道最后格式 (H, W, C) # 如果是通道优先格式 (C, H, W),转换为通道最后格式 (H, W, C)
elif len(image_path.shape) == 3 and image_path.shape[0] <= 4: elif len(image_path.shape) == 3 and image_path.shape[0] <= 4:
image_path = np.transpose(image_path, (1, 2, 0)) image_path = np.transpose(image_path, (1, 2, 0))
pil_image = Image.fromarray(image_path) pil_image = Image.fromarray(image_path)
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
file_name = "%s.jpg" % str(uuid.uuid4()) file_name = "%s.jpg" % str(uuid.uuid4())
p = os.path.join(output_dir, file_name) p = os.path.join(output_dir, file_name)
pil_image.save(p) pil_image.save(p)
return (p,) return (p,)