ComfyUI-CustomNode/ext/image.py

72 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
from PIL import Image
import numpy as np
import torch
class SaveImagePath:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image_path": ("IMAGE", {"forceInput": True}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "load"
CATEGORY = "不忘科技-自定义节点🚩"
def load(self, image_path):
# 如果是torch.Tensor类型转换为numpy数组
if isinstance(image_path, torch.Tensor):
image_path = image_path.cpu().numpy()
# 去除多余的维度,如果形状是(1, 1, height, width, channels)或(1, height, width, channels)等情况
while len(image_path.shape) > 3:
image_path = image_path.squeeze(0)
# 如果是通道优先格式 (C, H, W),转换为通道最后格式 (H, W, C)
if len(image_path.shape) == 3 and image_path.shape[0] <= 4:
image_path = np.transpose(image_path, (1, 2, 0))
# 如果是单通道图像转换为3通道
if len(image_path.shape) == 2:
image_path = np.stack([image_path] * 3, axis=-1)
# 数据范围和类型转换 - 这是关键修复
if image_path.dtype == np.float32 or image_path.dtype == np.float64:
# ComfyUI图像数据通常是0-1范围的浮点数
if image_path.max() <= 1.0:
# 从0-1范围转换到0-255范围
image_path = (image_path * 255.0).astype(np.uint8)
else:
# 如果已经是0-255范围直接转换类型
image_path = np.clip(image_path, 0, 255).astype(np.uint8)
elif image_path.dtype != np.uint8:
# 其他数据类型确保在0-255范围内
image_path = np.clip(image_path, 0, 255).astype(np.uint8)
pil_image = Image.fromarray(image_path)
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
# base_dir = os.path.dirname(os.path.abspath(__file__))
# output_dir = '/root/comfy/ComfyUI/output'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_name = "%s.png" % str(uuid.uuid4())
p = os.path.join(output_dir, file_name)
pil_image.save(p)
return (p,)
# 节点类定义结束以下是用于注册节点的字典结构通常在实际使用中由ComfyUI等框架来解析和注册
NODE_CLASS_MAPPINGS = {
"SaveImagePath": SaveImagePath
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveImagePath": "保存图片路径"
}