-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
109 lines (87 loc) · 3.03 KB
/
train.py
File metadata and controls
109 lines (87 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import pandas as pd
import numpy as np
import glob
from skimage import io, transform
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import pickle
from model import Model
from loss import Loss
# Dataset object, which will perform transformations on the images
class ImageDataset(Dataset):
def __init__(self, data_list, transform=None):
self.data_list = data_list
self.transform = transform
def __len__(self):
return len(self.data_list)
# Function which will open image from location and perform transformations before return the image
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data_list[idx]
image = io.imread(img_name)
if self.transform:
image = self.transform(image)
return image
# Function used to rescale the images to desired dimensions
class Rescale(object):
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image = sample
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
return img
# Function which will convert images to tensors
class ToTensor(object):
def __call__(self, sample):
image = sample
image = image.transpose((2, 0, 1))
return torch.from_numpy(image)
# Train function which will iterate over the dataloader
def train(model, dataloader, optimizer, batch_size, n, image_size):
for i_batch, sample in enumerate(dataloader):
optimizer.zero_grad()
A = model(sample)
LE = sample
A_n = A.reshape(batch_size, n, 3, image_size, image_size)
# Main iterative formula which is applied to enhance the images
for iter in range(n):
LE = LE + torch.mul(torch.mul(A_n[:][iter], LE), (torch.ones(LE.shape) - LE))
# Backward Propogation
l = Loss()
loss = l.compute_losses(sample, LE, A, image_size, n)
loss.backward()
optimizer.step()
print('Batch: ', str(i_batch), ' ------ Loss: ', str(loss.data))
if __name__ == '__main__':
# Fetch image locations from dataset directory
dataset_directory = 'Dataset'
image_locs = glob.glob(dataset_directory + '/*.jpg')
train_set = image_locs[:int(0.8*len(image_locs))]
# Initialize image dataset
dataset = ImageDataset(train_set, transform=transforms.Compose([Rescale(512),ToTensor()]))
image_size = 512
batch_size = 8
n_epochs = 10
n = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
model = Model()
print('Model created')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(0, n_epochs):
print('Epoch Number: ' + str(epoch))
train(model.float(), dataloader, optimizer, batch_size, n, image_size)
# Storing trained model
pickle.dump(model, open('model/trained_model.pkl', 'wb'))