ComfyUI-CustomNode/utils/face_occu_detect.py

121 lines
3.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import os.path
from os.path import isdir
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms import Resize
from ..utils.modal_utils import load_weight
from ..utils.model_module import Model
# CONSTANT
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
SIZE = [224, 224]
CLASSES = {0: "non-occluded",
1: "occluded"}
def face_occu_detect_single(opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(opt.model, 2, False).to(device)
model = load_weight(model, opt.weight)
model.eval()
# transform data
transform = transforms.Compose([
transforms.Resize(SIZE),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
# Image
if isdir(opt.image):
imgs = glob.glob(os.path.join(opt.image, "*.png"))
for path in imgs:
img = Image.open(path).convert("RGB")
img = transform(img).to(device)
output = model(img.unsqueeze(0))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
print("Image {} is {} - {:.2f} %".format(
path, CLASSES[pred.item()], prob.item() * 100
))
else:
img = Image.open(opt.image).convert("RGB")
img = transform(img).to(device)
output = model(img.unsqueeze(0))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
print("Image {} is {} - {:.2f} %".format(
opt.image, CLASSES[pred.item()], prob.item() * 100
))
def face_occu_detect(image: torch.Tensor, length=10, thres=95, model_name="convnext_tiny"):
weight_dic = {
"convnext_tiny": "best_convnext_tiny.pth",
"convnext_base": "best_convnext_base.pth"
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(model_name, 2, False).to(device)
weight = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "model", weight_dic[model_name])
if not os.path.exists(weight):
raise Exception("请前往https://github.com/LamKser/face-occlusion-classification下载所选权重文件到model文件夹")
model = load_weight(model, weight)
model.eval()
image = image.permute(0, 3, 1, 2)
torch_resize = Resize([224, 224])
output = model(torch_resize(image).to(device))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
probs, preds = [round(i.item() * 100, 2) for i in prob], [CLASSES[i.item()] for i in pred]
print("Image is {} - {} %".format(
preds, probs
))
nums = []
for idx, a, b in zip(range(len(probs)), preds, probs):
if a == "non-occluded" and b > thres:
nums.append(idx)
start = -1
end = -1
period = []
for idx in range(len(nums)):
if idx == 0:
start = nums[idx]
end = nums[idx]
else:
if nums[idx] == end + 1:
end = nums[idx]
else:
if end - start + 1 >= length:
period.append([start, end])
start = nums[idx]
end = nums[idx]
if end - start + 1 >= length:
period.append([start, end])
temp_period = []
for i in period:
a = i[0]
while a + length - 1 <= i[1]:
temp_period.append([a, a + length - 1])
a = a + length
return (
image.permute(0, 2, 3, 1), image.permute(0, 2, 3, 1)[nums, :, :, :], str(preds), str(probs), str(nums), temp_period)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model", type=str, help="Model name")
parser.add_argument("--weight", type=str, help="Weight path (.pth)")
parser.add_argument("--image", type=str, help="Image path")
args = parser.parse_args()
face_occu_detect_single(args)