forked from Alibaba-MIIL/ML_Decoder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
98 lines (82 loc) · 3.27 KB
/
infer.py
File metadata and controls
98 lines (82 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import argparse
import time
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data.distributed
from src_files.helper_functions.bn_fusion import fuse_bn_recursively
from src_files.models import create_model
import matplotlib
from src_files.models.tresnet.tresnet import InplacABN_to_ABN
# matplotlib.use('TkAgg')
# matplotlib.use('TkAgg')
# import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
parser = argparse.ArgumentParser(description="PyTorch MS_COCO infer")
parser.add_argument("--num-classes", default=80, type=int)
parser.add_argument(
"--model-path", type=str, default="./models_local/TRresNet_L_448_86.6.pth"
)
parser.add_argument("--pic-path", type=str, default="./pics/000000000885.jpg")
parser.add_argument("--model-name", type=str, default="tresnet_l")
parser.add_argument("--image-size", type=int, default=448)
# parser.add_argument('--dataset-type', type=str, default='MS-COCO')
parser.add_argument("--th", type=float, default=0.75)
parser.add_argument("--top-k", type=float, default=20)
# ML-Decoder
parser.add_argument("--use-ml-decoder", default=1, type=int)
parser.add_argument("--num-of-groups", default=-1, type=int) # full-decoding
parser.add_argument("--decoder-embedding", default=768, type=int)
parser.add_argument("--zsl", default=0, type=int)
def main():
print("Inference code on a single image")
# parsing args
args = parser.parse_args()
# Setup model
print("creating model {}...".format(args.model_name))
from IPython import embed
embed()
model = create_model(args, load_head=True).cuda()
state = torch.load(args.model_path, map_location="cpu")
model.load_state_dict(state["model"], strict=False)
########### eliminate BN for faster inference ###########
model = model.cpu()
model = InplacABN_to_ABN(model)
model = fuse_bn_recursively(model)
model = model.cuda().half().eval()
#######################################################
print("done")
classes_list = np.array(list(state["idx_to_class"].values()))
print("done\n")
# doing inference
print("loading image and doing inference...")
im = Image.open(args.pic_path)
im_resize = im.resize((args.image_size, args.image_size))
np_img = np.array(im_resize, dtype=np.uint8)
tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() / 255.0 # HWC to CHW
tensor_batch = torch.unsqueeze(tensor_img, 0).cuda().half() # float16 inference
output = torch.squeeze(torch.sigmoid(model(tensor_batch)))
np_output = output.cpu().detach().numpy()
## Top-k predictions
# detected_classes = classes_list[np_output > args.th]
idx_sort = np.argsort(-np_output)
detected_classes = np.array(classes_list)[idx_sort][: args.top_k]
scores = np_output[idx_sort][: args.top_k]
idx_th = scores > args.th
detected_classes = detected_classes[idx_th]
print("done\n")
print("detected classes: {}".format(detected_classes))
# displaying image
# print('showing image on screen...')
# fig = plt.figure()
# plt.imshow(im)
# plt.axis('off')
# plt.axis('tight')
# # plt.rcParams["axes.titlesize"] = 10
# plt.title("detected classes: {}".format(detected_classes))
# plt.show()
# print('done\n')
if __name__ == "__main__":
main()