158 lines
4.7 KiB
Python
158 lines
4.7 KiB
Python
import json
|
|
import os
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from comfy import model_management
|
|
from ultralytics import YOLO
|
|
|
|
from ..utils.download_utils import download_file
|
|
from ..utils.face_occu_detect import face_occu_detect
|
|
|
|
|
|
class FaceDetect:
|
|
"""
|
|
人脸遮挡检测
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"main_seed": (
|
|
"INT",
|
|
{"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
|
|
),
|
|
"model": (["convnext_tiny", "convnext_base"],),
|
|
"length": ("INT", {"default": 10, "min": 3, "max": 60, "step": 1}),
|
|
"threshold": (
|
|
"FLOAT",
|
|
{"default": 94, "min": 55, "max": 99, "step": 0.1},
|
|
),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = (
|
|
"IMAGE",
|
|
"IMAGE",
|
|
"STRING",
|
|
"STRING",
|
|
"STRING",
|
|
"STRING",
|
|
"STRING",
|
|
"INT",
|
|
"INT",
|
|
)
|
|
RETURN_NAMES = (
|
|
"图像",
|
|
"选中人脸",
|
|
"分类",
|
|
"概率",
|
|
"采用帧序号",
|
|
"全部帧序列",
|
|
"剪辑配置",
|
|
"起始帧序号",
|
|
"帧数量",
|
|
)
|
|
|
|
FUNCTION = "predict"
|
|
|
|
CATEGORY = "不忘科技-自定义节点🚩/图片/人脸"
|
|
|
|
def predict(self, image, main_seed, model, length, threshold):
|
|
image, image_selected, cls, prob, nums, period = face_occu_detect(
|
|
image, length=length, thres=threshold, model_name=model
|
|
)
|
|
print("全部帧序列", period)
|
|
if len(period) > 0:
|
|
start, end = period[main_seed % len(period)]
|
|
config = {"start": start, "end": end}
|
|
else:
|
|
raise RuntimeError("未找到符合要求的视频片段")
|
|
return (
|
|
image,
|
|
image_selected,
|
|
cls,
|
|
prob,
|
|
nums,
|
|
str(period),
|
|
json.dumps(config),
|
|
start,
|
|
end - start + 1,
|
|
)
|
|
|
|
|
|
class FaceExtract:
|
|
"""人脸提取 By YOLO"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("图片",)
|
|
|
|
FUNCTION = "crop"
|
|
|
|
CATEGORY = "不忘科技-自定义节点🚩/图片/人脸"
|
|
|
|
def crop(self, image):
|
|
device = model_management.get_torch_device()
|
|
image_np = 255.0 * image.cpu().numpy()
|
|
model_path = os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
"model",
|
|
"yolov8n-face-lindevs.pt",
|
|
)
|
|
if not os.path.exists(model_path):
|
|
download_file(
|
|
"https://github.com/lindevs/yolov8-face/releases/latest/download/yolov8n-face-lindevs.pt",
|
|
model_path,
|
|
)
|
|
model = YOLO(model=model_path)
|
|
total_images = image_np.shape[0]
|
|
out_images = np.ndarray(shape=(total_images, 512, 512, 3))
|
|
print("shape", image_np.shape)
|
|
idx = 0
|
|
for image_item in image_np:
|
|
results = model.predict(
|
|
image_item, imgsz=640, conf=0.75, iou=0.7, device=device, verbose=False
|
|
)
|
|
n = 512
|
|
r = results[0]
|
|
if len(r.boxes.data.cpu().numpy()) == 1:
|
|
y1, x1, y2, x2, p, cls = r.boxes.data.cpu().numpy()[0]
|
|
face_size = int(max(y2 - y1, x2 - x1))
|
|
center = (x1 + x2) // 2, (y1 + y2) // 2
|
|
x1, x2, y1, y2 = (
|
|
center[0] - face_size // 2,
|
|
center[0] + face_size // 2,
|
|
center[1] - face_size // 2,
|
|
center[1] + face_size // 2,
|
|
)
|
|
template = np.ndarray(shape=(face_size, face_size, 3))
|
|
template.fill(20)
|
|
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
|
|
for b, b1 in zip(
|
|
list(range(int(y1), int(y2))), list(range(face_size))
|
|
):
|
|
if (a >= 0 and a < r.orig_img.shape[0]) and (
|
|
b >= 0 and b < r.orig_img.shape[1]
|
|
):
|
|
template[a1][b1] = r.orig_img[a][b]
|
|
print(int(x1), int(x2), int(y1), int(y2))
|
|
img = cv2.resize(template, (n, n))
|
|
out_images[idx] = img
|
|
idx += 1
|
|
else:
|
|
idx += 1
|
|
cropped_face = np.array(out_images).astype(np.float32) / 255.0
|
|
cropped_face = torch.from_numpy(cropped_face)
|
|
return (cropped_face,)
|