44
55
66class UnetBlock (nn .Module ):
7+ """
8+ Defines a U-Net block, a fundamental building block for the generator in an image-to-image translation model.
9+ """
710 def __init__ (self , nf , ni , submodule = None , input_channels = None , dropout = False , innermost = False , outermost = False ):
11+ """
12+ Initialize UnetBlock.
13+
14+ :param nf: Number of filters.
15+ :type nf: int
16+ :param ni: Number of input channels.
17+ :type ni: int
18+ :param submodule: Submodule to be included inside the block. Default is None.
19+ :type submodule: nn.Module, optional
20+ :param input_channels: Number of input channels. Default is None.
21+ :type input_channels: int, optional
22+ :param dropout: Whether to apply dropout. Default is False.
23+ :type dropout: bool, optional
24+ :param innermost: Whether the block is innermost. Default is False.
25+ :type innermost: bool, optional
26+ :param outermost: Whether the block is outermost. Default is False.
27+ :type outermost: bool, optional
28+ """
829 super ().__init__ ()
930 self .outermost = outermost
1031 if input_channels is None :
@@ -35,13 +56,25 @@ def __init__(self, nf, ni, submodule=None, input_channels=None, dropout=False, i
3556 self .model = nn .Sequential (* model )
3657
3758 def forward (self , x ):
59+ """
60+ Forward pass through the U-Net block.
61+
62+ :param x: Input tensor.
63+ :type x: torch.Tensor
64+
65+ :return: Output tensor.
66+ :rtype: torch.Tensor
67+ """
3868 if self .outermost :
3969 return self .model (x )
4070 else :
4171 return torch .cat ([x , self .model (x )], 1 )
4272
4373
4474class Unet (nn .Module ):
75+ """
76+ Defines a U-Net model constructed using U-Net Blocks
77+ """
4578 def __init__ (self , nfg = 64 ):
4679 super ().__init__ ()
4780 unet_block = UnetBlock (nfg * 8 , nfg * 8 , innermost = True )
@@ -54,11 +87,29 @@ def __init__(self, nfg=64):
5487 self .model = UnetBlock (2 , out_filters , input_channels = 1 , submodule = unet_block , outermost = True )
5588
5689 def forward (self , x ):
90+ """
91+ Forward pass through the U-Net model.
92+
93+ :param x: Input tensor.
94+ :type x: torch.Tensor
95+
96+ :return: Output tensor.
97+ :rtype: torch.Tensor
98+ """
5799 return self .model (x )
58100
59101
60102class PatchDiscriminator (nn .Module ):
103+ """
104+ Defines a PatchGAN discriminator for image-to-image translation tasks.
105+ """
61106 def __init__ (self , nfd = 64 ):
107+ """
108+ Initialize PatchDiscriminator.
109+
110+ :param nfd: Number of initial filters. Default is 64.
111+ :type nfd: int, optional
112+ """
62113 super ().__init__ ()
63114 # No normalization in first block
64115 model = [self .get_layers (3 , nfd , normalization = False )]
@@ -70,6 +121,27 @@ def __init__(self, nfd=64):
70121 self .model = nn .Sequential (* model )
71122
72123 def get_layers (self , ni , nf , k = 4 , s = 2 , p = 1 , normalization = True , action = True ):
124+ """
125+ Helper method to create layers for the PatchGAN discriminator.
126+
127+ :param ni: Number of input channels.
128+ :type ni: int
129+ :param nf: Number of filters.
130+ :type nf: int
131+ :param k: Kernel size. Default is 4.
132+ :type k: int, optional
133+ :param s: Stride. Default is 2.
134+ :type s: int, optional
135+ :param p: Padding. Default is 1.
136+ :type p: int, optional
137+ :param normalization: Whether to apply batch normalization. Default is True.
138+ :type normalization: bool, optional
139+ :param action: Whether to apply activation function. Default is True.
140+ :type action: bool, optional
141+
142+ :return: Sequential model representing the layers.
143+ :rtype: nn.Sequential
144+ """
73145 layers = [
74146 nn .Conv2d (ni , nf , k , s , p , bias = not normalization )]
75147 if normalization :
@@ -79,10 +151,32 @@ def get_layers(self, ni, nf, k=4, s=2, p=1, normalization=True, action=True):
79151 return nn .Sequential (* layers )
80152
81153 def forward (self , x ):
154+ """
155+ Forward pass through the PatchGAN discriminator.
156+
157+ :param x: Input tensor.
158+ :type x: torch.Tensor
159+
160+ :return: Output tensor.
161+ :rtype: torch.Tensor
162+ """
82163 return self .model (x )
83164
84165
85166def init_weights (net , init = 'norm' , gain = 0.02 ):
167+ """
168+ Initialize weights for the neural network.
169+
170+ :param net: Neural network model.
171+ :type net: nn.Module
172+ :param init: Initialization method. Default is 'norm'.
173+ :type init: str, optional
174+ :param gain: Gain factor for weight initialization. Default is 0.02.
175+ :type gain: float, optional
176+
177+ :return: Initialized neural network model.
178+ :rtype: nn.Module
179+ """
86180 def init_func (m ):
87181 classname = m .__class__ .__name__
88182 if hasattr (m , 'weight' ) and 'Conv' in classname :
@@ -105,19 +199,49 @@ def init_func(m):
105199
106200
107201def init_model (model , device ):
202+ """
203+ Initialize a neural network model.
204+
205+ :param model: Neural network model.
206+ :type model: nn.Module
207+ :param device: Device to which the model will be moved.
208+ :type device: torch.device
209+
210+ :return: Initialized and moved neural network model.
211+ :rtype: nn.Module
212+ """
108213 model = model .to (device )
109214 model = init_weights (model )
110215 return model
111216
112217
113218class MainModel (nn .Module ):
219+ """
220+ Main model for image-to-image translation tasks using a conditional GAN with a U-Net generator.
221+ """
114222 def __init__ (self ,
115223 net_generator = None ,
116224 lr_generator = 2e-4 ,
117225 lr_discriminator = 2e-4 ,
118226 beta1 = 0.5 ,
119227 beta2 = 0.999 ,
120228 lambda_l1 = 100. ):
229+ """
230+ Initialize MainModel.
231+
232+ :param net_generator: Predefined generator network. Default is None.
233+ :type net_generator: nn.Module, optional
234+ :param lr_generator: Learning rate for the generator. Default is 2e-4.
235+ :type lr_generator: float, optional
236+ :param lr_discriminator: Learning rate for the discriminator. Default is 2e-4.
237+ :type lr_discriminator: float, optional
238+ :param beta1: Beta1 parameter for Adam optimizer. Default is 0.5.
239+ :type beta1: float, optional
240+ :param beta2: Beta2 parameter for Adam optimizer. Default is 0.999.
241+ :type beta2: float, optional
242+ :param lambda_l1: Weight for L1 loss term. Default is 100.
243+ :type lambda_l1: float, optional
244+ """
121245 super ().__init__ ()
122246
123247 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -134,17 +258,37 @@ def __init__(self,
134258 self .opt_D = optim .Adam (self .net_D .parameters (), lr = lr_discriminator , betas = (beta1 , beta2 ))
135259
136260 def set_requires_grad (self , model , requires_grad = True ):
261+ """
262+ Set the requires_grad attribute for model parameters.
263+
264+ :param model: Model for which to set the requires_grad attribute.
265+ :type model: nn.Module
266+ :param requires_grad: Whether to set requires_grad to True or False. Default is True.
267+ :type requires_grad: bool, optional
268+ """
137269 for p in model .parameters ():
138270 p .requires_grad = requires_grad
139271
140272 def setup_input (self , data ):
273+ """
274+ Move input data to the specified device.
275+
276+ :param data: Input data containing LAB colorspace components.
277+ :type data: dict
278+ """
141279 self .L = data ['L' ].to (self .device )
142280 self .ab = data ['ab' ].to (self .device )
143281
144282 def forward (self ):
283+ """
284+ Forward pass through the generator
285+ """
145286 self .fake_color = self .net_G (self .L )
146287
147288 def backward_D (self ):
289+ """
290+ Backward pass and optimization for the discriminator.
291+ """
148292 fake_image = torch .cat ([self .L , self .fake_color ], dim = 1 )
149293 fake_predictions = self .net_D (fake_image .detach ())
150294 self .loss_D_fake = self .GAN_criterion (fake_predictions , False )
@@ -155,6 +299,9 @@ def backward_D(self):
155299 self .loss_D .backward ()
156300
157301 def backward_G (self ):
302+ """
303+ Backward pass for the generator.
304+ """
158305 fake_image = torch .cat ([self .L , self .fake_color ], dim = 1 )
159306 fake_predictions = self .net_D (fake_image )
160307 self .loss_G_GAN = self .GAN_criterion (fake_predictions , True )
@@ -163,6 +310,20 @@ def backward_G(self):
163310 self .loss_G .backward ()
164311
165312 def optimize (self ):
313+ """
314+ Optimization of the whole model.
315+
316+ This method performs a single optimization step for both the generator and the discriminator.
317+ It includes the forward pass, backward pass, and parameter updates.
318+
319+ Steps:
320+ 1. Perform the forward pass through the generator.
321+ 2. Set the discriminator to training mode and enable gradients for its parameters.
322+ 3. Zero the gradients of the discriminator optimizer.
323+ 4. Perform the backward pass for the discriminator and update its parameters.
324+ 5. Set the generator to training mode and disable gradients for the discriminator parameters.
325+ 6. Zero the gradients of the generator optimizer.
326+ """
166327 self .forward ()
167328 self .net_D .train ()
168329 self .set_requires_grad (self .net_D , True )
0 commit comments