44
55
66class UnetBlock (nn .Module ):
7- def __init__ (self , nf , ni , submodule = None , input_c = None , dropout = False ,
8- innermost = False , outermost = False ):
7+ def __init__ (self , nf , ni , submodule = None , input_channels = None , dropout = False , innermost = False , outermost = False ):
98 super ().__init__ ()
109 self .outermost = outermost
11- if input_c is None : input_c = nf
12- downconv = nn . Conv2d ( input_c , ni , kernel_size = 4 ,
13- stride = 2 , padding = 1 , bias = False )
14- downrelu = nn .LeakyReLU (0.2 , True )
15- downnorm = nn .BatchNorm2d (ni )
16- uprelu = nn .ReLU (True )
17- upnorm = nn .BatchNorm2d (nf )
10+ if input_channels is None :
11+ input_channels = nf
12+ down_convolution = nn . Conv2d ( input_channels , ni , kernel_size = 4 , stride = 2 , padding = 1 , bias = False )
13+ down_relu = nn .LeakyReLU (0.2 , True )
14+ down_normalization = nn .BatchNorm2d (ni )
15+ up_relu = nn .ReLU (True )
16+ up_norm = nn .BatchNorm2d (nf )
1817
1918 if outermost :
20- upconv = nn .ConvTranspose2d (ni * 2 , nf , kernel_size = 4 , stride = 2 , padding = 1 )
21- down = [downconv ]
22- up = [uprelu , upconv , nn .Tanh ()]
19+ up_convolution = nn .ConvTranspose2d (ni * 2 , nf , kernel_size = 4 , stride = 2 , padding = 1 )
20+ down = [down_convolution ]
21+ up = [up_relu , up_convolution , nn .Tanh ()]
2322 model = down + [submodule ] + up
2423 elif innermost :
25- upconv = nn .ConvTranspose2d (ni , nf , kernel_size = 4 , stride = 2 , padding = 1 , bias = False )
26- down = [downrelu , downconv ]
27- up = [uprelu , upconv , upnorm ]
24+ up_convolution = nn .ConvTranspose2d (ni , nf , kernel_size = 4 , stride = 2 , padding = 1 , bias = False )
25+ down = [down_relu , down_convolution ]
26+ up = [up_relu , up_convolution , up_norm ]
2827 model = down + up
2928 else :
30- upconv = nn .ConvTranspose2d (ni * 2 , nf , kernel_size = 4 ,
31- stride = 2 , padding = 1 , bias = False )
32- down = [downrelu , downconv , downnorm ]
33- up = [ uprelu , upconv , upnorm ]
34- if dropout : up += [nn .Dropout (0.5 )]
29+ up_convolution = nn .ConvTranspose2d (ni * 2 , nf , kernel_size = 4 , stride = 2 , padding = 1 , bias = False )
30+ down = [ down_relu , down_convolution , down_normalization ]
31+ up = [up_relu , up_convolution , up_norm ]
32+ if dropout :
33+ up += [nn .Dropout (0.5 )]
3534 model = down + [submodule ] + up
3635 self .model = nn .Sequential (* model )
3736
@@ -43,40 +42,42 @@ def forward(self, x):
4342
4443
4544class Unet (nn .Module ):
46- def __init__ (self , input_c = 1 , output_c = 2 , n_down = 8 , num_filters = 64 ):
45+ def __init__ (self , nfg = 64 ):
4746 super ().__init__ ()
48- unet_block = UnetBlock (num_filters * 8 , num_filters * 8 , innermost = True )
49- for _ in range (n_down - 5 ):
50- unet_block = UnetBlock (num_filters * 8 , num_filters * 8 , submodule = unet_block , dropout = True )
51- out_filters = num_filters * 8
47+ unet_block = UnetBlock (nfg * 8 , nfg * 8 , innermost = True )
48+ for _ in range (3 ):
49+ unet_block = UnetBlock (nfg * 8 , nfg * 8 , submodule = unet_block , dropout = True )
50+ out_filters = nfg * 8
5251 for _ in range (3 ):
5352 unet_block = UnetBlock (out_filters // 2 , out_filters , submodule = unet_block )
5453 out_filters //= 2
55- self .model = UnetBlock (output_c , out_filters , input_c = input_c , submodule = unet_block , outermost = True )
54+ self .model = UnetBlock (2 , out_filters , input_channels = 1 , submodule = unet_block , outermost = True )
5655
5756 def forward (self , x ):
5857 return self .model (x )
5958
6059
6160class PatchDiscriminator (nn .Module ):
62- def __init__ (self , input_c , num_filters = 64 , n_down = 3 ):
61+ def __init__ (self , nfd = 64 ):
6362 super ().__init__ ()
64- model = [self .get_layers (input_c , num_filters , norm = False )]
65- model += [self .get_layers (num_filters * 2 ** i , num_filters * 2 ** (i + 1 ), s = 1 if i == (n_down - 1 ) else 2 )
66- for i in range (n_down )] # the 'if' statement is taking care of not using
67- # stride of 2 for the last block in this loop
68- model += [self .get_layers (num_filters * 2 ** n_down , 1 , s = 1 , norm = False ,
69- act = False )] # Make sure to not use normalization or
70- # activation for the last layer of the model
71- self .model = nn .Sequential (* model )
63+ self .model = nn .Sequential (
64+ nn .Conv2d (3 , nfd , 4 , 2 , 1 , bias = True ),
65+ nn .LeakyReLU (0.2 , True ),
66+
67+ nn .Conv2d (nfd , nfd * 2 , 4 , 2 , 1 , bias = False ),
68+ nn .BatchNorm2d (nfd * 2 ),
69+ nn .LeakyReLU (0.2 , True ),
70+
71+ nn .Conv2d (nfd * 2 , nfd * 4 , 4 , 2 , 1 , bias = False ),
72+ nn .BatchNorm2d (nfd * 4 ),
73+ nn .LeakyReLU (0.2 , True ),
74+
75+ nn .Conv2d (nfd * 4 , nfd * 8 , 4 , 1 , 1 , bias = False ),
76+ nn .BatchNorm2d (nfd * 8 ),
77+ nn .LeakyReLU (0.2 , True ),
7278
73- def get_layers (self , ni , nf , k = 4 , s = 2 , p = 1 , norm = True ,
74- act = True ): # when needing to make some repeatitive blocks of layers,
75- layers = [
76- nn .Conv2d (ni , nf , k , s , p , bias = not norm )] # it's always helpful to make a separate method for that purpose
77- if norm : layers += [nn .BatchNorm2d (nf )]
78- if act : layers += [nn .LeakyReLU (0.2 , True )]
79- return nn .Sequential (* layers )
79+ nn .Conv2d (nfd * 8 , nfd * 16 , 4 , 1 , 1 , bias = True ),
80+ )
8081
8182 def forward (self , x ):
8283 return self .model (x )
@@ -111,21 +112,27 @@ def init_model(model, device):
111112
112113
113114class MainModel (nn .Module ):
114- def __init__ (self , net_G = None , lr_G = 2e-4 , lr_D = 2e-4 , beta1 = 0.5 , beta2 = 0.999 , lambda_L1 = 100. ):
115+ def __init__ (self ,
116+ net_generator = None ,
117+ lr_generator = 2e-4 ,
118+ lr_discriminator = 2e-4 ,
119+ beta1 = 0.5 ,
120+ beta2 = 0.999 ,
121+ lambda_l1 = 100. ):
115122 super ().__init__ ()
116123
117124 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
118- self .lambda_L1 = lambda_L1
125+ self .lambda_L1 = lambda_l1
119126
120- if net_G is None :
121- self .net_G = init_model (Unet (input_c = 1 , output_c = 2 , n_down = 8 , num_filters = 64 ), self .device )
127+ if net_generator is None :
128+ self .net_G = init_model (Unet (nfg = 64 ), self .device )
122129 else :
123- self .net_G = net_G .to (self .device )
124- self .net_D = init_model (PatchDiscriminator (input_c = 3 , n_down = 3 , num_filters = 64 ), self .device )
125- self .GANcriterion = GANLoss (gan_mode = 'vanilla' ).to (self .device )
126- self .L1criterion = nn .L1Loss ()
127- self .opt_G = optim .Adam (self .net_G .parameters (), lr = lr_G , betas = (beta1 , beta2 ))
128- self .opt_D = optim .Adam (self .net_D .parameters (), lr = lr_D , betas = (beta1 , beta2 ))
130+ self .net_G = net_generator .to (self .device )
131+ self .net_D = init_model (PatchDiscriminator (nfd = 64 ), self .device )
132+ self .GAN_criterion = GANLoss (gan_mode = 'vanilla' ).to (self .device )
133+ self .L1_criterion = nn .L1Loss ()
134+ self .opt_G = optim .Adam (self .net_G .parameters (), lr = lr_generator , betas = (beta1 , beta2 ))
135+ self .opt_D = optim .Adam (self .net_D .parameters (), lr = lr_discriminator , betas = (beta1 , beta2 ))
129136
130137 def set_requires_grad (self , model , requires_grad = True ):
131138 for p in model .parameters ():
@@ -140,19 +147,19 @@ def forward(self):
140147
141148 def backward_D (self ):
142149 fake_image = torch .cat ([self .L , self .fake_color ], dim = 1 )
143- fake_preds = self .net_D (fake_image .detach ())
144- self .loss_D_fake = self .GANcriterion ( fake_preds , False )
150+ fake_predictions = self .net_D (fake_image .detach ())
151+ self .loss_D_fake = self .GAN_criterion ( fake_predictions , False )
145152 real_image = torch .cat ([self .L , self .ab ], dim = 1 )
146- real_preds = self .net_D (real_image )
147- self .loss_D_real = self .GANcriterion ( real_preds , True )
153+ real_predictions = self .net_D (real_image )
154+ self .loss_D_real = self .GAN_criterion ( real_predictions , True )
148155 self .loss_D = (self .loss_D_fake + self .loss_D_real ) * 0.5
149156 self .loss_D .backward ()
150157
151158 def backward_G (self ):
152159 fake_image = torch .cat ([self .L , self .fake_color ], dim = 1 )
153- fake_preds = self .net_D (fake_image )
154- self .loss_G_GAN = self .GANcriterion ( fake_preds , True )
155- self .loss_G_L1 = self .L1criterion (self .fake_color , self .ab ) * self .lambda_L1
160+ fake_predictions = self .net_D (fake_image )
161+ self .loss_G_GAN = self .GAN_criterion ( fake_predictions , True )
162+ self .loss_G_L1 = self .L1_criterion (self .fake_color , self .ab ) * self .lambda_L1
156163 self .loss_G = self .loss_G_GAN + self .loss_G_L1
157164 self .loss_G .backward ()
158165
0 commit comments