diff --git a/__init__.py b/__init__.py index 3e9a7a9..0fd7cd3 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,4 @@ +from nodes.image import SaveImagePath from .nodes.heygem import HeyGemF2F, HeyGemF2FFromFile from .nodes.s3 import S3Download, S3Upload, S3UploadURL from .nodes.text import * @@ -35,6 +36,7 @@ NODE_CLASS_MAPPINGS = { "LoadTextCustomOnline": LoadTextOnline, "HeyGemF2F": HeyGemF2F, "HeyGemF2FFromFile": HeyGemF2FFromFile, + "SaveImagePath": SaveImagePath, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -58,5 +60,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadTextCustom": "读取文本文件(本地)", "LoadTextCustomOnline": "读取文本文件(线上)", "HeyGemF2F": "HeyGem口型同步(API, 传入文件Tensor)", - "HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)" + "HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)", + "SaveImagePath": "保存图片" } diff --git a/nodes/image.py b/nodes/image.py new file mode 100644 index 0000000..2a8c73f --- /dev/null +++ b/nodes/image.py @@ -0,0 +1,27 @@ +import os.path +import uuid + +import torch +import torchvision + + +class SaveImagePath: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image_path":("IMAGE", {"forceInput": True}), + } + } + + RETURN_TYPES = ("STRING",) + + FUNCTION = "load" + + CATEGORY = "不忘科技-自定义节点🚩" + + def load(self, image_path:torch.Tensor): + u = uuid.uuid4() + p = os.path.join(os.path.dirname(os.path.abspath(__file__)),"output","%s.jpg" % str(u)) + torchvision.utils.save_image(image_path, p) + return (p,)