Skip to content

Commit 7cfd12b

Browse files
author
Saurav Agarwal
committed
Add jit transformer resize
1 parent 606dd01 commit 7cfd12b

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

python/tests/data_load.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import sys
2+
import torch
3+
import pyCoverageControlTorch as cct
4+
5+
if __name__ == "__main__":
6+
# Get filename from argument
7+
filename = sys.argv[1]
8+
# Load tensor from file
9+
tensor_model = torch.jit.load(filename)
10+
tensor = list(tensor_model.parameters())[0]
11+
# Print tensor
12+
print(tensor.dtype)
13+
print(tensor.shape)
14+
print(tensor.sum())
15+
16+
27.1 KB
Binary file not shown.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
import sys
3+
import torch
4+
import torchvision
5+
import torchvision.transforms as T
6+
7+
class Resizer(torch.nn.Module):
8+
def __init__(self, size):
9+
super().__init__()
10+
self.size = size
11+
self.T = T.Resize(size)
12+
13+
def forward(self, img):
14+
return self.T(img)
15+
16+
if __name__ == "__main__":
17+
# Get size from args
18+
map_size = int(sys.argv[1])
19+
file_name = str(sys.argv[2])
20+
scripted_resizer = torch.jit.script(Resizer([map_size,]))
21+
scripted_resizer.save(file_name)

0 commit comments

Comments
 (0)