ComfyUI-CustomNode/__init__.py

310 lines
11 KiB
Python

import glob
import json
import os
import shutil
import traceback
import uuid
from datetime import datetime
import cv2
import numpy as np
import torch
import yaml
from ultralytics import YOLO
from comfy import model_management
from qcloud_cos import CosConfig, CosClientError, CosServiceError
from qcloud_cos import CosS3Client
from .test_single_image import test_node
import ffmpy
video_extensions = ['webm', 'mp4', 'mkv', 'gif', 'mov']
class FaceDetect:
"""
人脸遮挡检测
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"main_seed": ("INT:seed", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"model": (["convnext_tiny", "convnext_base"],),
"length": ("INT", {"default": 10, "min": 3, "max": 60, "step": 1}),
"threshold": ("FLOAT", {"default": 94, "min": 55, "max": 99, "step": 0.1})
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "STRING", "STRING", "STRING", "STRING", "STRING", "INT", "INT")
RETURN_NAMES = ("图像", "选中人脸", "分类", "概率", "采用帧序号", "全部帧序列", "剪辑配置", "起始帧序号", "帧数量")
FUNCTION = "predict"
CATEGORY = "不忘科技-自定义节点🚩"
def predict(self, image, main_seed, model, length, threshold):
image, image_selected, cls, prob, nums, period = test_node(image, length=length, thres=threshold,
model_name=model)
print("全部帧序列", period)
if len(period) > 0:
start, end = period[main_seed % len(period)]
config = {"start": start, "end": end}
else:
raise RuntimeError("未找到符合要求的视频片段")
return (image, image_selected, cls, prob, nums, str(period), json.dumps(config), start, end - start+1)
class FaceExtract:
"""人脸提取 By YOLO"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("图片",)
FUNCTION = "crop"
CATEGORY = "不忘科技-自定义节点🚩"
def crop(self, image):
device = model_management.get_torch_device()
image_np = 255. * image.cpu().numpy()
model = YOLO(model=os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", "yolov8n-face-lindevs.pt"))
total_images = image_np.shape[0]
out_images = np.ndarray(shape=(total_images, 512, 512, 3))
print("shape", image_np.shape)
print("aaaaa")
idx = 0
for image_item in image_np:
results = model.predict(
image_item,
imgsz=640,
conf=0.75,
iou=0.7,
device=device,
verbose=False
)
n = 512
r = results[0]
if len(r.boxes.data.cpu().numpy()) == 1:
y1, x1, y2, x2, p, cls = r.boxes.data.cpu().numpy()[0]
face_size = int(max(y2 - y1, x2 - x1))
center = (x1 + x2) // 2, (y1 + y2) // 2
x1, x2, y1, y2 = center[0] - face_size // 2, center[0] + face_size // 2, center[1] - face_size // 2, \
center[1] + face_size // 2
template = np.ndarray(shape=(face_size, face_size, 3))
template.fill(20)
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
for b, b1 in zip(list(range(int(y1), int(y2))), list(range(face_size))):
if (a >= 0 and a <= r.orig_img.shape[1]) and (b >= 0 and b <= r.orig_img.shape[0]):
template[a1][b1] = r.orig_img[a][b]
print(int(x1), int(x2), int(y1), int(y2))
img = cv2.resize(template, (n, n))
out_images[idx] = img
idx += 1
else:
idx += 1
cropped_face = np.array(out_images).astype(np.float32) / 255.0
cropped_face = torch.from_numpy(cropped_face)
return (cropped_face,)
class COSDownload:
# TODO 增加腾讯云VOD视频下载
"""腾讯云COS下载"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cos_key": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "download"
CATEGORY = "不忘科技-自定义节点🚩"
def download(self, cos_key):
if os.sep in cos_key or "/" in cos_key or "\\" in cos_key:
os.makedirs(os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key)),
exist_ok=True)
for i in range(0, 10):
try:
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
mode="r+") as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
SecretKey=yaml_config["secret_key"])
client = CosS3Client(config)
response = client.download_file(
Bucket=yaml_config["bucket"],
Key=cos_key,
DestFilePath=os.path.join(os.path.dirname(os.path.abspath(__file__)), "download",
os.path.dirname(cos_key), os.path.basename(cos_key)))
break
except CosClientError or CosServiceError as e:
print(f"下载失败 {e}")
return (os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key),
os.path.basename(cos_key)),)
class COSUpload:
"""腾讯云COS上传"""
# TODO 增加腾讯云VOD视频上传
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"path": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("COS文件Key",)
FUNCTION = "upload"
CATEGORY = "不忘科技-自定义节点🚩"
def upload(self, path):
for i in range(0, 10):
try:
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
mode="r+") as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
SecretKey=yaml_config["secret_key"])
client = CosS3Client(config)
response = client.upload_file(
Bucket=yaml_config["bucket"],
Key="/".join(
[yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),
LocalFilePath=path)
break
except CosClientError or CosServiceError as e:
print(e)
return ("/".join([yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),)
class VideoCut:
"""FFMPEG视频剪辑 -- !有卡顿问题 暂废弃"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"config": ("STRING",),
"video_path": ("STRING",),
"mod": ("INT",),
"fps": ("FLOAT",),
"period_length": ("INT", {"default": 10, "min": 4, "max": 100, "step": 1, "forceInput": True})
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频路径",)
FUNCTION = "cut"
# OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩"
def cut(self, config, video_path, mod, fps, period_length):
# 原文件名
origin_fname = ".".join(video_path.split(os.sep)[-1].split(".")[:-1])
# 配置获取
mul = mod / fps
print("fps", fps)
config = json.loads(config)
if len(config.keys()) == 0:
return ("无法生成符合要求的片段",)
start, end = config["start"], config["end"]
# 新文件名 复制改名适配ffmpeg
uid = uuid.uuid1()
temp_fname = os.sep.join([*video_path.split(os.sep)[:-1], "%s.%s" % (str(uid), video_path.split(".")[-1])])
try:
shutil.copy(video_path, temp_fname)
except:
return ("请检查输入文件权限",)
video_path = temp_fname
# 组装输出文件名
output_name = (".".join([*video_path.split(os.sep)[-1].split(".")[:-2],
video_path.split(os.sep)[-1].split(".")[-2] +
"_output_%%03d_%s" % datetime.now().strftime('%Y%m%d_%H%M%S'),
video_path.split(os.sep)[-1].split(".")[-1]]))
output = (os.sep.join([*video_path.split(os.sep)[:-1], output_name])
.replace(os.sep.join(["ComfyUI", "input"]), os.sep.join(["ComfyUI", "output"])).replace(" ", ""))
#调用ffmpeg
ff = ffmpy.FFmpeg(
inputs={video_path: ['-accurate_seek']},
outputs={output: [
'-f', 'segment',
'-ss', str(round(start * mul, 3)),
'-to', str(round(end * mul, 3)),
'-segment_times', str(period_length),
'-c', 'copy',
'-map', '0',
'-avoid_negative_ts', '1'
]}
)
print(ff.cmd)
ff.run()
# uuid填充改回原文件名
try:
os.remove(temp_fname)
except:
pass
try:
files = glob.glob(output.replace("%03d", "*"))
for file in files:
shutil.move(file, file.replace(str(uid), origin_fname))
files = glob.glob(output.replace(str(uid), origin_fname).replace("%03d", "*"))
return (str(files),)
except:
files = glob.glob(output.replace("%03d", "*"))
traceback.print_exc()
return (str(files),)
# Add custom API routes, using router
from aiohttp import web
from server import PromptServer
@PromptServer.instance.routes.get("/hello")
async def get_hello(request):
return web.json_response("hello")
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"FaceOccDetect": FaceDetect,
"FaceExtract": FaceExtract,
"COSUpload": COSUpload,
"COSDownload": COSDownload,
"VideoCutCustom": VideoCut
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"FaceOccDetect": "面部遮挡检测",
"FaceExtract": "面部提取",
"COSUpload": "COS上传",
"COSDownload": "COS下载",
"VideoCutCustom": "视频剪裁"
}