113 lines
3.5 KiB
Python
113 lines
3.5 KiB
Python
import glob
|
|
import os.path
|
|
from os.path import isdir
|
|
|
|
from PIL import Image
|
|
import torch
|
|
from torchvision import transforms
|
|
from torchvision.transforms import Resize
|
|
|
|
from .utils import load_weight
|
|
from .model import Model
|
|
|
|
|
|
# CONSTANT
|
|
MEAN = [0.485, 0.456, 0.406]
|
|
STD = [0.229, 0.224, 0.225]
|
|
SIZE = [224, 224]
|
|
CLASSES = {0: "non-occluded",
|
|
1: "occluded"}
|
|
|
|
|
|
def test_image(opt):
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = Model(opt.model, 2, False).to(device)
|
|
model = load_weight(model, opt.weight)
|
|
model.eval()
|
|
|
|
# transform data
|
|
transform = transforms.Compose([
|
|
transforms.Resize(SIZE),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(MEAN, STD)
|
|
])
|
|
|
|
# Image
|
|
if isdir(opt.image):
|
|
imgs = glob.glob(os.path.join(opt.image,"*.png"))
|
|
for path in imgs:
|
|
img = Image.open(path).convert("RGB")
|
|
img = transform(img).to(device)
|
|
output = model(img.unsqueeze(0))
|
|
output = torch.softmax(output, 1)
|
|
prob, pred = torch.max(output, 1)
|
|
|
|
print("Image {} is {} - {:.2f} %".format(
|
|
path, CLASSES[pred.item()], prob.item() * 100
|
|
))
|
|
else:
|
|
img = Image.open(opt.image).convert("RGB")
|
|
img = transform(img).to(device)
|
|
output = model(img.unsqueeze(0))
|
|
output = torch.softmax(output, 1)
|
|
prob, pred = torch.max(output, 1)
|
|
|
|
print("Image {} is {} - {:.2f} %".format(
|
|
opt.image, CLASSES[pred.item()], prob.item() * 100
|
|
))
|
|
|
|
def test_node(image:torch.Tensor,length=10,thres=95,model_name="convnext_tiny"):
|
|
weight_dic = {
|
|
"convnext_tiny":"best_convnext_tiny.pth",
|
|
"convnext_base":"best_convnext_base.pth"
|
|
}
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = Model(model_name, 2, False).to(device)
|
|
weight = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", weight_dic[model_name])
|
|
model = load_weight(model, weight)
|
|
model.eval()
|
|
image = image.permute(0,3,1,2)
|
|
torch_resize = Resize([224, 224])
|
|
output = model(torch_resize(image).to(device))
|
|
output = torch.softmax(output, 1)
|
|
prob, pred = torch.max(output, 1)
|
|
probs, preds = [round(i.item() * 100,2) for i in prob], [CLASSES[i.item()] for i in pred]
|
|
print("Image is {} - {} %".format(
|
|
preds, probs
|
|
))
|
|
nums = []
|
|
for idx, a,b in zip(range(len(probs)),preds,probs):
|
|
if a=="non-occluded" and b > thres:
|
|
nums.append(idx)
|
|
start = -1
|
|
end = -1
|
|
period = []
|
|
for idx in range(len(nums)):
|
|
if idx == 0:
|
|
start = nums[idx]
|
|
end = nums[idx]
|
|
else:
|
|
if nums[idx] == end + 1:
|
|
end = nums[idx]
|
|
else:
|
|
if end - start + 1 >= length:
|
|
period.append([start, end])
|
|
start = nums[idx]
|
|
end = nums[idx]
|
|
if end - start + 1 >= length:
|
|
period.append([start, end])
|
|
return (image.permute(0,2,3,1), image.permute(0,2,3,1)[nums,:,:,:], str(preds), str(probs), str(nums), period)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from argparse import ArgumentParser
|
|
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--model", type=str, help="Model name")
|
|
parser.add_argument("--weight", type=str, help="Weight path (.pth)")
|
|
parser.add_argument("--image", type=str, help="Image path")
|
|
args = parser.parse_args()
|
|
|
|
test_image(args)
|