-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathConditionalGAN_With_LocalFiles.py
More file actions
501 lines (400 loc) · 18.4 KB
/
ConditionalGAN_With_LocalFiles.py
File metadata and controls
501 lines (400 loc) · 18.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 21 22:04:23 2021
@author: WesSa
"""
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
'''
Function for visualizing images: Given a tensor of images, number of images, and
size per image, plots and prints the images in an uniform grid.
'''
image_tensor = (image_tensor + 1) / 2
image_unflat = image_tensor.detach().cpu()
image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
if show:
plt.show()
# GENERATOR AND NOISE:
class Generator(nn.Module):
'''
Generator Class
Values:
input_dim: the dimension of the input vector, a scalar
im_chan: the number of channels of the output image, a scalar
(MNIST is black-and-white, so 1 channel is your default)
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
super(Generator, self).__init__()
self.input_dim = input_dim
# Build the neural network
self.gen = nn.Sequential(
self.make_gen_block(input_dim, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 2, hidden_dim),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a generator block of DCGAN;
a transposed convolution, a batchnorm (except in the final layer), and an activation.
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
'''
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh(),
)
def forward(self, noise):
'''
Function for completing a forward pass of the generator: Given a noise tensor,
returns generated images.
Parameters:
noise: a noise tensor with dimensions (n_samples, input_dim)
'''
x = noise.view(len(noise), self.input_dim, 1, 1)
return self.gen(x)
def get_noise(n_samples, input_dim, device='cpu'):
'''
Function for creating noise vectors: Given the dimensions (n_samples, input_dim)
creates a tensor of that shape filled with random numbers from the normal distribution.
Parameters:
n_samples: the number of samples to generate, a scalar
input_dim: the dimension of the input vector, a scalar
device: the device type
'''
return torch.randn(n_samples, input_dim, device=device)
# DISCRIMINATOR:
class Discriminator(nn.Module):
'''
Discriminator Class
Values:
im_chan: the number of channels of the output image, a scalar
(MNIST is black-and-white, so 1 channel is your default)
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, im_chan=1, hidden_dim=64):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
self.make_disc_block(im_chan, hidden_dim),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
)
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a discriminator block of the DCGAN;
a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
'''
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2, inplace=True),
)
else:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image):
'''
Function for completing a forward pass of the discriminator: Given an image tensor,
returns a 1-dimension tensor representing fake/real.
Parameters:
image: a flattened image tensor with dimension (im_chan)
'''
disc_pred = self.disc(image)
return disc_pred.view(len(disc_pred), -1)
# CLASS INPUT:
# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_one_hot_labels
import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
'''
Function for creating one-hot vectors for the labels, returns a tensor of shape (?, num_classes).
Parameters:
labels: tensor of labels from the dataloader, size (?)
n_classes: the total number of classes in the dataset, an integer scalar
'''
#### START CODE HERE ####
return F.one_hot(labels, num_classes= n_classes)
#### END CODE HERE ####
# Test one hot labels:
assert (
get_one_hot_labels(
labels=torch.Tensor([[0, 2, 1]]).long(),
n_classes=3
).tolist() ==
[[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]
]]
)
print("Success!")
# Combine Image vector and label vectors:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: combine_vectors
def combine_vectors(x, y):
'''
Function for combining two vectors with shapes (n_samples, ?) and (n_samples, ?).
Parameters:
x: (n_samples, ?) the first vector.
In this assignment, this will be the noise vector of shape (n_samples, z_dim),
but you shouldn't need to know the second dimension's size.
y: (n_samples, ?) the second vector.
Once again, in this assignment this will be the one-hot class vector
with the shape (n_samples, n_classes), but you shouldn't assume this in your code.
'''
# Note: Make sure this function outputs a float no matter what inputs it receives
#### START CODE HERE ####
combined = torch.cat((x.float(),y.float()),1)
#### END CODE HERE ####
return combined
mnist_shape = (1, 28, 28)
n_classes = 10
criterion = nn.BCEWithLogitsLoss()
n_epochs = 100
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
MNIST_Data = MNIST('.', download=True, transform=transform)
dataloader = DataLoader(
MNIST('.', download=True, transform=transform),
batch_size=batch_size,
shuffle=True)
# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_input_dimensions
def get_input_dimensions(z_dim, mnist_shape, n_classes):
'''
Function for getting the size of the conditional input dimensions
from z_dim, the image shape, and number of classes.
Parameters:
z_dim: the dimension of the noise vector, a scalar
mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
n_classes: the total number of classes in the dataset, an integer scalar
(10 for MNIST)
Returns:
generator_input_dim: the input dimensionality of the conditional generator,
which takes the noise and class vectors
discriminator_im_chan: the number of input channels to the discriminator
(e.g. C x 28 x 28 for MNIST)
'''
#### START CODE HERE ####
generator_input_dim = z_dim + n_classes
discriminator_im_chan = mnist_shape[0] + n_classes
#### END CODE HERE ####
return generator_input_dim, discriminator_im_chan
def test_input_dims():
gen_dim, disc_dim = get_input_dimensions(23, (12, 23, 52), 9)
assert gen_dim == 32
assert disc_dim == 21
test_input_dims()
print("Success!")
a = 1
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)
gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)
# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CELL
cur_step = 0
generator_losses = []
discriminator_losses = []
#UNIT TEST NOTE: Initializations needed for grading
noise_and_labels = False
fake = False
fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False
for epoch in range(n_epochs):
print("Current EPOCH: %d" % (epoch))
# Dataloader returns the batches and the labels
for real, labels in tqdm(dataloader):
cur_batch_size = len(real)
# Flatten the batch of real images from the dataset
real = real.to(device)
one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])
### Update discriminator ###
# Zero out the discriminator gradients
disc_opt.zero_grad()
# Get noise corresponding to the current batch_size
fake_noise = get_noise(cur_batch_size, z_dim, device=device)
# Now you can get the images from the generator
# Steps: 1) Combine the noise vectors and the one-hot labels for the generator
# 2) Generate the conditioned fake images
#### START CODE HERE ####
noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
fake = gen(noise_and_labels)
#### END CODE HERE ####
# Make sure that enough images were generated
assert len(fake) == len(real)
# Now you can get the predictions from the discriminator
# Steps: 1) Create the input for the discriminator
# a) Combine the fake images with image_one_hot_labels,
# remember to detach the generator (.detach()) so you do not backpropagate through it
# b) Combine the real images with image_one_hot_labels
# 2) Get the discriminator's prediction on the fakes as disc_fake_pred
# 3) Get the discriminator's prediction on the reals as disc_real_pred
#### START CODE HERE ####
fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
real_image_and_labels = combine_vectors(real, image_one_hot_labels)
disc_fake_pred = disc(fake_image_and_labels.detach())
disc_real_pred = disc(real_image_and_labels)
#### END CODE HERE ####
# Make sure that enough predictions were made
assert len(disc_real_pred) == len(real)
# Make sure that the inputs are different
assert torch.any(fake_image_and_labels != real_image_and_labels)
disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
disc_loss.backward(retain_graph=True)
disc_opt.step()
# Keep track of the average discriminator loss
discriminator_losses += [disc_loss.item()]
### Update generator ###
# Zero out the generator gradients
gen_opt.zero_grad()
fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
# This will error if you didn't concatenate your labels to your image correctly
disc_fake_pred = disc(fake_image_and_labels)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
gen_loss.backward()
gen_opt.step()
# Keep track of the generator losses
generator_losses += [gen_loss.item()]
#
if cur_step % display_step == 0 and cur_step > 0:
gen_mean = sum(generator_losses[-display_step:]) / display_step
disc_mean = sum(discriminator_losses[-display_step:]) / display_step
print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
show_tensor_images(fake)
show_tensor_images(real)
step_bins = 20
x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
num_examples = (len(generator_losses) // step_bins) * step_bins
plt.plot(
range(num_examples // step_bins),
torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
label="Generator Loss"
)
plt.plot(
range(num_examples // step_bins),
torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
label="Discriminator Loss"
)
plt.legend()
plt.show()
elif cur_step == 0:
print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
cur_step += 1
# Before you explore, you should put the generator
# in eval mode, both in general and so that batch norm
# doesn't cause you issues and is using its eval statistics
gen = gen.eval()
import math
### Change me! ###
n_interpolation = 9 # Choose the interpolation: how many intermediate images you want + 2 (for the start and end image)
interpolation_noise = get_noise(1, z_dim, device=device).repeat(n_interpolation, 1)
def interpolate_class(first_number, second_number):
first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)
# Calculate the interpolation vector between the two labels
percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label
# Combine the noise and the labels
noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
fake = gen(noise_and_labels)
show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
### Change me! ###
start_plot_number = 1 # Choose the start digit
### Change me! ###
end_plot_number = 9 # Choose the end digit
plt.figure(figsize=(8, 8))
interpolate_class(start_plot_number, end_plot_number)
_ = plt.axis('off')
### Uncomment the following lines of code if you would like to visualize a set of pairwise class
### interpolations for a collection of different numbers, all in a single grid of interpolations.
### You'll also see another visualization like this in the next code block!
plot_numbers = [2, 3, 4, 5, 7]
n_numbers = len(plot_numbers)
plt.figure(figsize=(8, 8))
for i, first_plot_number in enumerate(plot_numbers):
for j, second_plot_number in enumerate(plot_numbers):
plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
interpolate_class(first_plot_number, second_plot_number)
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()
# Changing the Noise Vector:
n_interpolation = 9 # How many intermediate images you want + 2 (for the start and end image)
# This time you're interpolating between the noise instead of the labels
interpolation_label = get_one_hot_labels(torch.Tensor([9]).long(), n_classes).repeat(n_interpolation, 1).float()
def interpolate_noise(first_noise, second_noise):
# This time you're interpolating between the noise instead of the labels
percent_first_noise = torch.linspace(0, 1, n_interpolation)[:, None].to(device)
interpolation_noise = first_noise * percent_first_noise + second_noise * (1 - percent_first_noise)
# Combine the noise and the labels again
noise_and_labels = combine_vectors(interpolation_noise, interpolation_label.to(device))
fake = gen(noise_and_labels)
show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)
# Generate noise vectors to interpolate between
### Change me! ###
n_noise = 4 # Choose the number of noise examples in the grid
plot_noises = [get_noise(1, z_dim, device=device) for i in range(n_noise)]
plt.figure(figsize=(8, 8))
for i, first_plot_noise in enumerate(plot_noises):
for j, second_plot_noise in enumerate(plot_noises):
plt.subplot(n_noise, n_noise, i * n_noise + j + 1)
interpolate_noise(first_plot_noise, second_plot_noise)
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()