ComfyUI-CustomNode/utils/face_occu_detect.py

119 lines
3.7 KiB
Python

import glob
import os.path
from os.path import isdir
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms import Resize
from ..utils.modal_utils import load_weight
from ..utils.model_module import Model
# CONSTANT
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
SIZE = [224, 224]
CLASSES = {0: "non-occluded",
1: "occluded"}
def face_occu_detect_single(opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(opt.model, 2, False).to(device)
model = load_weight(model, opt.weight)
model.eval()
# transform data
transform = transforms.Compose([
transforms.Resize(SIZE),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
# Image
if isdir(opt.image):
imgs = glob.glob(os.path.join(opt.image, "*.png"))
for path in imgs:
img = Image.open(path).convert("RGB")
img = transform(img).to(device)
output = model(img.unsqueeze(0))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
print("Image {} is {} - {:.2f} %".format(
path, CLASSES[pred.item()], prob.item() * 100
))
else:
img = Image.open(opt.image).convert("RGB")
img = transform(img).to(device)
output = model(img.unsqueeze(0))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
print("Image {} is {} - {:.2f} %".format(
opt.image, CLASSES[pred.item()], prob.item() * 100
))
def face_occu_detect(image: torch.Tensor, length=10, thres=95, model_name="convnext_tiny"):
weight_dic = {
"convnext_tiny": "best_convnext_tiny.pth",
"convnext_base": "best_convnext_base.pth"
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(model_name, 2, False).to(device)
weight = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", weight_dic[model_name])
model = load_weight(model, weight)
model.eval()
image = image.permute(0, 3, 1, 2)
torch_resize = Resize([224, 224])
output = model(torch_resize(image).to(device))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
probs, preds = [round(i.item() * 100, 2) for i in prob], [CLASSES[i.item()] for i in pred]
print("Image is {} - {} %".format(
preds, probs
))
nums = []
for idx, a, b in zip(range(len(probs)), preds, probs):
if a == "non-occluded" and b > thres:
nums.append(idx)
start = -1
end = -1
period = []
for idx in range(len(nums)):
if idx == 0:
start = nums[idx]
end = nums[idx]
else:
if nums[idx] == end + 1:
end = nums[idx]
else:
if end - start + 1 >= length:
period.append([start, end])
start = nums[idx]
end = nums[idx]
if end - start + 1 >= length:
period.append([start, end])
temp_period = []
for i in period:
a = i[0]
while a + length - 1 <= i[1]:
temp_period.append([a, a + length - 1])
a = a + length
return (
image.permute(0, 2, 3, 1), image.permute(0, 2, 3, 1)[nums, :, :, :], str(preds), str(probs), str(nums), temp_period)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model", type=str, help="Model name")
parser.add_argument("--weight", type=str, help="Weight path (.pth)")
parser.add_argument("--image", type=str, help="Image path")
args = parser.parse_args()
face_occu_detect_single(args)