PERF 合并读取文本文件, 读取图片同类型节点

This commit is contained in:
kyj@bowong.ai 2025-07-15 17:31:39 +08:00
parent 42f3db768c
commit 0b542b5d3f
3 changed files with 37 additions and 133 deletions

View File

@ -2,10 +2,10 @@ from .nodes.image_modal_nodes import ModalEditCustom, ModalClothesMask, ModalMid
ModalMidJourneyDescribeImage ModalMidJourneyDescribeImage
from .nodes.image_face_nodes import FaceDetect, FaceExtract from .nodes.image_face_nodes import FaceDetect, FaceExtract
from .nodes.image_gesture_nodes import JMGestureCorrect from .nodes.image_gesture_nodes import JMGestureCorrect
from .nodes.image_nodes import SaveImagePath, LoadNetImg, SaveImageWithOutput from .nodes.image_nodes import SaveImagePath, SaveImageWithOutput, LoadImgOptional
from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate
from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL
from .nodes.text_nodes import StringEmptyJudgement, LoadTextLocal, LoadTextOnline, RandomLineSelector from .nodes.text_nodes import StringEmptyJudgement, LoadText, RandomLineSelector
from .nodes.util_nodes import LogToDB, TaskIdGenerate, TraverseFolder, UnloadAllModels, VodToLocalNode, \ from .nodes.util_nodes import LogToDB, TaskIdGenerate, TraverseFolder, UnloadAllModels, VodToLocalNode, \
PlugAndPlayWebhook PlugAndPlayWebhook
from .nodes.video_lipsync_nodes import HeyGemF2F, HeyGemF2FFromFile from .nodes.video_lipsync_nodes import HeyGemF2F, HeyGemF2FFromFile
@ -28,12 +28,11 @@ NODE_CLASS_MAPPINGS = {
"StringEmptyJudgement": StringEmptyJudgement, "StringEmptyJudgement": StringEmptyJudgement,
"unloadAllModels": UnloadAllModels, "unloadAllModels": UnloadAllModels,
"TraverseFolder": TraverseFolder, "TraverseFolder": TraverseFolder,
"LoadTextCustom": LoadTextLocal, "LoadTextCustom": LoadText,
"LoadTextCustomOnline": LoadTextOnline,
"HeyGemF2F": HeyGemF2F, "HeyGemF2F": HeyGemF2F,
"HeyGemF2FFromFile": HeyGemF2FFromFile, "HeyGemF2FFromFile": HeyGemF2FFromFile,
"SaveImagePath": SaveImagePath, "SaveImagePath": SaveImagePath,
"LoadNetImg": LoadNetImg, "LoadImgCustom": LoadImgOptional,
"TaskIdGenerate": TaskIdGenerate, "TaskIdGenerate": TaskIdGenerate,
"RandomLineSelector": RandomLineSelector, "RandomLineSelector": RandomLineSelector,
"PlugAndPlayWebhook": PlugAndPlayWebhook, "PlugAndPlayWebhook": PlugAndPlayWebhook,
@ -66,12 +65,11 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StringEmptyJudgement": "字符串是否为空", "StringEmptyJudgement": "字符串是否为空",
"unloadAllModels": "卸载所有已加载模型", "unloadAllModels": "卸载所有已加载模型",
"TraverseFolder": "遍历文件夹", "TraverseFolder": "遍历文件夹",
"LoadTextCustom": "读取文本文件(本地)", "LoadTextCustom": "读取文本文件(file_path优先)",
"LoadTextCustomOnline": "读取文本文件(线上)",
"HeyGemF2F": "HeyGem口型同步(API, 传入文件Tensor)", "HeyGemF2F": "HeyGem口型同步(API, 传入文件Tensor)",
"HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)", "HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)",
"SaveImagePath": "保存图片", "SaveImagePath": "保存图片",
"LoadNetImg": "加载网络图片", "LoadImgCustom": "加载图片(URL/本地, URL优先)",
"TaskIdGenerate": "TaskID生成器", "TaskIdGenerate": "TaskID生成器",
"RandomLineSelector": "随机选择一行内容", "RandomLineSelector": "随机选择一行内容",
"PlugAndPlayWebhook": "Webhook转发器", "PlugAndPlayWebhook": "Webhook转发器",

View File

@ -2,23 +2,29 @@ import os
import uuid import uuid
from io import BytesIO from io import BytesIO
import loguru
import numpy as np import numpy as np
import requests import requests
import torch import torch
from PIL import Image from PIL import Image
import folder_paths
# 定义节点类 # 定义节点类
class LoadNetImg: class LoadImgOptional:
# 定义节点输入类型 # 定义节点输入类型
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return { return {
"required": { "required": {
"image_url": ("STRING", { "image_url": ("STRING", {
"default": "https://example.com/sample.jpg", "default": "https://example.com/sample.jpg",
"multiline": False "multiline": False
}), }),
"image": (sorted(files), {"image_upload": True})
} }
} }
@ -29,28 +35,27 @@ class LoadNetImg:
OUTPUT_NODE = False # 不允许该节点直接作为最终输出节点 OUTPUT_NODE = False # 不允许该节点直接作为最终输出节点
CATEGORY = "不忘科技-自定义节点🚩/图片" # 节点所属类别(在 ComfyUI 界面中分类) CATEGORY = "不忘科技-自定义节点🚩/图片" # 节点所属类别(在 ComfyUI 界面中分类)
def load_image_task(self, image_url): def load_image_task(self, image_url, image):
try: try:
if not image_url or not image_url.strip(): if not image_url or len(image_url.strip()) == 0 or image_url == "https://example.com/sample.jpg":
raise ValueError("需要提供图片URL") loguru.logger.info("读取本地文件")
image_path = folder_paths.get_annotated_filepath(image)
# 下载网络图片 with open(image_path, "rb") as image_file:
response = requests.get(image_url) image = image_file.read()
response.raise_for_status() # 请求失败时抛出异常 else:
image = Image.open(BytesIO(response.content)).convert("RGB") loguru.logger.info("读取线上文件")
response = requests.get(image_url)
response.raise_for_status()
image = response.content
image = Image.open(BytesIO(image)).convert("RGB")
# 按照官方格式转换图像数据 # 按照官方格式转换图像数据
# Convert to numpy array and normalize to 0-1
image_array = np.array(image).astype(np.float32) / 255.0 image_array = np.array(image).astype(np.float32) / 255.0
# Convert to torch tensor and add batch dimension image_tensor = torch.from_numpy(image_array).unsqueeze(0)
image_tensor = torch.from_numpy(image_array)[None,]
return (image_tensor,) # 返回torch张量 return (image_tensor,)
except Exception as e: except Exception as e:
print(f"Error loading image: {e}") raise Exception(f"Error loading image: {e}")
# 返回一个空的黑色图片作为错误处理
empty_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
return (empty_image,)
class SaveImagePath: class SaveImagePath:

View File

@ -5,89 +5,16 @@ import time
import folder_paths import folder_paths
def get_allowed_dirs(): class LoadText:
return {
"input": "$input/**/*.txt",
"output": "$output/**/*.txt",
"temp": "$temp/**/*.txt"
}
def get_valid_dirs():
return get_allowed_dirs().keys()
def get_dir_from_name(name):
dirs = get_allowed_dirs()
if name not in dirs:
raise KeyError(name + " dir not found")
path = dirs[name]
path = path.replace("$input", folder_paths.get_input_directory())
path = path.replace("$output", folder_paths.get_output_directory())
path = path.replace("$temp", folder_paths.get_temp_directory())
return path
def is_child_dir(parent_path, child_path):
parent_path = os.path.abspath(parent_path)
child_path = os.path.abspath(child_path)
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
def get_real_path(dir):
dir = dir.replace("/**/", "/")
dir = os.path.abspath(dir)
dir = os.path.split(dir)[0]
return dir
def get_file(root_dir, file):
if file == "[none]" or not file or not file.strip():
raise ValueError("No file")
root_dir = get_dir_from_name(root_dir)
root_dir = get_real_path(root_dir)
if not os.path.exists(root_dir):
os.mkdir(root_dir)
full_path = os.path.join(root_dir, file)
if not is_child_dir(root_dir, full_path):
raise ReferenceError()
return full_path
class LoadTextLocal:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["text"])
return { return {
"required": { "required": {
"root_dir": (list(get_valid_dirs()), {}), "file_path": ("STRING", {"default": "/path/to/file"}),
"file": (["[none]"], { "file": (sorted(files),),
"pysssss.binding": [{
"source": "root_dir",
"callback": [{
"type": "set",
"target": "$this.disabled",
"value": True
}, {
"type": "fetch",
"url": "/pysssss/text-file/{$source.value}",
"then": [{
"type": "set",
"target": "$this.options.values",
"value": "$result"
}, {
"type": "validate-combo"
}, {
"type": "set",
"target": "$this.disabled",
"value": False
}]
}],
}]
}),
"encoding": ("STRING", {"default": "utf-8"}), "encoding": ("STRING", {"default": "utf-8"}),
}, },
} }
@ -98,35 +25,9 @@ class LoadTextLocal:
CATEGORY = "不忘科技-自定义节点🚩/文本" CATEGORY = "不忘科技-自定义节点🚩/文本"
@classmethod def load(self, file_path, file, encoding):
def VALIDATE_INPUTS(self, root_dir, file, **kwargs): if not (file_path and len(file_path.strip()) > 0 and file_path != "/path/to/file"):
if file == "[none]" or not file or not file.strip(): file_path = folder_paths.get_annotated_filepath(file)
return True
get_file(root_dir, file)
return True
def load(self, root_dir, file, encoding):
with open(get_file(root_dir, file), "r", encoding=encoding) as f:
return (f.read(),)
class LoadTextOnline:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"file_path": ("STRING", {"default": "input/"}),
"encoding": ("STRING", {"default": "utf-8"}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "load"
CATEGORY = "不忘科技-自定义节点🚩/文本"
def load(self, file_path, encoding):
with open(file_path, "r", encoding=encoding) as f: with open(file_path, "r", encoding=encoding) as f:
return (f.read(),) return (f.read(),)