ComfyUI-CustomNode/nodes/save_node.py

168 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import mimetypes
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from urllib.parse import urlparse
import folder_paths
import requests
import torch
from PIL import Image
class ExtSaveNode:
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=10)
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {},
"optional": {
# multiline=True 可以让UI中的输入框更大但处理逻辑已兼容多行
"url_input": ("STRING", {"multiline": True, "default": ""}),
"image_tensor_input": ("IMAGE",),
"subdirectory": ("STRING", {"multiline": False, "default": ""}),
"download_file_type": (["auto", "image", "video", "other"],),
"image_file_prefix": ("STRING", {"multiline": False, "default": "ComfyUI_Image_"}),
"image_file_format": (["png", "jpeg"],),
"jpeg_quality": ("INT", {"default": 90, "min": 1, "max": 100}),
}
}
RETURN_TYPES = ("STRING", "STRING")
RETURN_NAMES = ("downloaded_paths", "saved_image_paths")
FUNCTION = "process_inputs"
CATEGORY = "不忘科技-自定义节点🚩/utils/文件保存"
def _get_save_path(self, subdirectory: str) -> str:
full_path = os.path.join(self.output_dir, subdirectory)
os.makedirs(full_path, exist_ok=True)
return full_path
def _download_file_threaded(self, url, save_path, file_type):
try:
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path)
if not filename or "." not in filename:
try:
with requests.head(url, allow_redirects=True, timeout=60) as h:
h.raise_for_status()
content_type = h.headers.get('content-type')
ext = mimetypes.guess_extension(content_type) if content_type else None
final_ext = ext if ext else ""
filename = f"downloaded_file_{os.urandom(4).hex()}{final_ext}"
except requests.exceptions.RequestException as e:
print(f"Could not determine filename from headers for {url}: {e}")
filename = f"downloaded_file_{os.urandom(4).hex()}"
file_path = os.path.join(save_path, filename)
if os.path.exists(file_path):
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("_%Y%m%d%H%M%S%f")[:-3]
filename = f"{name}{timestamp}{ext}"
file_path = os.path.join(save_path, filename)
print(f"Starting download of {url} to {file_path}")
with requests.get(url, stream=True, timeout=300) as r:
r.raise_for_status()
with open(file_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Finished downloading {url} to {file_path}")
return file_path
except Exception as e:
print(f"Error downloading {url}: {e}")
return f"Download Error: {e}"
def _save_image_tensor(self, images: torch.Tensor, save_path: str, file_prefix: str, file_format: str,
jpeg_quality: int):
"""保存图像Tensor的核心逻辑"""
saved_paths = []
for i, image_tensor in enumerate(images):
try:
img_np = (image_tensor.cpu().numpy() * 255).astype('uint8')
if img_np.shape[2] == 1:
img_pil = Image.fromarray(img_np.squeeze(axis=2), mode='L')
elif img_np.shape[2] == 4:
img_pil = Image.fromarray(img_np, mode='RGBA')
else:
img_pil = Image.fromarray(img_np, mode='RGB')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
filename = f"{file_prefix}{timestamp}_{i}.{file_format}"
full_path = os.path.join(save_path, filename)
if file_format == "png":
img_pil.save(full_path, format="PNG")
elif file_format == "jpeg":
img_pil.save(full_path, format="JPEG", quality=jpeg_quality)
saved_paths.append(full_path)
except Exception as e:
print(f"Error saving image tensor {i}: {e}")
saved_paths.append(f"Save Error: {e}")
return ", ".join(saved_paths)
def process_inputs(self,
url_input: str = "",
image_tensor_input: torch.Tensor = None,
subdirectory: str = "",
download_file_type: str = "auto",
image_file_prefix: str = "ComfyUI_Image_",
image_file_format: str = "png",
jpeg_quality: int = 90):
downloaded_paths_output = ""
saved_image_paths_output = ""
final_save_path = self._get_save_path(subdirectory)
if url_input:
url_input = url_input.strip()
if '\n' in url_input:
lines = [line.strip() for line in url_input.strip().split('\n')]
else:
lines = [line.strip() for line in url_input.strip().split()]
urls_to_download = [line for line in lines if line.startswith(('http://', 'https://'))]
if urls_to_download:
print(f"Found {len(urls_to_download)} URLs to download. Saving to: {final_save_path}")
futures = {
self.executor.submit(self._download_file_threaded, url, final_save_path, download_file_type): url
for url in urls_to_download}
downloaded_paths = []
for future in as_completed(futures):
result_path = future.result()
downloaded_paths.append(result_path)
downloaded_paths_output = ", ".join(downloaded_paths)
if image_tensor_input is not None and isinstance(image_tensor_input,
torch.Tensor) and image_tensor_input.numel() > 0:
print(f"Detected Image Tensor input, will save to: {final_save_path}")
saved_image_paths_output = self._save_image_tensor(
image_tensor_input,
final_save_path,
image_file_prefix,
image_file_format,
jpeg_quality
)
return (downloaded_paths_output, saved_image_paths_output)
# NODE_CLASS_MAPPINGS = {
# "UniversalSaver": ExtSaveNode
# }
#
# NODE_DISPLAY_NAME_MAPPINGS = {
# "UniversalSaver": "通用文件保存"
# }