ComfyUI-CustomNode/nodes/file_upload.py

199 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import mimetypes
import uuid
import boto3
import os
from botocore.config import Config
import asyncio
import torch
import numpy as np
from PIL import Image
import folder_paths
# 尝试导入 scipy如果失败则给出提示
try:
import scipy.io.wavfile as wavfile
except ImportError:
print("------------------------------------------------------------------------------------")
print("[FileUploadNode] 提示: Scipy 库未安装, 如果需要处理音频输入, 请运行: pip install scipy")
print("------------------------------------------------------------------------------------")
# --- AWS S3 配置 ---
aws_settings = {
'access_key_id': 'AKIAYRH5NGRSWHN2L4M6',
'secret_access_key': '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',
'bucket_name': 'modal-media-cache',
'region': 'ap-northeast-2',
'cnd_endpoint': 'https://cdn.roasmax.cn'
}
# --- 核心上传逻辑 ---
async def upload_file_s3_v2(file_path: str, remove: bool = False, perpetual: bool = False):
"""
使用 boto3 客户端异步上传文件到 S3
"""
resp_data = {'status': False, 'data': '', 'msg': ''}
if not os.path.isfile(file_path):
resp_data['msg'] = f'文件不存在: {file_path}'
print(f"[FileUploadNode ERROR] {resp_data['msg']}")
return resp_data
try:
s3_client = boto3.client(
"s3",
aws_access_key_id=aws_settings['access_key_id'],
aws_secret_access_key=aws_settings['secret_access_key'],
region_name=aws_settings['region'],
endpoint_url="https://s3-accelerate.amazonaws.com",
config=Config(s3={'addressing_style': 'virtual'}, signature_version='s3v4')
)
suffix = os.path.splitext(file_path)[-1]
bucket_suffix = 'material/' if perpetual else 'upload/'
s3_key = f"{bucket_suffix}{uuid.uuid4().hex}{suffix}"
mime_type, _ = mimetypes.guess_type(file_path)
extra_args = {'ContentType': mime_type if mime_type else 'application/octet-stream'}
print(
f"[FileUploadNode INFO] 开始上传文件 {os.path.basename(file_path)} 到 S3 bucket '{aws_settings['bucket_name']}'...")
s3_client.upload_file(
Filename=os.path.abspath(file_path),
Bucket=aws_settings['bucket_name'],
Key=s3_key,
ExtraArgs=extra_args
)
cdn_url = f"{aws_settings['cnd_endpoint'].rstrip('/')}/{s3_key.lstrip('/')}"
resp_data.update(status=True, data=cdn_url, msg='文件成功上传到S3')
print(f"[FileUploadNode INFO] 文件成功上传到S3: {cdn_url}")
except Exception as e:
print(f"[FileUploadNode ERROR] 上传文件到S3时出错: {e}")
resp_data['msg'] = f'上传文件到S3时出错: {e}'
finally:
if remove and resp_data['status'] and os.path.exists(file_path):
try:
os.remove(file_path)
print(f"[FileUploadNode INFO] 源文件已根据选项删除: {file_path}")
except Exception as e:
print(f"[FileUploadNode ERROR] 删除源文件失败: {e}")
return resp_data
class FileUploadNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"perpetual": ("BOOLEAN", {"default": False}),
"remove_source_file": ("BOOLEAN", {"default": False}),
},
"optional": {
"video": ("*",),
"image": ("IMAGE",),
"audio": ("AUDIO",),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("file_url",)
FUNCTION = "upload_file"
CATEGORY = "不忘科技-自定义节点🚩/utils/通用文件上传"
def tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
image_np = tensor[0].cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
return Image.fromarray(image_np)
def save_pil_to_temp(self, pil_image: Image.Image) -> str:
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, _, _, _) = folder_paths.get_save_image_path("uploader_temp", output_dir)
filepath = os.path.join(full_output_folder, f"{filename}.png")
pil_image.save(filepath, 'PNG')
return filepath
def save_audio_tensor_to_temp(self, waveform_tensor: torch.Tensor, sample_rate: int) -> str:
if 'wavfile' not in globals():
raise ImportError("Scipy 库未安装。请在您的 ComfyUI 环境中运行 'pip install scipy' 来启用音频处理功能。")
waveform_np = waveform_tensor.cpu().numpy()
if waveform_np.ndim == 3: waveform_np = waveform_np[0]
if waveform_np.shape[0] < waveform_np.shape[1]: waveform_np = waveform_np.T
waveform_int16 = np.int16(waveform_np * 32767)
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, _, _, _) = folder_paths.get_save_image_path("uploader_temp_audio", output_dir)
filepath = os.path.join(full_output_folder, f"{filename}.wav")
wavfile.write(filepath, sample_rate, waveform_int16)
return filepath
def upload_file(self, perpetual, remove_source_file, image=None, audio=None, video=None):
resolved_path = None
if video is not None:
print('[FileUploadNode INFO] 检测到视频输入...')
unwrapped_input = video[0] if isinstance(video, (list, tuple)) and video else video
if hasattr(unwrapped_input, 'save_to'):
try:
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, _, _, _) = folder_paths.get_save_image_path("uploader_temp_video",
output_dir)
temp_video_path = os.path.join(full_output_folder, f"{filename}.mp4")
unwrapped_input.save_to(temp_video_path)
resolved_path = temp_video_path
except Exception as e:
return (f"ERROR: 保存视频时出错: {e}",)
else:
return (f"ERROR: 不支持的视频输入格式,无法找到 save_to() 方法。",)
elif image is not None:
print('[FileUploadNode INFO] 检测到图像输入,正在保存为临时文件...')
pil_image = self.tensor_to_pil(image)
resolved_path = self.save_pil_to_temp(pil_image)
elif audio is not None:
print('[FileUploadNode INFO] 检测到音频输入...')
audio_info = audio[0] if isinstance(audio, (list, tuple)) and audio else audio
if isinstance(audio_info, dict) and 'waveform' in audio_info and 'sample_rate' in audio_info:
print('[FileUploadNode INFO] 正在从 waveform 数据保存为临时 .wav 文件...')
try:
resolved_path = self.save_audio_tensor_to_temp(audio_info['waveform'], audio_info['sample_rate'])
except Exception as e:
return (f"ERROR: 保存音频张量时出错: {e}",)
else:
return (f"ERROR: 不支持的音频输入格式,需要包含 'waveform''sample_rate' 的字典。",)
else:
raise ValueError("ERROR: 没有提供有效的媒体输入 (视频/图像/音频)。")
# return ("ERROR: 没有提供有效的媒体输入 (视频/图像/音频)。",)
if not resolved_path or not os.path.exists(resolved_path):
return (f"ERROR: 解析后的文件路径无效或文件不存在: {resolved_path}",)
print(f"[FileUploadNode INFO] 最终待上传文件: {resolved_path}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(
upload_file_s3_v2(
file_path=resolved_path,
remove=remove_source_file,
perpetual=perpetual
)
)
loop.close()
if result['status']:
return (result['data'],)
else:
error_message = f"上传失败: {result['msg']}"
raise ValueError(error_message)
# return (f"ERROR: {error_message}",)
NODE_CLASS_MAPPINGS = {
"FileUploadNode": FileUploadNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FileUploadNode": "文件上传(s3,gs)"
}