import glob import os.path from os.path import isdir from PIL import Image import torch from torchvision import transforms from torchvision.transforms import Resize from .utils import load_weight from .model import Model # CONSTANT MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] SIZE = [224, 224] CLASSES = {0: "non-occluded", 1: "occluded"} def test_image(opt): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model(opt.model, 2, False).to(device) model = load_weight(model, opt.weight) model.eval() # transform data transform = transforms.Compose([ transforms.Resize(SIZE), transforms.ToTensor(), transforms.Normalize(MEAN, STD) ]) # Image if isdir(opt.image): imgs = glob.glob(os.path.join(opt.image,"*.png")) for path in imgs: img = Image.open(path).convert("RGB") img = transform(img).to(device) output = model(img.unsqueeze(0)) output = torch.softmax(output, 1) prob, pred = torch.max(output, 1) print("Image {} is {} - {:.2f} %".format( path, CLASSES[pred.item()], prob.item() * 100 )) else: img = Image.open(opt.image).convert("RGB") img = transform(img).to(device) output = model(img.unsqueeze(0)) output = torch.softmax(output, 1) prob, pred = torch.max(output, 1) print("Image {} is {} - {:.2f} %".format( opt.image, CLASSES[pred.item()], prob.item() * 100 )) def test_node(image:torch.Tensor,length=10,thres=95,model_name="convnext_tiny"): weight_dic = { "convnext_tiny":"best_convnext_tiny.pth", "convnext_base":"best_convnext_base.pth" } device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model(model_name, 2, False).to(device) weight = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", weight_dic[model_name]) model = load_weight(model, weight) model.eval() image = image.permute(0,3,1,2) torch_resize = Resize([224, 224]) output = model(torch_resize(image).to(device)) output = torch.softmax(output, 1) prob, pred = torch.max(output, 1) probs, preds = [round(i.item() * 100,2) for i in prob], [CLASSES[i.item()] for i in pred] print("Image is {} - {} %".format( preds, probs )) nums = [] for idx, a,b in zip(range(len(probs)),preds,probs): if a=="non-occluded" and b > thres: nums.append(idx) start = -1 end = -1 period = [] for idx in range(len(nums)): if idx == 0: start = nums[idx] end = nums[idx] else: if nums[idx] == end + 1: end = nums[idx] else: if end - start + 1 >= length: period.append([start, end]) start = nums[idx] end = nums[idx] if end - start + 1 >= length: period.append([start, end]) temp_period = [] for i in period: a = i[0] while a+length-1 <= i[1]: temp_period.append([a,a+length-1]) a = a+length return (image.permute(0,2,3,1), image.permute(0,2,3,1)[nums,:,:,:], str(preds), str(probs), str(nums), temp_period) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("--model", type=str, help="Model name") parser.add_argument("--weight", type=str, help="Weight path (.pth)") parser.add_argument("--image", type=str, help="Image path") args = parser.parse_args() test_image(args)