From 0b542b5d3f4a5492b15ed1a0f15e0be8071e4b59 Mon Sep 17 00:00:00 2001 From: "kyj@bowong.ai" Date: Tue, 15 Jul 2025 17:31:39 +0800 Subject: [PATCH] =?UTF-8?q?PERF=20=E5=90=88=E5=B9=B6=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E6=96=87=E6=9C=AC=E6=96=87=E4=BB=B6,=20=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E5=90=8C=E7=B1=BB=E5=9E=8B=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __init__.py | 14 +++--- nodes/image_nodes.py | 39 ++++++++------- nodes/text_nodes.py | 117 ++++--------------------------------------- 3 files changed, 37 insertions(+), 133 deletions(-) diff --git a/__init__.py b/__init__.py index f8663a9..66471f0 100644 --- a/__init__.py +++ b/__init__.py @@ -2,10 +2,10 @@ from .nodes.image_modal_nodes import ModalEditCustom, ModalClothesMask, ModalMid ModalMidJourneyDescribeImage from .nodes.image_face_nodes import FaceDetect, FaceExtract from .nodes.image_gesture_nodes import JMGestureCorrect -from .nodes.image_nodes import SaveImagePath, LoadNetImg, SaveImageWithOutput +from .nodes.image_nodes import SaveImagePath, SaveImageWithOutput, LoadImgOptional from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL -from .nodes.text_nodes import StringEmptyJudgement, LoadTextLocal, LoadTextOnline, RandomLineSelector +from .nodes.text_nodes import StringEmptyJudgement, LoadText, RandomLineSelector from .nodes.util_nodes import LogToDB, TaskIdGenerate, TraverseFolder, UnloadAllModels, VodToLocalNode, \ PlugAndPlayWebhook from .nodes.video_lipsync_nodes import HeyGemF2F, HeyGemF2FFromFile @@ -28,12 +28,11 @@ NODE_CLASS_MAPPINGS = { "StringEmptyJudgement": StringEmptyJudgement, "unloadAllModels": UnloadAllModels, "TraverseFolder": TraverseFolder, - "LoadTextCustom": LoadTextLocal, - "LoadTextCustomOnline": LoadTextOnline, + "LoadTextCustom": LoadText, "HeyGemF2F": HeyGemF2F, "HeyGemF2FFromFile": HeyGemF2FFromFile, "SaveImagePath": SaveImagePath, - "LoadNetImg": LoadNetImg, + "LoadImgCustom": LoadImgOptional, "TaskIdGenerate": TaskIdGenerate, "RandomLineSelector": RandomLineSelector, "PlugAndPlayWebhook": PlugAndPlayWebhook, @@ -66,12 +65,11 @@ NODE_DISPLAY_NAME_MAPPINGS = { "StringEmptyJudgement": "字符串是否为空", "unloadAllModels": "卸载所有已加载模型", "TraverseFolder": "遍历文件夹", - "LoadTextCustom": "读取文本文件(本地)", - "LoadTextCustomOnline": "读取文本文件(线上)", + "LoadTextCustom": "读取文本文件(file_path优先)", "HeyGemF2F": "HeyGem口型同步(API, 传入文件Tensor)", "HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)", "SaveImagePath": "保存图片", - "LoadNetImg": "加载网络图片", + "LoadImgCustom": "加载图片(URL/本地, URL优先)", "TaskIdGenerate": "TaskID生成器", "RandomLineSelector": "随机选择一行内容", "PlugAndPlayWebhook": "Webhook转发器", diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 5f41b4c..68aa7c7 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2,23 +2,29 @@ import os import uuid from io import BytesIO +import loguru import numpy as np import requests import torch from PIL import Image +import folder_paths # 定义节点类 -class LoadNetImg: +class LoadImgOptional: # 定义节点输入类型 @classmethod def INPUT_TYPES(cls): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = folder_paths.filter_files_content_types(files, ["image"]) return { "required": { "image_url": ("STRING", { "default": "https://example.com/sample.jpg", "multiline": False }), + "image": (sorted(files), {"image_upload": True}) } } @@ -29,28 +35,27 @@ class LoadNetImg: OUTPUT_NODE = False # 不允许该节点直接作为最终输出节点 CATEGORY = "不忘科技-自定义节点🚩/图片" # 节点所属类别(在 ComfyUI 界面中分类) - def load_image_task(self, image_url): + def load_image_task(self, image_url, image): try: - if not image_url or not image_url.strip(): - raise ValueError("需要提供图片URL") - - # 下载网络图片 - response = requests.get(image_url) - response.raise_for_status() # 请求失败时抛出异常 - image = Image.open(BytesIO(response.content)).convert("RGB") + if not image_url or len(image_url.strip()) == 0 or image_url == "https://example.com/sample.jpg": + loguru.logger.info("读取本地文件") + image_path = folder_paths.get_annotated_filepath(image) + with open(image_path, "rb") as image_file: + image = image_file.read() + else: + loguru.logger.info("读取线上文件") + response = requests.get(image_url) + response.raise_for_status() + image = response.content + image = Image.open(BytesIO(image)).convert("RGB") # 按照官方格式转换图像数据 - # Convert to numpy array and normalize to 0-1 image_array = np.array(image).astype(np.float32) / 255.0 - # Convert to torch tensor and add batch dimension - image_tensor = torch.from_numpy(image_array)[None,] + image_tensor = torch.from_numpy(image_array).unsqueeze(0) - return (image_tensor,) # 返回torch张量 + return (image_tensor,) except Exception as e: - print(f"Error loading image: {e}") - # 返回一个空的黑色图片作为错误处理 - empty_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32) - return (empty_image,) + raise Exception(f"Error loading image: {e}") class SaveImagePath: diff --git a/nodes/text_nodes.py b/nodes/text_nodes.py index 6875c18..d949712 100644 --- a/nodes/text_nodes.py +++ b/nodes/text_nodes.py @@ -5,89 +5,16 @@ import time import folder_paths -def get_allowed_dirs(): - return { - "input": "$input/**/*.txt", - "output": "$output/**/*.txt", - "temp": "$temp/**/*.txt" - } - - -def get_valid_dirs(): - return get_allowed_dirs().keys() - - -def get_dir_from_name(name): - dirs = get_allowed_dirs() - if name not in dirs: - raise KeyError(name + " dir not found") - - path = dirs[name] - path = path.replace("$input", folder_paths.get_input_directory()) - path = path.replace("$output", folder_paths.get_output_directory()) - path = path.replace("$temp", folder_paths.get_temp_directory()) - return path - - -def is_child_dir(parent_path, child_path): - parent_path = os.path.abspath(parent_path) - child_path = os.path.abspath(child_path) - return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) - - -def get_real_path(dir): - dir = dir.replace("/**/", "/") - dir = os.path.abspath(dir) - dir = os.path.split(dir)[0] - return dir - - -def get_file(root_dir, file): - if file == "[none]" or not file or not file.strip(): - raise ValueError("No file") - - root_dir = get_dir_from_name(root_dir) - root_dir = get_real_path(root_dir) - if not os.path.exists(root_dir): - os.mkdir(root_dir) - full_path = os.path.join(root_dir, file) - - if not is_child_dir(root_dir, full_path): - raise ReferenceError() - - return full_path - - -class LoadTextLocal: +class LoadText: @classmethod def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = folder_paths.filter_files_content_types(files, ["text"]) return { "required": { - "root_dir": (list(get_valid_dirs()), {}), - "file": (["[none]"], { - "pysssss.binding": [{ - "source": "root_dir", - "callback": [{ - "type": "set", - "target": "$this.disabled", - "value": True - }, { - "type": "fetch", - "url": "/pysssss/text-file/{$source.value}", - "then": [{ - "type": "set", - "target": "$this.options.values", - "value": "$result" - }, { - "type": "validate-combo" - }, { - "type": "set", - "target": "$this.disabled", - "value": False - }] - }], - }] - }), + "file_path": ("STRING", {"default": "/path/to/file"}), + "file": (sorted(files),), "encoding": ("STRING", {"default": "utf-8"}), }, } @@ -98,35 +25,9 @@ class LoadTextLocal: CATEGORY = "不忘科技-自定义节点🚩/文本" - @classmethod - def VALIDATE_INPUTS(self, root_dir, file, **kwargs): - if file == "[none]" or not file or not file.strip(): - return True - get_file(root_dir, file) - return True - - def load(self, root_dir, file, encoding): - with open(get_file(root_dir, file), "r", encoding=encoding) as f: - return (f.read(),) - - -class LoadTextOnline: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "file_path": ("STRING", {"default": "input/"}), - "encoding": ("STRING", {"default": "utf-8"}), - } - } - - RETURN_TYPES = ("STRING",) - - FUNCTION = "load" - - CATEGORY = "不忘科技-自定义节点🚩/文本" - - def load(self, file_path, encoding): + def load(self, file_path, file, encoding): + if not (file_path and len(file_path.strip()) > 0 and file_path != "/path/to/file"): + file_path = folder_paths.get_annotated_filepath(file) with open(file_path, "r", encoding=encoding) as f: return (f.read(),)