121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
import glob
|
||
import os.path
|
||
from os.path import isdir
|
||
|
||
from PIL import Image
|
||
import torch
|
||
from torchvision import transforms
|
||
from torchvision.transforms import Resize
|
||
|
||
from ..utils.modal_utils import load_weight
|
||
from ..utils.model_module import Model
|
||
|
||
# CONSTANT
|
||
MEAN = [0.485, 0.456, 0.406]
|
||
STD = [0.229, 0.224, 0.225]
|
||
SIZE = [224, 224]
|
||
CLASSES = {0: "non-occluded",
|
||
1: "occluded"}
|
||
|
||
|
||
def face_occu_detect_single(opt):
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
model = Model(opt.model, 2, False).to(device)
|
||
model = load_weight(model, opt.weight)
|
||
model.eval()
|
||
|
||
# transform data
|
||
transform = transforms.Compose([
|
||
transforms.Resize(SIZE),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(MEAN, STD)
|
||
])
|
||
|
||
# Image
|
||
if isdir(opt.image):
|
||
imgs = glob.glob(os.path.join(opt.image, "*.png"))
|
||
for path in imgs:
|
||
img = Image.open(path).convert("RGB")
|
||
img = transform(img).to(device)
|
||
output = model(img.unsqueeze(0))
|
||
output = torch.softmax(output, 1)
|
||
prob, pred = torch.max(output, 1)
|
||
|
||
print("Image {} is {} - {:.2f} %".format(
|
||
path, CLASSES[pred.item()], prob.item() * 100
|
||
))
|
||
else:
|
||
img = Image.open(opt.image).convert("RGB")
|
||
img = transform(img).to(device)
|
||
output = model(img.unsqueeze(0))
|
||
output = torch.softmax(output, 1)
|
||
prob, pred = torch.max(output, 1)
|
||
|
||
print("Image {} is {} - {:.2f} %".format(
|
||
opt.image, CLASSES[pred.item()], prob.item() * 100
|
||
))
|
||
|
||
|
||
def face_occu_detect(image: torch.Tensor, length=10, thres=95, model_name="convnext_tiny"):
|
||
weight_dic = {
|
||
"convnext_tiny": "best_convnext_tiny.pth",
|
||
"convnext_base": "best_convnext_base.pth"
|
||
}
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
model = Model(model_name, 2, False).to(device)
|
||
weight = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "model", weight_dic[model_name])
|
||
if not os.path.exists(weight):
|
||
raise Exception("请前往https://github.com/LamKser/face-occlusion-classification下载所选权重文件到model文件夹!")
|
||
model = load_weight(model, weight)
|
||
model.eval()
|
||
image = image.permute(0, 3, 1, 2)
|
||
torch_resize = Resize([224, 224])
|
||
output = model(torch_resize(image).to(device))
|
||
output = torch.softmax(output, 1)
|
||
prob, pred = torch.max(output, 1)
|
||
probs, preds = [round(i.item() * 100, 2) for i in prob], [CLASSES[i.item()] for i in pred]
|
||
print("Image is {} - {} %".format(
|
||
preds, probs
|
||
))
|
||
nums = []
|
||
for idx, a, b in zip(range(len(probs)), preds, probs):
|
||
if a == "non-occluded" and b > thres:
|
||
nums.append(idx)
|
||
start = -1
|
||
end = -1
|
||
period = []
|
||
for idx in range(len(nums)):
|
||
if idx == 0:
|
||
start = nums[idx]
|
||
end = nums[idx]
|
||
else:
|
||
if nums[idx] == end + 1:
|
||
end = nums[idx]
|
||
else:
|
||
if end - start + 1 >= length:
|
||
period.append([start, end])
|
||
start = nums[idx]
|
||
end = nums[idx]
|
||
if end - start + 1 >= length:
|
||
period.append([start, end])
|
||
temp_period = []
|
||
for i in period:
|
||
a = i[0]
|
||
while a + length - 1 <= i[1]:
|
||
temp_period.append([a, a + length - 1])
|
||
a = a + length
|
||
return (
|
||
image.permute(0, 2, 3, 1), image.permute(0, 2, 3, 1)[nums, :, :, :], str(preds), str(probs), str(nums), temp_period)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
from argparse import ArgumentParser
|
||
|
||
parser = ArgumentParser()
|
||
parser.add_argument("--model", type=str, help="Model name")
|
||
parser.add_argument("--weight", type=str, help="Weight path (.pth)")
|
||
parser.add_argument("--image", type=str, help="Image path")
|
||
args = parser.parse_args()
|
||
|
||
face_occu_detect_single(args)
|