PERF 合并读取文本文件, 读取图片同类型节点
This commit is contained in:
parent
42f3db768c
commit
0b542b5d3f
14
__init__.py
14
__init__.py
|
|
@ -2,10 +2,10 @@ from .nodes.image_modal_nodes import ModalEditCustom, ModalClothesMask, ModalMid
|
|||
ModalMidJourneyDescribeImage
|
||||
from .nodes.image_face_nodes import FaceDetect, FaceExtract
|
||||
from .nodes.image_gesture_nodes import JMGestureCorrect
|
||||
from .nodes.image_nodes import SaveImagePath, LoadNetImg, SaveImageWithOutput
|
||||
from .nodes.image_nodes import SaveImagePath, SaveImageWithOutput, LoadImgOptional
|
||||
from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate
|
||||
from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL
|
||||
from .nodes.text_nodes import StringEmptyJudgement, LoadTextLocal, LoadTextOnline, RandomLineSelector
|
||||
from .nodes.text_nodes import StringEmptyJudgement, LoadText, RandomLineSelector
|
||||
from .nodes.util_nodes import LogToDB, TaskIdGenerate, TraverseFolder, UnloadAllModels, VodToLocalNode, \
|
||||
PlugAndPlayWebhook
|
||||
from .nodes.video_lipsync_nodes import HeyGemF2F, HeyGemF2FFromFile
|
||||
|
|
@ -28,12 +28,11 @@ NODE_CLASS_MAPPINGS = {
|
|||
"StringEmptyJudgement": StringEmptyJudgement,
|
||||
"unloadAllModels": UnloadAllModels,
|
||||
"TraverseFolder": TraverseFolder,
|
||||
"LoadTextCustom": LoadTextLocal,
|
||||
"LoadTextCustomOnline": LoadTextOnline,
|
||||
"LoadTextCustom": LoadText,
|
||||
"HeyGemF2F": HeyGemF2F,
|
||||
"HeyGemF2FFromFile": HeyGemF2FFromFile,
|
||||
"SaveImagePath": SaveImagePath,
|
||||
"LoadNetImg": LoadNetImg,
|
||||
"LoadImgCustom": LoadImgOptional,
|
||||
"TaskIdGenerate": TaskIdGenerate,
|
||||
"RandomLineSelector": RandomLineSelector,
|
||||
"PlugAndPlayWebhook": PlugAndPlayWebhook,
|
||||
|
|
@ -66,12 +65,11 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"StringEmptyJudgement": "字符串是否为空",
|
||||
"unloadAllModels": "卸载所有已加载模型",
|
||||
"TraverseFolder": "遍历文件夹",
|
||||
"LoadTextCustom": "读取文本文件(本地)",
|
||||
"LoadTextCustomOnline": "读取文本文件(线上)",
|
||||
"LoadTextCustom": "读取文本文件(file_path优先)",
|
||||
"HeyGemF2F": "HeyGem口型同步(API, 传入文件Tensor)",
|
||||
"HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)",
|
||||
"SaveImagePath": "保存图片",
|
||||
"LoadNetImg": "加载网络图片",
|
||||
"LoadImgCustom": "加载图片(URL/本地, URL优先)",
|
||||
"TaskIdGenerate": "TaskID生成器",
|
||||
"RandomLineSelector": "随机选择一行内容",
|
||||
"PlugAndPlayWebhook": "Webhook转发器",
|
||||
|
|
|
|||
|
|
@ -2,23 +2,29 @@ import os
|
|||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
import loguru
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
import folder_paths
|
||||
|
||||
|
||||
# 定义节点类
|
||||
class LoadNetImg:
|
||||
class LoadImgOptional:
|
||||
# 定义节点输入类型
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
files = folder_paths.filter_files_content_types(files, ["image"])
|
||||
return {
|
||||
"required": {
|
||||
"image_url": ("STRING", {
|
||||
"default": "https://example.com/sample.jpg",
|
||||
"multiline": False
|
||||
}),
|
||||
"image": (sorted(files), {"image_upload": True})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -29,28 +35,27 @@ class LoadNetImg:
|
|||
OUTPUT_NODE = False # 不允许该节点直接作为最终输出节点
|
||||
CATEGORY = "不忘科技-自定义节点🚩/图片" # 节点所属类别(在 ComfyUI 界面中分类)
|
||||
|
||||
def load_image_task(self, image_url):
|
||||
def load_image_task(self, image_url, image):
|
||||
try:
|
||||
if not image_url or not image_url.strip():
|
||||
raise ValueError("需要提供图片URL")
|
||||
|
||||
# 下载网络图片
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # 请求失败时抛出异常
|
||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
if not image_url or len(image_url.strip()) == 0 or image_url == "https://example.com/sample.jpg":
|
||||
loguru.logger.info("读取本地文件")
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
with open(image_path, "rb") as image_file:
|
||||
image = image_file.read()
|
||||
else:
|
||||
loguru.logger.info("读取线上文件")
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status()
|
||||
image = response.content
|
||||
image = Image.open(BytesIO(image)).convert("RGB")
|
||||
|
||||
# 按照官方格式转换图像数据
|
||||
# Convert to numpy array and normalize to 0-1
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
# Convert to torch tensor and add batch dimension
|
||||
image_tensor = torch.from_numpy(image_array)[None,]
|
||||
image_tensor = torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
return (image_tensor,) # 返回torch张量
|
||||
return (image_tensor,)
|
||||
except Exception as e:
|
||||
print(f"Error loading image: {e}")
|
||||
# 返回一个空的黑色图片作为错误处理
|
||||
empty_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
|
||||
return (empty_image,)
|
||||
raise Exception(f"Error loading image: {e}")
|
||||
|
||||
|
||||
class SaveImagePath:
|
||||
|
|
|
|||
|
|
@ -5,89 +5,16 @@ import time
|
|||
import folder_paths
|
||||
|
||||
|
||||
def get_allowed_dirs():
|
||||
return {
|
||||
"input": "$input/**/*.txt",
|
||||
"output": "$output/**/*.txt",
|
||||
"temp": "$temp/**/*.txt"
|
||||
}
|
||||
|
||||
|
||||
def get_valid_dirs():
|
||||
return get_allowed_dirs().keys()
|
||||
|
||||
|
||||
def get_dir_from_name(name):
|
||||
dirs = get_allowed_dirs()
|
||||
if name not in dirs:
|
||||
raise KeyError(name + " dir not found")
|
||||
|
||||
path = dirs[name]
|
||||
path = path.replace("$input", folder_paths.get_input_directory())
|
||||
path = path.replace("$output", folder_paths.get_output_directory())
|
||||
path = path.replace("$temp", folder_paths.get_temp_directory())
|
||||
return path
|
||||
|
||||
|
||||
def is_child_dir(parent_path, child_path):
|
||||
parent_path = os.path.abspath(parent_path)
|
||||
child_path = os.path.abspath(child_path)
|
||||
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
|
||||
|
||||
|
||||
def get_real_path(dir):
|
||||
dir = dir.replace("/**/", "/")
|
||||
dir = os.path.abspath(dir)
|
||||
dir = os.path.split(dir)[0]
|
||||
return dir
|
||||
|
||||
|
||||
def get_file(root_dir, file):
|
||||
if file == "[none]" or not file or not file.strip():
|
||||
raise ValueError("No file")
|
||||
|
||||
root_dir = get_dir_from_name(root_dir)
|
||||
root_dir = get_real_path(root_dir)
|
||||
if not os.path.exists(root_dir):
|
||||
os.mkdir(root_dir)
|
||||
full_path = os.path.join(root_dir, file)
|
||||
|
||||
if not is_child_dir(root_dir, full_path):
|
||||
raise ReferenceError()
|
||||
|
||||
return full_path
|
||||
|
||||
|
||||
class LoadTextLocal:
|
||||
class LoadText:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
files = folder_paths.filter_files_content_types(files, ["text"])
|
||||
return {
|
||||
"required": {
|
||||
"root_dir": (list(get_valid_dirs()), {}),
|
||||
"file": (["[none]"], {
|
||||
"pysssss.binding": [{
|
||||
"source": "root_dir",
|
||||
"callback": [{
|
||||
"type": "set",
|
||||
"target": "$this.disabled",
|
||||
"value": True
|
||||
}, {
|
||||
"type": "fetch",
|
||||
"url": "/pysssss/text-file/{$source.value}",
|
||||
"then": [{
|
||||
"type": "set",
|
||||
"target": "$this.options.values",
|
||||
"value": "$result"
|
||||
}, {
|
||||
"type": "validate-combo"
|
||||
}, {
|
||||
"type": "set",
|
||||
"target": "$this.disabled",
|
||||
"value": False
|
||||
}]
|
||||
}],
|
||||
}]
|
||||
}),
|
||||
"file_path": ("STRING", {"default": "/path/to/file"}),
|
||||
"file": (sorted(files),),
|
||||
"encoding": ("STRING", {"default": "utf-8"}),
|
||||
},
|
||||
}
|
||||
|
|
@ -98,35 +25,9 @@ class LoadTextLocal:
|
|||
|
||||
CATEGORY = "不忘科技-自定义节点🚩/文本"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(self, root_dir, file, **kwargs):
|
||||
if file == "[none]" or not file or not file.strip():
|
||||
return True
|
||||
get_file(root_dir, file)
|
||||
return True
|
||||
|
||||
def load(self, root_dir, file, encoding):
|
||||
with open(get_file(root_dir, file), "r", encoding=encoding) as f:
|
||||
return (f.read(),)
|
||||
|
||||
|
||||
class LoadTextOnline:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"file_path": ("STRING", {"default": "input/"}),
|
||||
"encoding": ("STRING", {"default": "utf-8"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
|
||||
FUNCTION = "load"
|
||||
|
||||
CATEGORY = "不忘科技-自定义节点🚩/文本"
|
||||
|
||||
def load(self, file_path, encoding):
|
||||
def load(self, file_path, file, encoding):
|
||||
if not (file_path and len(file_path.strip()) > 0 and file_path != "/path/to/file"):
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
with open(file_path, "r", encoding=encoding) as f:
|
||||
return (f.read(),)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue