-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
101 lines (77 loc) · 3.41 KB
/
test.py
File metadata and controls
101 lines (77 loc) · 3.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import os
import time
import torch
from PIL import Image
import cv2 as cv
from model import VideoSaliencyModel
from utils import load_model_to_device, torch_transform_image, save_image, blur
parser = argparse.ArgumentParser()
parser.add_argument('weight_file', default='', type=str, help='path to pretrained model state dict file')
parser.add_argument('--test_data_path',
default='E:/szkolne/praca_magisterska/ACLNet-Pytorch/validation',
type=str,
help='path to testing data')
parser.add_argument('--output_path', default='./result', type=str, help='path for output files')
def main():
args = parser.parse_args()
# set constants
len_temporal = 8
# set input and output path strings
file_weight = args.weight_file
path_input = args.test_data_path
path_output = args.output_path
path_output = os.path.join(path_output, time.strftime("%m-%d_%H-%M-%S"))
if not os.path.isdir(path_output):
os.makedirs(path_output)
model = VideoSaliencyModel()
model.load_state_dict(torch.load(file_weight))
model, device = load_model_to_device(model)
model.eval()
list_input_data = [d for d in os.listdir(path_input) if os.path.isdir(os.path.join(path_input, d))]
list_input_data.sort()
for data_name in list_input_data:
print(f'Processing {data_name}...')
list_frames = [f for f in os.listdir(os.path.join(path_input, data_name, 'images')) if os.path.isfile(
os.path.join(path_input, data_name, 'images', f)
)]
# list_frames = [f for f in os.listdir(os.path.join(path_input, data_name)) if os.path.isfile(
# os.path.join(path_input, data_name, f)
# )]
list_frames.sort()
os.makedirs(os.path.join(path_output, data_name), exist_ok=True)
if len(list_frames) < 2 * len_temporal - 1:
print('Not enough frames in input clip!')
return
snippet = []
for i in range(len(list_frames)):
img = Image.open(os.path.join(path_input, data_name, 'images', list_frames[i])).convert('RGB')
# img = Image.open(os.path.join(path_input, data_name, list_frames[i])).convert('RGB')
img_size = img.size
img = torch_transform_image(img)
snippet.append(img)
if i >= len_temporal - 1:
clip = torch.FloatTensor(torch.stack(snippet, dim=0)).unsqueeze(0)
clip = clip.permute((0, 2, 1, 3, 4))
process_image(model, device, clip, data_name, list_frames[i], path_output, img_size)
# process first (len_temporal - 1) frames
if i < 2 * len_temporal - 2:
process_image(
model,
device,
torch.flip(clip, [2]),
data_name,
list_frames[i - len_temporal + 1],
path_output,
img_size
)
del snippet[0]
def process_image(model, device, clip, data_name, frame_no, save_path, img_size):
with torch.no_grad():
pred = model(clip.to(device)).cpu().data[0]
pred = pred.numpy()
pred = cv.resize(pred, (img_size[0], img_size[1]))
pred = blur(pred)
save_image(pred, os.path.join(save_path, data_name, frame_no), normalize=True)
if __name__ == '__main__':
main()