diff --git a/__init__.py b/__init__.py index a1edd01..0b61d5a 100644 --- a/__init__.py +++ b/__init__.py @@ -7,7 +7,7 @@ from .nodes.cos import COSUpload, COSDownload from .nodes.face_detect import FaceDetect from .nodes.face_extract import FaceExtract from .nodes.log2db import LogToDB -from .nodes.videocut import VideoCut +from .nodes.videocut import VideoCut, VideoCutByFramePoint from .nodes.vod2local import VodToLocalNode # A dictionary that contains all nodes you want to export with their names @@ -18,6 +18,7 @@ NODE_CLASS_MAPPINGS = { "COSUpload": COSUpload, "COSDownload": COSDownload, "VideoCutCustom": VideoCut, + "VideoCutByFramePoint": VideoCutByFramePoint, "VodToLocal": VodToLocalNode, "LogToDB": LogToDB, "VideoPointCompute": VideoStartPointDurationCompute, @@ -35,6 +36,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "COSUpload": "COS上传", "COSDownload": "COS下载", "VideoCutCustom": "视频剪裁", + "VideoCutByFramePoint": "视频剪裁(精确帧位)", "VodToLocal": "腾讯云VOD下载", "LogToDB": "状态持久化DB", "VideoPointCompute": "视频帧位计算", diff --git a/nodes/compute_video_point.py b/nodes/compute_video_point.py index 0cfb560..8351448 100644 --- a/nodes/compute_video_point.py +++ b/nodes/compute_video_point.py @@ -4,14 +4,14 @@ from math import ceil def validate_time_format(time_str): - pattern = r'^([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9]|\d{1,2})$' + pattern = r'^([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9]|\d{1,2}).(\d{3})$' return bool(re.match(pattern, time_str)) def get_duration_wave(audio): waveform, sample_rate = audio["waveform"], audio["sample_rate"] # 防止话说不完 - return ceil(waveform.shape[2] / sample_rate) + 0.1 + return waveform.shape[2] / sample_rate class VideoStartPointDurationCompute: @@ -25,7 +25,7 @@ class VideoStartPointDurationCompute: }, } - RETURN_TYPES = ("INT", "INT",) + RETURN_TYPES = ("FLOAT", "FLOAT",) RETURN_NAMES = ("起始帧位", "帧数") FUNCTION = "compute" @@ -36,10 +36,11 @@ class VideoStartPointDurationCompute: if not validate_time_format(start_time): raise ValueError("start_time或者end_time时间格式不对(start_time or end_time is not in time format)") - time_format = "%H:%M:%S" + time_format = "%H:%M:%S.%f" start_dt = datetime.strptime(start_time, time_format) start_sec = (start_dt - datetime(1900, 1, 1)).total_seconds() - start_point = int(start_sec * fps) - print("audio duration %.2f s"%get_duration_wave(audio)) + start_point = start_sec * fps + print("audio duration %.3f s"%get_duration_wave(audio)) duration = get_duration_wave(audio) * fps return (start_point, duration,) + diff --git a/nodes/videocut.py b/nodes/videocut.py index 56cad6f..c2d173e 100644 --- a/nodes/videocut.py +++ b/nodes/videocut.py @@ -6,6 +6,9 @@ import uuid from datetime import datetime import ffmpy +import torchaudio +import torchvision.io + video_extensions = ['webm', 'mp4', 'mkv', 'gif', 'mov'] @@ -18,13 +21,13 @@ class VideoCut: return { "required": { "video_path": ("STRING",{"placeholder": "X://insert/path/here.mp4", "vhs_path_extensions": video_extensions}), - "start": ("STRING", {"default": "00:00:00"}), - "end": ("STRING", {"default": "00:00:10"}), + "start": ("STRING", {"default": "00:00:00.000"}), + "end": ("STRING", {"default": "00:00:10.000"}), }, } - RETURN_TYPES = ("STRING",) - RETURN_NAMES = ("视频路径",) + RETURN_TYPES = ("IMAGE","AUDIO") + RETURN_NAMES = ("视频帧","音频") FUNCTION = "cut" @@ -76,7 +79,17 @@ class VideoCut: "-c:v", "libx264", "-c:a", - "copy" + "libmp3lame", + "-reset_timestamps", + "1", + "-sc_threshold", + "0", + "-g", + "1" + "-force_key_frames", + "expr:gte(t, n_forced * 1)", + "-v", + "-8" ] }, ) @@ -94,8 +107,124 @@ class VideoCut: files = glob.glob( output.replace(str(uid), origin_fname).replace("%03d", "*") ) - return (str(files),) except: files = glob.glob(output.replace("%03d", "*")) traceback.print_exc() - return (str(files),) + video, audio, info = torchvision.io.read_video(files[0]) + video.mul_(255) + audio.unsqueeze_(0) + try: + os.remove(files[0]) + except: + pass + return (video, {"waveform":audio,"sample_rate":info["audio_fps"]},) + +class VideoCutByFramePoint: + """FFMPEG视频剪辑-帧位""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "video_path": ("STRING",{"placeholder": "X://insert/path/here.mp4", "vhs_path_extensions": video_extensions}), + "start_point": ("FLOAT", {"default": "0.0"}), + "duration": ("FLOAT", {"default": "10.0"}), + "fps": ("INT", {"default": "25"}), + "force_match_fps": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("IMAGE","AUDIO") + RETURN_NAMES = ("视频帧","音频") + + FUNCTION = "cut" + + # OUTPUT_NODE = False + + CATEGORY = "不忘科技-自定义节点🚩" + + def cut(self, video_path, start_point, duration, fps, force_match_fps): + # 原文件名 + origin_fname = ".".join(video_path.split(os.sep)[-1].split(".")[:-1]) + # 新文件名 复制改名适配ffmpeg + uid = uuid.uuid1() + temp_fname = os.sep.join( + [ + *video_path.split(os.sep)[:-1], + "%s.%s" % (str(uid), video_path.split(".")[-1]), + ] + ) + try: + shutil.copy(video_path, temp_fname) + except: + return ("请检查输入文件权限",) + video_path = temp_fname + # 组装输出文件名 + output_name = ".".join( + [ + *video_path.split(os.sep)[-1].split(".")[:-2], + video_path.split(os.sep)[-1].split(".")[-2] + + "_output_%s" % datetime.now().strftime("%Y%m%d_%H%M%S"), + video_path.split(os.sep)[-1].split(".")[-1], + ] + ) + output = ( + os.sep.join([*video_path.split(os.sep)[:-1], output_name]) + .replace( + os.sep.join(["ComfyUI", "input"]), os.sep.join(["ComfyUI", "output"]) + ) + .replace(" ", "") + ) + # 调用ffmpeg + ff = ffmpy.FFmpeg( + inputs={video_path: None}, + outputs={ + output: [ + "-ss", + "%.3f" % (start_point/fps), + "-t", + "%.3f" % (duration/fps), + "-c:v", + "libx264", + "-c:a", + "libmp3lame", + "-reset_timestamps", + "1", + "-sc_threshold", + "0", + "-g", + "1", + "-force_key_frames", + "expr:gte(t, n_forced * 1)", + "-r" if force_match_fps else "", + "%d" % fps if force_match_fps else "", + "-v", + "-8" + ] + }, + ) + print(ff.cmd) + ff.run() + # uuid填充改回原文件名 + try: + os.remove(temp_fname) + except: + pass + try: + files = glob.glob(output.replace("%03d", "*")) + for file in files: + shutil.move(file, file.replace(str(uid), origin_fname)) + files = glob.glob( + output.replace(str(uid), origin_fname).replace("%03d", "*") + ) + except: + files = glob.glob(output.replace("%03d", "*")) + traceback.print_exc() + video, audio, info = torchvision.io.read_video(files[0]) + video.mul_(255) + audio.unsqueeze_(0) + try: + os.remove(files[0]) + except: + pass + return (video, {"waveform":audio,"sample_rate":info["audio_fps"]},)