168 lines
6.8 KiB
Python
168 lines
6.8 KiB
Python
import mimetypes
|
||
import os
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from datetime import datetime
|
||
from urllib.parse import urlparse
|
||
|
||
import folder_paths
|
||
import requests
|
||
import torch
|
||
from PIL import Image
|
||
|
||
|
||
class ExtSaveNode:
|
||
|
||
def __init__(self):
|
||
self.executor = ThreadPoolExecutor(max_workers=10)
|
||
self.output_dir = folder_paths.get_output_directory()
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {},
|
||
"optional": {
|
||
# multiline=True 可以让UI中的输入框更大,但处理逻辑已兼容多行
|
||
"url_input": ("STRING", {"multiline": True, "default": ""}),
|
||
"image_tensor_input": ("IMAGE",),
|
||
"subdirectory": ("STRING", {"multiline": False, "default": ""}),
|
||
"download_file_type": (["auto", "image", "video", "other"],),
|
||
"image_file_prefix": ("STRING", {"multiline": False, "default": "ComfyUI_Image_"}),
|
||
"image_file_format": (["png", "jpeg"],),
|
||
"jpeg_quality": ("INT", {"default": 90, "min": 1, "max": 100}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING", "STRING")
|
||
RETURN_NAMES = ("downloaded_paths", "saved_image_paths")
|
||
FUNCTION = "process_inputs"
|
||
CATEGORY = "不忘科技-自定义节点🚩/utils/文件保存"
|
||
|
||
def _get_save_path(self, subdirectory: str) -> str:
|
||
full_path = os.path.join(self.output_dir, subdirectory)
|
||
os.makedirs(full_path, exist_ok=True)
|
||
return full_path
|
||
|
||
def _download_file_threaded(self, url, save_path, file_type):
|
||
try:
|
||
parsed_url = urlparse(url)
|
||
filename = os.path.basename(parsed_url.path)
|
||
|
||
if not filename or "." not in filename:
|
||
try:
|
||
with requests.head(url, allow_redirects=True, timeout=60) as h:
|
||
h.raise_for_status()
|
||
content_type = h.headers.get('content-type')
|
||
ext = mimetypes.guess_extension(content_type) if content_type else None
|
||
final_ext = ext if ext else ""
|
||
filename = f"downloaded_file_{os.urandom(4).hex()}{final_ext}"
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"Could not determine filename from headers for {url}: {e}")
|
||
filename = f"downloaded_file_{os.urandom(4).hex()}"
|
||
|
||
file_path = os.path.join(save_path, filename)
|
||
|
||
if os.path.exists(file_path):
|
||
name, ext = os.path.splitext(filename)
|
||
timestamp = datetime.now().strftime("_%Y%m%d%H%M%S%f")[:-3]
|
||
filename = f"{name}{timestamp}{ext}"
|
||
file_path = os.path.join(save_path, filename)
|
||
|
||
print(f"Starting download of {url} to {file_path}")
|
||
with requests.get(url, stream=True, timeout=300) as r:
|
||
r.raise_for_status()
|
||
with open(file_path, 'wb') as f:
|
||
for chunk in r.iter_content(chunk_size=8192):
|
||
f.write(chunk)
|
||
print(f"Finished downloading {url} to {file_path}")
|
||
return file_path
|
||
except Exception as e:
|
||
print(f"Error downloading {url}: {e}")
|
||
return f"Download Error: {e}"
|
||
|
||
def _save_image_tensor(self, images: torch.Tensor, save_path: str, file_prefix: str, file_format: str,
|
||
jpeg_quality: int):
|
||
"""保存图像Tensor的核心逻辑"""
|
||
saved_paths = []
|
||
for i, image_tensor in enumerate(images):
|
||
try:
|
||
img_np = (image_tensor.cpu().numpy() * 255).astype('uint8')
|
||
|
||
if img_np.shape[2] == 1:
|
||
img_pil = Image.fromarray(img_np.squeeze(axis=2), mode='L')
|
||
elif img_np.shape[2] == 4:
|
||
img_pil = Image.fromarray(img_np, mode='RGBA')
|
||
else:
|
||
img_pil = Image.fromarray(img_np, mode='RGB')
|
||
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
|
||
filename = f"{file_prefix}{timestamp}_{i}.{file_format}"
|
||
full_path = os.path.join(save_path, filename)
|
||
|
||
if file_format == "png":
|
||
img_pil.save(full_path, format="PNG")
|
||
elif file_format == "jpeg":
|
||
img_pil.save(full_path, format="JPEG", quality=jpeg_quality)
|
||
|
||
saved_paths.append(full_path)
|
||
except Exception as e:
|
||
print(f"Error saving image tensor {i}: {e}")
|
||
saved_paths.append(f"Save Error: {e}")
|
||
|
||
return ", ".join(saved_paths)
|
||
|
||
def process_inputs(self,
|
||
url_input: str = "",
|
||
image_tensor_input: torch.Tensor = None,
|
||
subdirectory: str = "",
|
||
download_file_type: str = "auto",
|
||
image_file_prefix: str = "ComfyUI_Image_",
|
||
image_file_format: str = "png",
|
||
jpeg_quality: int = 90):
|
||
|
||
downloaded_paths_output = ""
|
||
saved_image_paths_output = ""
|
||
|
||
final_save_path = self._get_save_path(subdirectory)
|
||
|
||
if url_input:
|
||
url_input = url_input.strip()
|
||
if '\n' in url_input:
|
||
lines = [line.strip() for line in url_input.strip().split('\n')]
|
||
else:
|
||
lines = [line.strip() for line in url_input.strip().split()]
|
||
urls_to_download = [line for line in lines if line.startswith(('http://', 'https://'))]
|
||
|
||
if urls_to_download:
|
||
print(f"Found {len(urls_to_download)} URLs to download. Saving to: {final_save_path}")
|
||
|
||
futures = {
|
||
self.executor.submit(self._download_file_threaded, url, final_save_path, download_file_type): url
|
||
for url in urls_to_download}
|
||
|
||
downloaded_paths = []
|
||
for future in as_completed(futures):
|
||
result_path = future.result()
|
||
downloaded_paths.append(result_path)
|
||
|
||
downloaded_paths_output = ", ".join(downloaded_paths)
|
||
if image_tensor_input is not None and isinstance(image_tensor_input,
|
||
torch.Tensor) and image_tensor_input.numel() > 0:
|
||
print(f"Detected Image Tensor input, will save to: {final_save_path}")
|
||
saved_image_paths_output = self._save_image_tensor(
|
||
image_tensor_input,
|
||
final_save_path,
|
||
image_file_prefix,
|
||
image_file_format,
|
||
jpeg_quality
|
||
)
|
||
|
||
return (downloaded_paths_output, saved_image_paths_output)
|
||
|
||
|
||
# NODE_CLASS_MAPPINGS = {
|
||
# "UniversalSaver": ExtSaveNode
|
||
# }
|
||
#
|
||
# NODE_DISPLAY_NAME_MAPPINGS = {
|
||
# "UniversalSaver": "通用文件保存"
|
||
# } |