-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
74 lines (63 loc) · 2.63 KB
/
model.py
File metadata and controls
74 lines (63 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
"""A simple Convolution Block.
Conv2d - Activation - MaxPool2d - Conv2d - Activation - MaxPool2d
"""
def __init__(self, activation=F.relu_):
pass
class ConvNet(nn.Module):
"""A simple Convolution Neural Network for classification."""
def __init__(self):
super().__init__()
self.conv_module = nn.Sequential(nn.Conv2d(1, 8, 3, padding=1), # [8, 28, 28]
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # [8, 14, 14]
nn.Conv2d(8, 16, 3, padding=1), # [16, 14, 14]
nn.ReLU(inplace=True),
nn.MaxPool2d(2) # [16, 7, 7]
)
out_dim = 16*7*7
self.fc_module = nn.Sequential(nn.Linear(out_dim, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 10)
)
self._initialize_params()
def _initialize_params(self):
"""Initialize training parameters."""
for m in self.children():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight)
if hasattr(m, 'bias'):
nn.init.zeros_(m.bias)
print("Initialized training parameters.")
def forward(self, x):
""" x.shape = [batch_size, 1, 28, 28] for MNIST """
x = self.conv_module(x)
x = self.flatten(x)
score = self.fc_module(x)
return score
def flatten(self, x):
"""Flatten a tensor x."""
dims = x.size()
dim = 1
for d in dims[1: ]:
dim *= d
return x.view(-1, dim)
def save(self, fp='backup.pth'):
"""Save the model to the filepath ./save/{fp}"""
if not os.path.exists('./save'):
os.makedirs('./save')
filename = os.path.join('./save', fp)
torch.save({'state_dict': self.state_dict()})
print(f"Saved the state dict of the model to {filename}")
def load(self, fp):
"""Load the model from the filepath ./save/{fp}"""
filename = os.path.join('./save', fp)
assert os.path.exists(filename), 'wrong filename'
checkpoint = torch.load(filename)
self.load_state_dict(checkpoint['state_dict'])