Skip to content

Commit 2645aef

Browse files
Added python docs
1 parent 1a92dfc commit 2645aef

5 files changed

Lines changed: 285 additions & 3 deletions

File tree

dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,18 @@
88

99

1010
class ColorizationDataset(Dataset):
11+
"""
12+
Custom class for loading the dataset into the neural network using PyTorch
13+
"""
14+
1115
def __init__(self, paths, split='train'):
16+
"""
17+
Colorization Dataset initializer
18+
:param paths: List of file paths for images
19+
:type paths: list
20+
:param split:Dataset split, either 'train' or 'test'
21+
:type split: str, optional
22+
"""
1223
if split == 'train':
1324
self.transforms = transforms.Compose([
1425
transforms.Resize((SIZE, SIZE), Image.BICUBIC),
@@ -37,6 +48,20 @@ def __len__(self):
3748

3849

3950
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
51+
"""
52+
Create DataLoader for ColorizationDataset.
53+
54+
:param batch_size: Number of samples per batch. Default is 16.
55+
:type batch_size: int, optional
56+
:param n_workers: Number of workers for data loading. Default is 4.
57+
:type n_workers: int, optional
58+
:param pin_memory: Whether to use pinned memory for faster data transfer. Default is True.
59+
:type pin_memory: bool, optional
60+
:param kwargs: Additional arguments to pass to ColorizationDataset.
61+
62+
:return: DataLoader for the ColorizationDataset.
63+
:rtype: torch.utils.data.DataLoader
64+
"""
4065
dataset = ColorizationDataset(**kwargs)
4166
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, pin_memory=pin_memory)
4267
return dataloader

loss.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33

44

55
class GANLoss(nn.Module):
6+
"""
7+
A module for measuring the GAN Loss.
8+
"""
69
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
10+
"""
11+
Initializes the GANLoss module.
12+
13+
:param real_label: The label value for real samples (default is 1.0).
14+
:type real_label: float
15+
:param fake_label: The label value for fake/generated samples (default is 0.0).
16+
:type fake_label: float
17+
"""
718
super().__init__()
819
self.register_buffer('real_label', torch.tensor(real_label))
920
self.register_buffer('fake_label', torch.tensor(fake_label))
@@ -13,13 +24,35 @@ def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
1324
self.loss = nn.MSELoss()
1425

1526
def get_labels(self, preds, target_is_real):
27+
"""
28+
Returns the target labels based on whether the target is real or fake.
29+
30+
:param preds:The predictions from the discriminator.
31+
:type preds: torch.Tensor
32+
:param target_is_real:Indicates whether the target is a real sample.
33+
:type target_is_real: bool
34+
35+
:return: The target labels expanded to match the shape of predictions.
36+
:rtype: torch.Tensor
37+
"""
1638
if target_is_real:
1739
labels = self.real_label
1840
else:
1941
labels = self.fake_label
2042
return labels.expand_as(preds)
2143

2244
def __call__(self, preds, target_is_real):
45+
"""
46+
Computes and returns the adversarial loss given the predictions and target labels.
47+
48+
:param preds: The predictions from the discriminator.
49+
:type preds: torch.Tensor
50+
:param target_is_real: Indicates whether the target is a real sample.
51+
:type target_is_real: bool
52+
53+
:return: The computed adversarial loss.
54+
:rtype: torch.Tensor
55+
"""
2356
labels = self.get_labels(preds, target_is_real)
2457
loss = self.loss(preds, labels)
2558
return loss

models.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,28 @@
44

55

66
class UnetBlock(nn.Module):
7+
"""
8+
Defines a U-Net block, a fundamental building block for the generator in an image-to-image translation model.
9+
"""
710
def __init__(self, nf, ni, submodule=None, input_channels=None, dropout=False, innermost=False, outermost=False):
11+
"""
12+
Initialize UnetBlock.
13+
14+
:param nf: Number of filters.
15+
:type nf: int
16+
:param ni: Number of input channels.
17+
:type ni: int
18+
:param submodule: Submodule to be included inside the block. Default is None.
19+
:type submodule: nn.Module, optional
20+
:param input_channels: Number of input channels. Default is None.
21+
:type input_channels: int, optional
22+
:param dropout: Whether to apply dropout. Default is False.
23+
:type dropout: bool, optional
24+
:param innermost: Whether the block is innermost. Default is False.
25+
:type innermost: bool, optional
26+
:param outermost: Whether the block is outermost. Default is False.
27+
:type outermost: bool, optional
28+
"""
829
super().__init__()
930
self.outermost = outermost
1031
if input_channels is None:
@@ -35,13 +56,25 @@ def __init__(self, nf, ni, submodule=None, input_channels=None, dropout=False, i
3556
self.model = nn.Sequential(*model)
3657

3758
def forward(self, x):
59+
"""
60+
Forward pass through the U-Net block.
61+
62+
:param x: Input tensor.
63+
:type x: torch.Tensor
64+
65+
:return: Output tensor.
66+
:rtype: torch.Tensor
67+
"""
3868
if self.outermost:
3969
return self.model(x)
4070
else:
4171
return torch.cat([x, self.model(x)], 1)
4272

4373

4474
class Unet(nn.Module):
75+
"""
76+
Defines a U-Net model constructed using U-Net Blocks
77+
"""
4578
def __init__(self, nfg=64):
4679
super().__init__()
4780
unet_block = UnetBlock(nfg * 8, nfg * 8, innermost=True)
@@ -54,11 +87,29 @@ def __init__(self, nfg=64):
5487
self.model = UnetBlock(2, out_filters, input_channels=1, submodule=unet_block, outermost=True)
5588

5689
def forward(self, x):
90+
"""
91+
Forward pass through the U-Net model.
92+
93+
:param x: Input tensor.
94+
:type x: torch.Tensor
95+
96+
:return: Output tensor.
97+
:rtype: torch.Tensor
98+
"""
5799
return self.model(x)
58100

59101

60102
class PatchDiscriminator(nn.Module):
103+
"""
104+
Defines a PatchGAN discriminator for image-to-image translation tasks.
105+
"""
61106
def __init__(self, nfd=64):
107+
"""
108+
Initialize PatchDiscriminator.
109+
110+
:param nfd: Number of initial filters. Default is 64.
111+
:type nfd: int, optional
112+
"""
62113
super().__init__()
63114
# No normalization in first block
64115
model = [self.get_layers(3, nfd, normalization=False)]
@@ -70,6 +121,27 @@ def __init__(self, nfd=64):
70121
self.model = nn.Sequential(*model)
71122

72123
def get_layers(self, ni, nf, k=4, s=2, p=1, normalization=True, action=True):
124+
"""
125+
Helper method to create layers for the PatchGAN discriminator.
126+
127+
:param ni: Number of input channels.
128+
:type ni: int
129+
:param nf: Number of filters.
130+
:type nf: int
131+
:param k: Kernel size. Default is 4.
132+
:type k: int, optional
133+
:param s: Stride. Default is 2.
134+
:type s: int, optional
135+
:param p: Padding. Default is 1.
136+
:type p: int, optional
137+
:param normalization: Whether to apply batch normalization. Default is True.
138+
:type normalization: bool, optional
139+
:param action: Whether to apply activation function. Default is True.
140+
:type action: bool, optional
141+
142+
:return: Sequential model representing the layers.
143+
:rtype: nn.Sequential
144+
"""
73145
layers = [
74146
nn.Conv2d(ni, nf, k, s, p, bias=not normalization)]
75147
if normalization:
@@ -79,10 +151,32 @@ def get_layers(self, ni, nf, k=4, s=2, p=1, normalization=True, action=True):
79151
return nn.Sequential(*layers)
80152

81153
def forward(self, x):
154+
"""
155+
Forward pass through the PatchGAN discriminator.
156+
157+
:param x: Input tensor.
158+
:type x: torch.Tensor
159+
160+
:return: Output tensor.
161+
:rtype: torch.Tensor
162+
"""
82163
return self.model(x)
83164

84165

85166
def init_weights(net, init='norm', gain=0.02):
167+
"""
168+
Initialize weights for the neural network.
169+
170+
:param net: Neural network model.
171+
:type net: nn.Module
172+
:param init: Initialization method. Default is 'norm'.
173+
:type init: str, optional
174+
:param gain: Gain factor for weight initialization. Default is 0.02.
175+
:type gain: float, optional
176+
177+
:return: Initialized neural network model.
178+
:rtype: nn.Module
179+
"""
86180
def init_func(m):
87181
classname = m.__class__.__name__
88182
if hasattr(m, 'weight') and 'Conv' in classname:
@@ -105,19 +199,49 @@ def init_func(m):
105199

106200

107201
def init_model(model, device):
202+
"""
203+
Initialize a neural network model.
204+
205+
:param model: Neural network model.
206+
:type model: nn.Module
207+
:param device: Device to which the model will be moved.
208+
:type device: torch.device
209+
210+
:return: Initialized and moved neural network model.
211+
:rtype: nn.Module
212+
"""
108213
model = model.to(device)
109214
model = init_weights(model)
110215
return model
111216

112217

113218
class MainModel(nn.Module):
219+
"""
220+
Main model for image-to-image translation tasks using a conditional GAN with a U-Net generator.
221+
"""
114222
def __init__(self,
115223
net_generator=None,
116224
lr_generator=2e-4,
117225
lr_discriminator=2e-4,
118226
beta1=0.5,
119227
beta2=0.999,
120228
lambda_l1=100.):
229+
"""
230+
Initialize MainModel.
231+
232+
:param net_generator: Predefined generator network. Default is None.
233+
:type net_generator: nn.Module, optional
234+
:param lr_generator: Learning rate for the generator. Default is 2e-4.
235+
:type lr_generator: float, optional
236+
:param lr_discriminator: Learning rate for the discriminator. Default is 2e-4.
237+
:type lr_discriminator: float, optional
238+
:param beta1: Beta1 parameter for Adam optimizer. Default is 0.5.
239+
:type beta1: float, optional
240+
:param beta2: Beta2 parameter for Adam optimizer. Default is 0.999.
241+
:type beta2: float, optional
242+
:param lambda_l1: Weight for L1 loss term. Default is 100.
243+
:type lambda_l1: float, optional
244+
"""
121245
super().__init__()
122246

123247
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -134,17 +258,37 @@ def __init__(self,
134258
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_discriminator, betas=(beta1, beta2))
135259

136260
def set_requires_grad(self, model, requires_grad=True):
261+
"""
262+
Set the requires_grad attribute for model parameters.
263+
264+
:param model: Model for which to set the requires_grad attribute.
265+
:type model: nn.Module
266+
:param requires_grad: Whether to set requires_grad to True or False. Default is True.
267+
:type requires_grad: bool, optional
268+
"""
137269
for p in model.parameters():
138270
p.requires_grad = requires_grad
139271

140272
def setup_input(self, data):
273+
"""
274+
Move input data to the specified device.
275+
276+
:param data: Input data containing LAB colorspace components.
277+
:type data: dict
278+
"""
141279
self.L = data['L'].to(self.device)
142280
self.ab = data['ab'].to(self.device)
143281

144282
def forward(self):
283+
"""
284+
Forward pass through the generator
285+
"""
145286
self.fake_color = self.net_G(self.L)
146287

147288
def backward_D(self):
289+
"""
290+
Backward pass and optimization for the discriminator.
291+
"""
148292
fake_image = torch.cat([self.L, self.fake_color], dim=1)
149293
fake_predictions = self.net_D(fake_image.detach())
150294
self.loss_D_fake = self.GAN_criterion(fake_predictions, False)
@@ -155,6 +299,9 @@ def backward_D(self):
155299
self.loss_D.backward()
156300

157301
def backward_G(self):
302+
"""
303+
Backward pass for the generator.
304+
"""
158305
fake_image = torch.cat([self.L, self.fake_color], dim=1)
159306
fake_predictions = self.net_D(fake_image)
160307
self.loss_G_GAN = self.GAN_criterion(fake_predictions, True)
@@ -163,6 +310,20 @@ def backward_G(self):
163310
self.loss_G.backward()
164311

165312
def optimize(self):
313+
"""
314+
Optimization of the whole model.
315+
316+
This method performs a single optimization step for both the generator and the discriminator.
317+
It includes the forward pass, backward pass, and parameter updates.
318+
319+
Steps:
320+
1. Perform the forward pass through the generator.
321+
2. Set the discriminator to training mode and enable gradients for its parameters.
322+
3. Zero the gradients of the discriminator optimizer.
323+
4. Perform the backward pass for the discriminator and update its parameters.
324+
5. Set the generator to training mode and disable gradients for the discriminator parameters.
325+
6. Zero the gradients of the generator optimizer.
326+
"""
166327
self.forward()
167328
self.net_D.train()
168329
self.set_requires_grad(self.net_D, True)

train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11

22

33
def train_model(model, train_dataloader, epochs, display_every=100):
4-
# visualization_data = next(iter(val_dl)) # getting a batch for visualizing the model output after fixed intervals
4+
"""
5+
Train the specified model using the provided data loader.
6+
7+
:param model: (nn.Module) The PyTorch model to be trained.
8+
:type model: nn.Module
9+
:param train_dataloader: (DataLoader) The DataLoader for training data.
10+
:type train_dataloader: DataLoader
11+
:param epochs: (int) The number of training epochs.
12+
:type epochs: int
13+
:param display_every: (int) Display training progress every specified number of iterations (default is 100).
14+
:type display_every: int
15+
16+
Example:
17+
train_model(my_model, train_data_loader, epochs=10, display_every=50)
18+
"""
519
for e in range(epochs):
620
loss_meter_dict = create_loss_meters()
721
i = 0

0 commit comments

Comments
 (0)