Skip to content

Commit 8dd7a99

Browse files
Rewrote model classes for simplicity
1 parent ef9b7d5 commit 8dd7a99

1 file changed

Lines changed: 66 additions & 59 deletions

File tree

models.py

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,33 @@
44

55

66
class UnetBlock(nn.Module):
7-
def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
8-
innermost=False, outermost=False):
7+
def __init__(self, nf, ni, submodule=None, input_channels=None, dropout=False, innermost=False, outermost=False):
98
super().__init__()
109
self.outermost = outermost
11-
if input_c is None: input_c = nf
12-
downconv = nn.Conv2d(input_c, ni, kernel_size=4,
13-
stride=2, padding=1, bias=False)
14-
downrelu = nn.LeakyReLU(0.2, True)
15-
downnorm = nn.BatchNorm2d(ni)
16-
uprelu = nn.ReLU(True)
17-
upnorm = nn.BatchNorm2d(nf)
10+
if input_channels is None:
11+
input_channels = nf
12+
down_convolution = nn.Conv2d(input_channels, ni, kernel_size=4, stride=2, padding=1, bias=False)
13+
down_relu = nn.LeakyReLU(0.2, True)
14+
down_normalization = nn.BatchNorm2d(ni)
15+
up_relu = nn.ReLU(True)
16+
up_norm = nn.BatchNorm2d(nf)
1817

1918
if outermost:
20-
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1)
21-
down = [downconv]
22-
up = [uprelu, upconv, nn.Tanh()]
19+
up_convolution = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1)
20+
down = [down_convolution]
21+
up = [up_relu, up_convolution, nn.Tanh()]
2322
model = down + [submodule] + up
2423
elif innermost:
25-
upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4, stride=2, padding=1, bias=False)
26-
down = [downrelu, downconv]
27-
up = [uprelu, upconv, upnorm]
24+
up_convolution = nn.ConvTranspose2d(ni, nf, kernel_size=4, stride=2, padding=1, bias=False)
25+
down = [down_relu, down_convolution]
26+
up = [up_relu, up_convolution, up_norm]
2827
model = down + up
2928
else:
30-
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
31-
stride=2, padding=1, bias=False)
32-
down = [downrelu, downconv, downnorm]
33-
up = [uprelu, upconv, upnorm]
34-
if dropout: up += [nn.Dropout(0.5)]
29+
up_convolution = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False)
30+
down = [down_relu, down_convolution, down_normalization]
31+
up = [up_relu, up_convolution, up_norm]
32+
if dropout:
33+
up += [nn.Dropout(0.5)]
3534
model = down + [submodule] + up
3635
self.model = nn.Sequential(*model)
3736

@@ -43,40 +42,42 @@ def forward(self, x):
4342

4443

4544
class Unet(nn.Module):
46-
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
45+
def __init__(self, nfg=64):
4746
super().__init__()
48-
unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
49-
for _ in range(n_down - 5):
50-
unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
51-
out_filters = num_filters * 8
47+
unet_block = UnetBlock(nfg * 8, nfg * 8, innermost=True)
48+
for _ in range(3):
49+
unet_block = UnetBlock(nfg * 8, nfg * 8, submodule=unet_block, dropout=True)
50+
out_filters = nfg * 8
5251
for _ in range(3):
5352
unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
5453
out_filters //= 2
55-
self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
54+
self.model = UnetBlock(2, out_filters, input_channels=1, submodule=unet_block, outermost=True)
5655

5756
def forward(self, x):
5857
return self.model(x)
5958

6059

6160
class PatchDiscriminator(nn.Module):
62-
def __init__(self, input_c, num_filters=64, n_down=3):
61+
def __init__(self, nfd=64):
6362
super().__init__()
64-
model = [self.get_layers(input_c, num_filters, norm=False)]
65-
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down - 1) else 2)
66-
for i in range(n_down)] # the 'if' statement is taking care of not using
67-
# stride of 2 for the last block in this loop
68-
model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False,
69-
act=False)] # Make sure to not use normalization or
70-
# activation for the last layer of the model
71-
self.model = nn.Sequential(*model)
63+
self.model = nn.Sequential(
64+
nn.Conv2d(3, nfd, 4, 2, 1, bias=True),
65+
nn.LeakyReLU(0.2, True),
66+
67+
nn.Conv2d(nfd, nfd * 2, 4, 2, 1, bias=False),
68+
nn.BatchNorm2d(nfd * 2),
69+
nn.LeakyReLU(0.2, True),
70+
71+
nn.Conv2d(nfd * 2, nfd * 4, 4, 2, 1, bias=False),
72+
nn.BatchNorm2d(nfd * 4),
73+
nn.LeakyReLU(0.2, True),
74+
75+
nn.Conv2d(nfd * 4, nfd * 8, 4, 1, 1, bias=False),
76+
nn.BatchNorm2d(nfd * 8),
77+
nn.LeakyReLU(0.2, True),
7278

73-
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True,
74-
act=True): # when needing to make some repeatitive blocks of layers,
75-
layers = [
76-
nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
77-
if norm: layers += [nn.BatchNorm2d(nf)]
78-
if act: layers += [nn.LeakyReLU(0.2, True)]
79-
return nn.Sequential(*layers)
79+
nn.Conv2d(nfd * 8, nfd * 16, 4, 1, 1, bias=True),
80+
)
8081

8182
def forward(self, x):
8283
return self.model(x)
@@ -111,21 +112,27 @@ def init_model(model, device):
111112

112113

113114
class MainModel(nn.Module):
114-
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, beta1=0.5, beta2=0.999, lambda_L1=100.):
115+
def __init__(self,
116+
net_generator=None,
117+
lr_generator=2e-4,
118+
lr_discriminator=2e-4,
119+
beta1=0.5,
120+
beta2=0.999,
121+
lambda_l1=100.):
115122
super().__init__()
116123

117124
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
118-
self.lambda_L1 = lambda_L1
125+
self.lambda_L1 = lambda_l1
119126

120-
if net_G is None:
121-
self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
127+
if net_generator is None:
128+
self.net_G = init_model(Unet(nfg=64), self.device)
122129
else:
123-
self.net_G = net_G.to(self.device)
124-
self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
125-
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
126-
self.L1criterion = nn.L1Loss()
127-
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
128-
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
130+
self.net_G = net_generator.to(self.device)
131+
self.net_D = init_model(PatchDiscriminator(nfd=64), self.device)
132+
self.GAN_criterion = GANLoss(gan_mode='vanilla').to(self.device)
133+
self.L1_criterion = nn.L1Loss()
134+
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_generator, betas=(beta1, beta2))
135+
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_discriminator, betas=(beta1, beta2))
129136

130137
def set_requires_grad(self, model, requires_grad=True):
131138
for p in model.parameters():
@@ -140,19 +147,19 @@ def forward(self):
140147

141148
def backward_D(self):
142149
fake_image = torch.cat([self.L, self.fake_color], dim=1)
143-
fake_preds = self.net_D(fake_image.detach())
144-
self.loss_D_fake = self.GANcriterion(fake_preds, False)
150+
fake_predictions = self.net_D(fake_image.detach())
151+
self.loss_D_fake = self.GAN_criterion(fake_predictions, False)
145152
real_image = torch.cat([self.L, self.ab], dim=1)
146-
real_preds = self.net_D(real_image)
147-
self.loss_D_real = self.GANcriterion(real_preds, True)
153+
real_predictions = self.net_D(real_image)
154+
self.loss_D_real = self.GAN_criterion(real_predictions, True)
148155
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
149156
self.loss_D.backward()
150157

151158
def backward_G(self):
152159
fake_image = torch.cat([self.L, self.fake_color], dim=1)
153-
fake_preds = self.net_D(fake_image)
154-
self.loss_G_GAN = self.GANcriterion(fake_preds, True)
155-
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
160+
fake_predictions = self.net_D(fake_image)
161+
self.loss_G_GAN = self.GAN_criterion(fake_predictions, True)
162+
self.loss_G_L1 = self.L1_criterion(self.fake_color, self.ab) * self.lambda_L1
156163
self.loss_G = self.loss_G_GAN + self.loss_G_L1
157164
self.loss_G.backward()
158165

0 commit comments

Comments
 (0)