Skip to content

Commit 019c1bd

Browse files
committed
Fix style errors in transformer code
1 parent d9f750e commit 019c1bd

File tree

5 files changed

+88
-74
lines changed

5 files changed

+88
-74
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ language: python
22
branches:
33
only:
44
- master
5+
- transformer
56
os:
67
- linux
78

examples/transformer_learn.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,12 @@
7575
print('Training on {} shots, testing on {} shots'.format(
7676
len(shot_list_train), len(shot_list_test)))
7777

78-
7978
#####################################################
8079
# TRAINING #
8180
#####################################################
8281
train(conf, shot_list_train.random_sublist(512),
8382
shot_list_validate.random_sublist(256), loader)
84-
#if not only_predict:
83+
# if not only_predict:
8584
# p = old_mp.Process(target=train,
8685
# args=(conf, shot_list_train,
8786
# shot_list_validate, loader)
@@ -115,18 +114,17 @@
115114

116115
# TODO(KGF): check tuple unpack
117116
(y_prime_train, y_gold_train, disruptive_train, roc_train,
118-
loss_train) = make_predictions_and_evaluate_gpu(
119-
conf, shot_list_train, loader, custom_path)
117+
loss_train) = make_predictions_and_evaluate_gpu(
118+
conf, shot_list_train, loader, custom_path)
120119
(y_prime_test, y_gold_test, disruptive_test, roc_test,
121-
loss_test) = make_predictions_and_evaluate_gpu(
122-
conf, shot_list_test, loader, custom_path)
120+
loss_test) = make_predictions_and_evaluate_gpu(
121+
conf, shot_list_test, loader, custom_path)
123122
print('=========Summary========')
124123
print('Train Loss: {:.3e}'.format(loss_train))
125124
print('Train ROC: {:.4f}'.format(roc_train))
126125
print('Test Loss: {:.3e}'.format(loss_test))
127126
print('Test ROC: {:.4f}'.format(roc_test))
128127

129-
130128
disruptive_train = np.array(disruptive_train)
131129
disruptive_test = np.array(disruptive_test)
132130

@@ -138,16 +136,20 @@
138136
shot_list_test.make_light()
139137
shot_list_train.make_light()
140138

141-
save_str = 'results_' + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
139+
save_str = 'results_' + datetime.datetime.now().strftime(
140+
"%Y-%m-%d-%H-%M-%S")
142141
result_base_path = conf['paths']['results_prepath']
143142
if not os.path.exists(result_base_path):
144143
os.makedirs(result_base_path)
145-
np.savez(result_base_path+save_str, y_gold=y_gold, y_gold_train=y_gold_train,
146-
y_gold_test=y_gold_test, y_prime=y_prime, y_prime_train=y_prime_train,
147-
y_prime_test=y_prime_test, disruptive=disruptive,
148-
disruptive_train=disruptive_train, disruptive_test=disruptive_test,
149-
shot_list_validate=shot_list_validate,
150-
shot_list_train=shot_list_train, shot_list_test=shot_list_test,
151-
conf=conf)
144+
np.savez(result_base_path+save_str, y_gold=y_gold,
145+
y_gold_train=y_gold_train,
146+
y_gold_test=y_gold_test,
147+
y_prime=y_prime, y_prime_train=y_prime_train,
148+
y_prime_test=y_prime_test, disruptive=disruptive,
149+
disruptive_train=disruptive_train,
150+
disruptive_test=disruptive_test,
151+
shot_list_validate=shot_list_validate,
152+
shot_list_train=shot_list_train, shot_list_test=shot_list_test,
153+
conf=conf)
152154

153155
print('finished.')

plasma/models/distributed_torch_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ def train(conf, shot_list_train, shot_list_validate, loader):
418418
hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)
419419
hvd.broadcast_optimizer_state(optim, root_rank=0)
420420

421-
optimizer_args = {'op': hvd.Average, 'compression': hvd.Compression.fp16, 'named_parameters': train_model.named_parameters()}
421+
optimizer_args = {'op': hvd.Average, 'compression': hvd.Compression.fp16,
422+
'named_parameters': train_model.named_parameters()}
422423
optimizer = hvd.DistributedOptimizer(optim, **optimizer_args)
423424

424425
train_model.train()
@@ -431,8 +432,8 @@ def train(conf, shot_list_train, shot_list_validate, loader):
431432
while e < num_epochs - 1:
432433
print('\nEpoch {}/{}'.format(e, num_epochs))
433434
(step, ave_loss, curr_loss, num_so_far,
434-
effective_epochs) = train_epoch(train_model, data_gen, optimizer, scheduler,
435-
loss_fn)
435+
effective_epochs) = train_epoch(train_model, data_gen, optimizer,
436+
scheduler, loss_fn)
436437
e = effective_epochs
437438
loader.verbose = False # True during the first iteration
438439
# if task_index == 0:

plasma/models/loader.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -825,16 +825,15 @@ def get_batch_size(batch_size, prediction_mode):
825825
def get_num_skips(length, skip):
826826
return 1 + (length-1)//skip
827827

828-
#FIXME Alexeys
828+
# FIXME Alexeys
829829
def simple_batch_generator(self, shot_list, max_len=2048, inference=False):
830-
831830
batch_size = self.conf['training']['batch_size']
832831
sig, res = self.get_signal_result_from_shot(shot_list.shots[0])
833832
Xbuff = np.zeros((batch_size, max_len, sig.shape[1]))
834833
Ybuff = np.zeros((batch_size, max_len, res.shape[1]))
835834

836835
num_total = len(shot_list)
837-
#num_batches = num_total//batch_size
836+
# num_batches = num_total//batch_size
838837
disr = np.zeros(batch_size, dtype=bool)
839838

840839
while True:
@@ -846,18 +845,18 @@ def simple_batch_generator(self, shot_list, max_len=2048, inference=False):
846845
for i in range(num_total):
847846
shot = self.sample_shot_from_list_given_index(shot_list, i)
848847
sig, res = self.get_signal_result_from_shot(shot)
849-
sig = sig[-max_len:,:]
850-
res = res[-max_len:,:]
851-
Xbuff[i%batch_size, -sig.shape[0]:, :] = sig
852-
Ybuff[i%batch_size, -res.shape[0]:, :] = res
853-
disr[i%batch_size] = shot.is_disruptive_shot()
848+
sig = sig[-max_len:, :]
849+
res = res[-max_len:, :]
850+
Xbuff[i % batch_size, -sig.shape[0]:, :] = sig
851+
Ybuff[i % batch_size, -res.shape[0]:, :] = res
852+
disr[i % batch_size] = shot.is_disruptive_shot()
854853

855854
if i % batch_size == 0:
856855
num_so_far += batch_size
857-
858856
yield Xbuff, Ybuff, num_so_far, num_total, disr
859-
#Xbuff = np.zeros((batch_size, max_len, sig.shape[1]))
860-
#Ybuff = np.zeros((batch_size, max_len, res.shape[1]))
857+
# Xbuff = np.zeros((batch_size, max_len, sig.shape[1]))
858+
# Ybuff = np.zeros((batch_size, max_len, res.shape[1]))
859+
861860

862861
class ProcessGenerator(object):
863862
def __init__(self, generator):

plasma/transformer/runner.py

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88
from plasma.utils.performance import PerformanceAnalyzer
99
from plasma.utils.evaluation import get_loss_from_list
1010
from plasma.models.torch_runner import (
11-
#make_predictions_and_evaluate_gpu,
12-
#make_predictions,
11+
# make_predictions_and_evaluate_gpu,
12+
# make_predictions,
1313
get_signal_dimensions,
14+
calculate_conv_output_size,
1415
)
1516

1617
from functools import partial
1718
import os
1819
import numpy as np
1920
import logging
2021
import random
21-
import tqdm
22+
# import tqdm
2223

2324
model_filename = "torch_model.pt"
2425
LOGGER = logging.getLogger("plasma.transformer.runner")
@@ -32,12 +33,14 @@
3233
# else:
3334
# device = torch.device("cpu")
3435

36+
3537
def set_seed(seed):
3638
random.seed(seed)
3739
np.random.seed(seed)
3840
torch.manual_seed(seed)
3941
torch.cuda.manual_seed_all(seed)
40-
os.environ["PYTHONHASHSEED"]="0"
42+
os.environ["PYTHONHASHSEED"] = "0"
43+
4144

4245
class TransformerNet(nn.Module):
4346
def __init__(
@@ -103,14 +106,16 @@ def __init__(
103106
)
104107
self.layers.append(nn.MaxPool1d(kernel_size=self.pooling_size))
105108
self.conv_output_size = calculate_conv_output_size(
106-
self.conv_output_size, 0, 1, self.pooling_size, self.pooling_size
109+
self.conv_output_size, 0, 1, self.pooling_size,
110+
self.pooling_size
107111
)
108112
self.layers.append(nn.Dropout2d(dropout))
109113
self.net = nn.Sequential(*self.layers)
110114
self.conv_output_size = self.conv_output_size * layer_sizes[-1]
111115
self.linear_layers = []
112116

113-
print("Final feature size = {}".format(self.n_scalars + self.conv_output_size))
117+
print("Final feature size = {}".format(self.n_scalars
118+
+ self.conv_output_size))
114119
self.linear_layers.append(
115120
nn.Linear(self.conv_output_size + self.n_scalars, linear_size)
116121
)
@@ -128,7 +133,7 @@ def forward(self, x):
128133
x_profiles = x
129134
else:
130135
x_scalars = x[:, : self.n_scalars]
131-
x_profiles = x[:, self.n_scalars :]
136+
x_profiles = x[:, self.n_scalars:]
132137
x_profiles = x_profiles.contiguous().view(
133138
x.size(0), self.n_profiles, self.profile_size
134139
)
@@ -170,7 +175,8 @@ def __init__(
170175
self.__max_seq_length = max_seq_length
171176
self.__d_model = d_model
172177
# FIXME
173-
self.__positional_encodings = nn.Embedding(max_seq_length, d_model).float()
178+
self.__positional_encodings = nn.Embedding(
179+
max_seq_length, d_model).float()
174180

175181
def forward(self, x):
176182
"""
@@ -180,17 +186,19 @@ def forward(self, x):
180186
mask = (
181187
torch.arange(x.shape[1], device=device)
182188
.unsqueeze(0)
183-
.lt(torch.tensor([self.__max_seq_length], device=device).unsqueeze(-1))
189+
.lt(torch.tensor([self.__max_seq_length],
190+
device=device).unsqueeze(-1))
184191
)
185192
transformer_input = x * mask.unsqueeze(-1).float() # B x max_len x D
186193

187194
positional_encodings = self.__positional_encodings(
188195
torch.arange(x.shape[1], dtype=torch.int64, device=device)
189196
).unsqueeze(0)
190-
transformer_input = transformer_input + positional_encodings # B x max_len x D
197+
transformer_input = (transformer_input
198+
+ positional_encodings) # B x max_len x D
191199

192200
out = self.__transformer_encoder(
193-
transformer_input #.transpose(0, 1), src_key_padding_mask=~mask
201+
transformer_input # .transpose(0, 1), src_key_padding_mask=~mask
194202
)
195203
return out
196204

@@ -199,11 +207,10 @@ def build_torch_model(conf):
199207

200208
dropout = conf["model"]["dropout_prob"]
201209
n_scalars, n_profiles, profile_size = get_signal_dimensions(conf)
202-
203-
output_size = 1
204-
layer_sizes_spatial = [6, 3, 3]
210+
# output_size = 1
211+
layer_sizes_spatial = [6, 3, 3]
205212
kernel_size_spatial = 3
206-
linear_size = 5 #FIXME Alexeys there will be no linear layers
213+
linear_size = 5 # FIXME Alexeys there will be no linear layers
207214

208215
model = TransformerNet(
209216
n_scalars,
@@ -233,7 +240,7 @@ def train_epoch(model, data_gen, optimizer, scheduler, loss_fn):
233240
step = 0
234241
while True:
235242
x_, y_, num_so_far, num_total, _ = next(data_gen)
236-
243+
237244
x = torch.from_numpy(x_).float().to(device)
238245
y = torch.from_numpy(y_).float().to(device)
239246

@@ -247,75 +254,77 @@ def train_epoch(model, data_gen, optimizer, scheduler, loss_fn):
247254
scheduler.step()
248255
step += 1
249256

250-
LOGGER.info(
251-
f"[{step}] [{num_so_far}/{num_total}] loss: {loss.item()}, ave_loss: {total_loss / step}"
252-
)
257+
LOGGER.info(f"[{step}] [{num_so_far}/{num_total}] loss: {loss.item()}, ave_loss: {total_loss / step}") # noqa
253258
if num_so_far >= num_total:
254259
break
255260

256-
return step, loss.item(), total_loss, num_so_far, 1.0 * num_so_far / num_total
261+
return (step, loss.item(), total_loss, num_so_far,
262+
1.0 * num_so_far / num_total)
257263

258264

259265
def train(conf, shot_list_train, shot_list_validate, loader):
260-
#set random seed
266+
# set random seed
261267
set_seed(0)
262268
num_epochs = conf["training"]["num_epochs"]
263-
patience = conf["callbacks"]["patience"]
269+
# patience = conf["callbacks"]["patience"]
264270
lr_decay = conf["model"]["lr_decay"]
265-
batch_size = conf['training']['batch_size']
271+
# batch_size = conf['training']['batch_size']
266272
lr = conf["model"]["lr"]
267-
clipnorm = conf['model']['clipnorm']
273+
# clipnorm = conf['model']['clipnorm']
268274
e = 0
269275

270276
loader.set_inference_mode(False)
271277
train_data_gen = partial(
272278
loader.simple_batch_generator,
273279
shot_list=shot_list_train,
274280
)()
275-
valid_data_generator = partial(
281+
valid_data_generator = partial( # noqa
276282
loader.simple_batch_generator,
277283
shot_list=shot_list_validate,
278284
inference=True
279285
)()
280-
LOGGER.info(f"validate: {len(shot_list_validate)} shots, {shot_list_validate.num_disruptive()} disruptive")
281-
LOGGER.info(f"training: {len(shot_list_train)} shots, {shot_list_train.num_disruptive()} disruptive")
286+
LOGGER.info(f"validate: {len(shot_list_validate)} shots, {shot_list_validate.num_disruptive()} disruptive") # noqa
287+
LOGGER.info(f"training: {len(shot_list_train)} shots, {shot_list_train.num_disruptive()} disruptive") # noqa
282288

283289
loss_fn = nn.MSELoss(size_average=True)
284290
train_model = build_torch_model(conf)
285291

286292
optimizer = opt.Adam(train_model.parameters(), lr=lr)
287293
scheduler = opt.lr_scheduler.ExponentialLR(optimizer, lr_decay)
288-
294+
289295
model_path = get_model_path(conf)
290296
makedirs_process_safe(os.path.dirname(model_path))
291297

292298
train_model.train()
293299
LOGGER.info(f"{num_epochs - 1 - e} epochs left to go")
294300
while e < num_epochs - 1:
295301
LOGGER.info(f"Epoch {e}/{num_epochs}")
296-
(step, ave_loss, curr_loss, num_so_far, effective_epochs) = train_epoch(
302+
(step, ave_loss, curr_loss, num_so_far,
303+
effective_epochs) = train_epoch(
297304
train_model, train_data_gen, optimizer, scheduler, loss_fn
298305
)
299-
306+
300307
e = effective_epochs
301308
torch.save(train_model.state_dict(), model_path)
302-
#FIXME no validation for now as OOM
303-
#_, _, _, roc_area, loss = make_predictions_and_evaluate_gpu(
309+
# FIXME no validation for now as OOM
310+
# _, _, _, roc_area, loss = make_predictions_and_evaluate_gpu(
304311
# conf, shot_list_validate, valid_data_generator
305-
#)
312+
# )
313+
314+
# # stop_training = False
315+
# print("=========Summary======== for epoch{}".format(step))
316+
# print("Training Loss numpy: {:.3e}".format(ave_loss))
317+
# print("Validation Loss: {:.3e}".format(loss))
318+
# print("Validation ROC: {:.4f}".format(roc_area))
306319

307-
## stop_training = False
308-
#print("=========Summary======== for epoch{}".format(step))
309-
#print("Training Loss numpy: {:.3e}".format(ave_loss))
310-
#print("Validation Loss: {:.3e}".format(loss))
311-
#print("Validation ROC: {:.4f}".format(roc_area))
312320

313321
def apply_model_to_np(model, x):
314322
return model(torch.from_numpy(x).float()).data.numpy()
315323

316-
#FIXME Alexeys change
324+
325+
# FIXME Alexeys change
317326
def make_predictions(conf, shot_list, generator, custom_path=None):
318-
#generator = loader.inference_batch_generator_full_shot(shot_list)
327+
# generator = loader.inference_batch_generator_full_shot(shot_list)
319328
inference_model = build_torch_model(conf)
320329

321330
if custom_path is None:
@@ -336,11 +345,11 @@ def make_predictions(conf, shot_list, generator, custom_path=None):
336345

337346
x = torch.from_numpy(x_).float().to(device)
338347
y = torch.from_numpy(y_).float().to(device)
339-
#output = apply_model_to_np(inference_model, x)
348+
# output = apply_model_to_np(inference_model, x)
340349
output = inference_model(x)
341350

342351
for batch_idx in range(x.shape[0]):
343-
#curr_length = lengths[batch_idx]
352+
# curr_length = lengths[batch_idx]
344353
y_prime += [output[batch_idx, :, 0]]
345354
y_gold += [y[batch_idx, :, 0]]
346355
disruptive += [disr[batch_idx]]
@@ -352,11 +361,13 @@ def make_predictions(conf, shot_list, generator, custom_path=None):
352361
break
353362
return y_prime, y_gold, disruptive
354363

355-
#FIXME ALexeys change loader --> generator
356-
def make_predictions_and_evaluate_gpu(conf, shot_list, generator, custom_path=None):
364+
365+
# FIXME ALexeys change loader --> generator
366+
def make_predictions_and_evaluate_gpu(conf, shot_list, generator,
367+
custom_path=None):
357368
y_prime, y_gold, disruptive = make_predictions(
358369
conf, shot_list, generator, custom_path)
359370
analyzer = PerformanceAnalyzer(conf=conf)
360371
roc_area = analyzer.get_roc_area(y_prime, y_gold, disruptive)
361372
loss = get_loss_from_list(y_prime, y_gold, conf['data']['target'])
362-
return y_prime, y_gold, disruptive, roc_area, loss
373+
return y_prime, y_gold, disruptive, roc_area, loss

0 commit comments

Comments
 (0)