ComfyUI-CustomNode/test_single_image.py

113 lines
3.5 KiB
Python

import glob
import os.path
from os.path import isdir
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms import Resize
from .utils import load_weight
from .model import Model
# CONSTANT
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
SIZE = [224, 224]
CLASSES = {0: "non-occluded",
1: "occluded"}
def test_image(opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(opt.model, 2, False).to(device)
model = load_weight(model, opt.weight)
model.eval()
# transform data
transform = transforms.Compose([
transforms.Resize(SIZE),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
# Image
if isdir(opt.image):
imgs = glob.glob(os.path.join(opt.image,"*.png"))
for path in imgs:
img = Image.open(path).convert("RGB")
img = transform(img).to(device)
output = model(img.unsqueeze(0))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
print("Image {} is {} - {:.2f} %".format(
path, CLASSES[pred.item()], prob.item() * 100
))
else:
img = Image.open(opt.image).convert("RGB")
img = transform(img).to(device)
output = model(img.unsqueeze(0))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
print("Image {} is {} - {:.2f} %".format(
opt.image, CLASSES[pred.item()], prob.item() * 100
))
def test_node(image:torch.Tensor,length=10,thres=95,model_name="convnext_tiny"):
weight_dic = {
"convnext_tiny":"best_convnext_tiny.pth",
"convnext_base":"best_convnext_base.pth"
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(model_name, 2, False).to(device)
weight = os.path.join(os.path.dirname(os.path.abspath(__file__)),weight_dic[model_name])
model = load_weight(model, weight)
model.eval()
image = image.permute(0,3,1,2)
torch_resize = Resize([224, 224])
output = model(torch_resize(image).to(device))
output = torch.softmax(output, 1)
prob, pred = torch.max(output, 1)
probs, preds = [round(i.item() * 100,2) for i in prob], [CLASSES[i.item()] for i in pred]
print("Image is {} - {} %".format(
preds, probs
))
nums = []
for idx, a,b in zip(range(len(probs)),preds,probs):
if a=="non-occluded" and b > thres:
nums.append(idx)
start = -1
end = -1
period = []
for idx in range(len(nums)):
if idx == 0:
start = nums[idx]
end = nums[idx]
else:
if nums[idx] == end + 1:
end = nums[idx]
else:
if end - start + 1 >= length:
period.append([start, end])
start = nums[idx]
end = nums[idx]
if end - start + 1 >= length:
period.append([start, end])
return (image.permute(0,2,3,1), image.permute(0,2,3,1)[nums,:,:,:], str(preds), str(probs), str(nums), period)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model", type=str, help="Model name")
parser.add_argument("--weight", type=str, help="Weight path (.pth)")
parser.add_argument("--image", type=str, help="Image path")
args = parser.parse_args()
test_image(args)