307 lines
11 KiB
Python
307 lines
11 KiB
Python
import glob
|
|
import json
|
|
import os
|
|
import shutil
|
|
import traceback
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import yaml
|
|
from ultralytics import YOLO
|
|
from comfy import model_management
|
|
from qcloud_cos import CosConfig, CosClientError, CosServiceError
|
|
from qcloud_cos import CosS3Client
|
|
|
|
from .test_single_image import test_node
|
|
import ffmpy
|
|
|
|
video_extensions = ['webm', 'mp4', 'mkv', 'gif', 'mov']
|
|
|
|
|
|
class FaceDetect:
|
|
"""
|
|
人脸遮挡检测
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"main_seed": ("INT:seed", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
|
"model": (["convnext_tiny", "convnext_base"],),
|
|
"length": ("INT", {"default": 10, "min": 3, "max": 60, "step": 1}),
|
|
"threshold": ("FLOAT", {"default": 94, "min": 55, "max": 99, "step": 0.1})
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE", "IMAGE", "STRING", "STRING", "STRING", "STRING", "STRING", "INT", "INT")
|
|
RETURN_NAMES = ("图像", "选中人脸", "分类", "概率", "采用帧序号", "全部帧序列", "剪辑配置", "起始帧序号", "帧数量")
|
|
|
|
FUNCTION = "predict"
|
|
|
|
CATEGORY = "不忘科技-自定义节点🚩"
|
|
|
|
def predict(self, image, main_seed, model, length, threshold):
|
|
image, image_selected, cls, prob, nums, period = test_node(image, length=length, thres=threshold,
|
|
model_name=model)
|
|
print("全部帧序列", period)
|
|
if len(period) > 0:
|
|
start, end = period[main_seed % len(period)]
|
|
config = {"start": start, "end": end}
|
|
else:
|
|
raise RuntimeError("未找到符合要求的视频片段")
|
|
return (image, image_selected, cls, prob, nums, str(period), json.dumps(config), start, end - start+1)
|
|
|
|
|
|
class FaceExtract:
|
|
"""人脸提取 By YOLO"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("图片",)
|
|
|
|
FUNCTION = "crop"
|
|
|
|
CATEGORY = "不忘科技-自定义节点🚩"
|
|
|
|
def crop(self, image):
|
|
device = model_management.get_torch_device()
|
|
image_np = 255. * image.cpu().numpy()
|
|
model = YOLO(model=os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", "yolov8n-face-lindevs.pt"))
|
|
total_images = image_np.shape[0]
|
|
out_images = np.ndarray(shape=(total_images, 512, 512, 3))
|
|
print("shape", image_np.shape)
|
|
print("aaaaa")
|
|
idx = 0
|
|
for image_item in image_np:
|
|
results = model.predict(
|
|
image_item,
|
|
imgsz=640,
|
|
conf=0.75,
|
|
iou=0.7,
|
|
device=device,
|
|
verbose=False
|
|
)
|
|
n = 512
|
|
r = results[0]
|
|
if len(r.boxes.data.cpu().numpy()) == 1:
|
|
y1, x1, y2, x2, p, cls = r.boxes.data.cpu().numpy()[0]
|
|
face_size = int(max(y2 - y1, x2 - x1))
|
|
center = (x1 + x2) // 2, (y1 + y2) // 2
|
|
x1, x2, y1, y2 = center[0] - face_size // 2, center[0] + face_size // 2, center[1] - face_size // 2, \
|
|
center[1] + face_size // 2
|
|
template = np.ndarray(shape=(face_size, face_size, 3))
|
|
template.fill(20)
|
|
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
|
|
for b, b1 in zip(list(range(int(y1), int(y2))), list(range(face_size))):
|
|
if (a >= 0 and a <= r.orig_img.shape[1]) and (b >= 0 and b <= r.orig_img.shape[0]):
|
|
template[a1][b1] = r.orig_img[a][b]
|
|
print(int(x1), int(x2), int(y1), int(y2))
|
|
img = cv2.resize(template, (n, n))
|
|
out_images[idx] = img
|
|
idx += 1
|
|
else:
|
|
idx += 1
|
|
cropped_face = np.array(out_images).astype(np.float32) / 255.0
|
|
cropped_face = torch.from_numpy(cropped_face)
|
|
return (cropped_face,)
|
|
|
|
|
|
class COSDownload:
|
|
"""腾讯云COS下载"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"cos_key": ("STRING", {"multiline": True}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("视频存储路径",)
|
|
FUNCTION = "download"
|
|
CATEGORY = "不忘科技-自定义节点🚩"
|
|
|
|
def download(self, cos_key):
|
|
if os.sep in cos_key or "/" in cos_key or "\\" in cos_key:
|
|
os.makedirs(os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key)),
|
|
exist_ok=True)
|
|
for i in range(0, 10):
|
|
try:
|
|
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
|
|
mode="r+") as f:
|
|
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
|
|
SecretKey=yaml_config["secret_key"])
|
|
client = CosS3Client(config)
|
|
response = client.download_file(
|
|
Bucket=yaml_config["bucket"],
|
|
Key=cos_key,
|
|
DestFilePath=os.path.join(os.path.dirname(os.path.abspath(__file__)), "download",
|
|
os.path.dirname(cos_key), os.path.basename(cos_key)))
|
|
break
|
|
except CosClientError or CosServiceError as e:
|
|
print(f"下载失败 {e}")
|
|
return (os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key),
|
|
os.path.basename(cos_key)),)
|
|
|
|
|
|
class COSUpload:
|
|
"""腾讯云COS上传"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"path": ("STRING", {"multiline": True}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("COS文件Key",)
|
|
|
|
FUNCTION = "upload"
|
|
CATEGORY = "不忘科技-自定义节点🚩"
|
|
|
|
def upload(self, path):
|
|
for i in range(0, 10):
|
|
try:
|
|
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
|
|
mode="r+") as f:
|
|
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
|
|
SecretKey=yaml_config["secret_key"])
|
|
client = CosS3Client(config)
|
|
response = client.upload_file(
|
|
Bucket=yaml_config["bucket"],
|
|
Key="/".join(
|
|
[yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),
|
|
LocalFilePath=path)
|
|
break
|
|
except CosClientError or CosServiceError as e:
|
|
print(e)
|
|
return ("/".join([yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),)
|
|
|
|
|
|
class VideoCut:
|
|
"""FFMPEG视频剪辑 -- !有卡顿问题 暂废弃"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"config": ("STRING",),
|
|
"video_path": ("STRING",),
|
|
"mod": ("INT",),
|
|
"fps": ("FLOAT",),
|
|
"period_length": ("INT", {"default": 10, "min": 4, "max": 100, "step": 1, "forceInput": True})
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("视频路径",)
|
|
|
|
FUNCTION = "cut"
|
|
|
|
# OUTPUT_NODE = False
|
|
|
|
CATEGORY = "不忘科技-自定义节点🚩"
|
|
|
|
def cut(self, config, video_path, mod, fps, period_length):
|
|
# 原文件名
|
|
origin_fname = ".".join(video_path.split(os.sep)[-1].split(".")[:-1])
|
|
# 配置获取
|
|
mul = mod / fps
|
|
print("fps", fps)
|
|
config = json.loads(config)
|
|
if len(config.keys()) == 0:
|
|
return ("无法生成符合要求的片段",)
|
|
start, end = config["start"], config["end"]
|
|
# 新文件名 复制改名适配ffmpeg
|
|
uid = uuid.uuid1()
|
|
temp_fname = os.sep.join([*video_path.split(os.sep)[:-1], "%s.%s" % (str(uid), video_path.split(".")[-1])])
|
|
try:
|
|
shutil.copy(video_path, temp_fname)
|
|
except:
|
|
return ("请检查输入文件权限",)
|
|
video_path = temp_fname
|
|
# 组装输出文件名
|
|
output_name = (".".join([*video_path.split(os.sep)[-1].split(".")[:-2],
|
|
video_path.split(os.sep)[-1].split(".")[-2] +
|
|
"_output_%%03d_%s" % datetime.now().strftime('%Y%m%d_%H%M%S'),
|
|
video_path.split(os.sep)[-1].split(".")[-1]]))
|
|
output = (os.sep.join([*video_path.split(os.sep)[:-1], output_name])
|
|
.replace(os.sep.join(["ComfyUI", "input"]), os.sep.join(["ComfyUI", "output"])).replace(" ", ""))
|
|
#调用ffmpeg
|
|
ff = ffmpy.FFmpeg(
|
|
inputs={video_path: ['-accurate_seek']},
|
|
outputs={output: [
|
|
'-f', 'segment',
|
|
'-ss', str(round(start * mul, 3)),
|
|
'-to', str(round(end * mul, 3)),
|
|
'-segment_times', str(period_length),
|
|
'-c', 'copy',
|
|
'-map', '0',
|
|
'-avoid_negative_ts', '1'
|
|
]}
|
|
)
|
|
print(ff.cmd)
|
|
ff.run()
|
|
# uuid填充改回原文件名
|
|
try:
|
|
os.remove(temp_fname)
|
|
except:
|
|
pass
|
|
try:
|
|
files = glob.glob(output.replace("%03d", "*"))
|
|
for file in files:
|
|
shutil.move(file, file.replace(str(uid), origin_fname))
|
|
files = glob.glob(output.replace(str(uid), origin_fname).replace("%03d", "*"))
|
|
return (str(files),)
|
|
except:
|
|
files = glob.glob(output.replace("%03d", "*"))
|
|
traceback.print_exc()
|
|
return (str(files),)
|
|
|
|
|
|
# Add custom API routes, using router
|
|
from aiohttp import web
|
|
from server import PromptServer
|
|
|
|
|
|
@PromptServer.instance.routes.get("/hello")
|
|
async def get_hello(request):
|
|
return web.json_response("hello")
|
|
|
|
|
|
# A dictionary that contains all nodes you want to export with their names
|
|
# NOTE: names should be globally unique
|
|
NODE_CLASS_MAPPINGS = {
|
|
"FaceOccDetect": FaceDetect,
|
|
"FaceExtract": FaceExtract,
|
|
"COSUpload": COSUpload,
|
|
"COSDownload": COSDownload,
|
|
"VideoCutCustom": VideoCut
|
|
}
|
|
|
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"FaceOccDetect": "面部遮挡检测",
|
|
"FaceExtract": "面部提取",
|
|
"COSUpload": "COS上传",
|
|
"COSDownload": "COS下载",
|
|
"VideoCutCustom": "视频剪裁"
|
|
}
|