import glob import os.path from os.path import isdir from PIL import Image import torch from torchvision import transforms from torchvision.transforms import Resize from ..utils.modal_utils import load_weight from ..utils.model_module import Model # CONSTANT MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] SIZE = [224, 224] CLASSES = {0: "non-occluded", 1: "occluded"} def face_occu_detect_single(opt): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model(opt.model, 2, False).to(device) model = load_weight(model, opt.weight) model.eval() # transform data transform = transforms.Compose([ transforms.Resize(SIZE), transforms.ToTensor(), transforms.Normalize(MEAN, STD) ]) # Image if isdir(opt.image): imgs = glob.glob(os.path.join(opt.image, "*.png")) for path in imgs: img = Image.open(path).convert("RGB") img = transform(img).to(device) output = model(img.unsqueeze(0)) output = torch.softmax(output, 1) prob, pred = torch.max(output, 1) print("Image {} is {} - {:.2f} %".format( path, CLASSES[pred.item()], prob.item() * 100 )) else: img = Image.open(opt.image).convert("RGB") img = transform(img).to(device) output = model(img.unsqueeze(0)) output = torch.softmax(output, 1) prob, pred = torch.max(output, 1) print("Image {} is {} - {:.2f} %".format( opt.image, CLASSES[pred.item()], prob.item() * 100 )) def face_occu_detect(image: torch.Tensor, length=10, thres=95, model_name="convnext_tiny"): weight_dic = { "convnext_tiny": "best_convnext_tiny.pth", "convnext_base": "best_convnext_base.pth" } device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model(model_name, 2, False).to(device) weight = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", weight_dic[model_name]) model = load_weight(model, weight) model.eval() image = image.permute(0, 3, 1, 2) torch_resize = Resize([224, 224]) output = model(torch_resize(image).to(device)) output = torch.softmax(output, 1) prob, pred = torch.max(output, 1) probs, preds = [round(i.item() * 100, 2) for i in prob], [CLASSES[i.item()] for i in pred] print("Image is {} - {} %".format( preds, probs )) nums = [] for idx, a, b in zip(range(len(probs)), preds, probs): if a == "non-occluded" and b > thres: nums.append(idx) start = -1 end = -1 period = [] for idx in range(len(nums)): if idx == 0: start = nums[idx] end = nums[idx] else: if nums[idx] == end + 1: end = nums[idx] else: if end - start + 1 >= length: period.append([start, end]) start = nums[idx] end = nums[idx] if end - start + 1 >= length: period.append([start, end]) temp_period = [] for i in period: a = i[0] while a + length - 1 <= i[1]: temp_period.append([a, a + length - 1]) a = a + length return ( image.permute(0, 2, 3, 1), image.permute(0, 2, 3, 1)[nums, :, :, :], str(preds), str(probs), str(nums), temp_period) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("--model", type=str, help="Model name") parser.add_argument("--weight", type=str, help="Weight path (.pth)") parser.add_argument("--image", type=str, help="Image path") args = parser.parse_args() face_occu_detect_single(args)