88from plasma .utils .performance import PerformanceAnalyzer
99from plasma .utils .evaluation import get_loss_from_list
1010from plasma .models .torch_runner import (
11- #make_predictions_and_evaluate_gpu,
12- #make_predictions,
11+ # make_predictions_and_evaluate_gpu,
12+ # make_predictions,
1313 get_signal_dimensions ,
14+ calculate_conv_output_size ,
1415)
1516
1617from functools import partial
1718import os
1819import numpy as np
1920import logging
2021import random
21- import tqdm
22+ # import tqdm
2223
2324model_filename = "torch_model.pt"
2425LOGGER = logging .getLogger ("plasma.transformer.runner" )
3233# else:
3334# device = torch.device("cpu")
3435
36+
3537def set_seed (seed ):
3638 random .seed (seed )
3739 np .random .seed (seed )
3840 torch .manual_seed (seed )
3941 torch .cuda .manual_seed_all (seed )
40- os .environ ["PYTHONHASHSEED" ]= "0"
42+ os .environ ["PYTHONHASHSEED" ] = "0"
43+
4144
4245class TransformerNet (nn .Module ):
4346 def __init__ (
@@ -103,14 +106,16 @@ def __init__(
103106 )
104107 self .layers .append (nn .MaxPool1d (kernel_size = self .pooling_size ))
105108 self .conv_output_size = calculate_conv_output_size (
106- self .conv_output_size , 0 , 1 , self .pooling_size , self .pooling_size
109+ self .conv_output_size , 0 , 1 , self .pooling_size ,
110+ self .pooling_size
107111 )
108112 self .layers .append (nn .Dropout2d (dropout ))
109113 self .net = nn .Sequential (* self .layers )
110114 self .conv_output_size = self .conv_output_size * layer_sizes [- 1 ]
111115 self .linear_layers = []
112116
113- print ("Final feature size = {}" .format (self .n_scalars + self .conv_output_size ))
117+ print ("Final feature size = {}" .format (self .n_scalars
118+ + self .conv_output_size ))
114119 self .linear_layers .append (
115120 nn .Linear (self .conv_output_size + self .n_scalars , linear_size )
116121 )
@@ -128,7 +133,7 @@ def forward(self, x):
128133 x_profiles = x
129134 else :
130135 x_scalars = x [:, : self .n_scalars ]
131- x_profiles = x [:, self .n_scalars :]
136+ x_profiles = x [:, self .n_scalars :]
132137 x_profiles = x_profiles .contiguous ().view (
133138 x .size (0 ), self .n_profiles , self .profile_size
134139 )
@@ -170,7 +175,8 @@ def __init__(
170175 self .__max_seq_length = max_seq_length
171176 self .__d_model = d_model
172177 # FIXME
173- self .__positional_encodings = nn .Embedding (max_seq_length , d_model ).float ()
178+ self .__positional_encodings = nn .Embedding (
179+ max_seq_length , d_model ).float ()
174180
175181 def forward (self , x ):
176182 """
@@ -180,17 +186,19 @@ def forward(self, x):
180186 mask = (
181187 torch .arange (x .shape [1 ], device = device )
182188 .unsqueeze (0 )
183- .lt (torch .tensor ([self .__max_seq_length ], device = device ).unsqueeze (- 1 ))
189+ .lt (torch .tensor ([self .__max_seq_length ],
190+ device = device ).unsqueeze (- 1 ))
184191 )
185192 transformer_input = x * mask .unsqueeze (- 1 ).float () # B x max_len x D
186193
187194 positional_encodings = self .__positional_encodings (
188195 torch .arange (x .shape [1 ], dtype = torch .int64 , device = device )
189196 ).unsqueeze (0 )
190- transformer_input = transformer_input + positional_encodings # B x max_len x D
197+ transformer_input = (transformer_input
198+ + positional_encodings ) # B x max_len x D
191199
192200 out = self .__transformer_encoder (
193- transformer_input # .transpose(0, 1), src_key_padding_mask=~mask
201+ transformer_input # .transpose(0, 1), src_key_padding_mask=~mask
194202 )
195203 return out
196204
@@ -199,11 +207,10 @@ def build_torch_model(conf):
199207
200208 dropout = conf ["model" ]["dropout_prob" ]
201209 n_scalars , n_profiles , profile_size = get_signal_dimensions (conf )
202-
203- output_size = 1
204- layer_sizes_spatial = [6 , 3 , 3 ]
210+ # output_size = 1
211+ layer_sizes_spatial = [6 , 3 , 3 ]
205212 kernel_size_spatial = 3
206- linear_size = 5 # FIXME Alexeys there will be no linear layers
213+ linear_size = 5 # FIXME Alexeys there will be no linear layers
207214
208215 model = TransformerNet (
209216 n_scalars ,
@@ -233,7 +240,7 @@ def train_epoch(model, data_gen, optimizer, scheduler, loss_fn):
233240 step = 0
234241 while True :
235242 x_ , y_ , num_so_far , num_total , _ = next (data_gen )
236-
243+
237244 x = torch .from_numpy (x_ ).float ().to (device )
238245 y = torch .from_numpy (y_ ).float ().to (device )
239246
@@ -247,75 +254,77 @@ def train_epoch(model, data_gen, optimizer, scheduler, loss_fn):
247254 scheduler .step ()
248255 step += 1
249256
250- LOGGER .info (
251- f"[{ step } ] [{ num_so_far } /{ num_total } ] loss: { loss .item ()} , ave_loss: { total_loss / step } "
252- )
257+ LOGGER .info (f"[{ step } ] [{ num_so_far } /{ num_total } ] loss: { loss .item ()} , ave_loss: { total_loss / step } " ) # noqa
253258 if num_so_far >= num_total :
254259 break
255260
256- return step , loss .item (), total_loss , num_so_far , 1.0 * num_so_far / num_total
261+ return (step , loss .item (), total_loss , num_so_far ,
262+ 1.0 * num_so_far / num_total )
257263
258264
259265def train (conf , shot_list_train , shot_list_validate , loader ):
260- #set random seed
266+ # set random seed
261267 set_seed (0 )
262268 num_epochs = conf ["training" ]["num_epochs" ]
263- patience = conf ["callbacks" ]["patience" ]
269+ # patience = conf["callbacks"]["patience"]
264270 lr_decay = conf ["model" ]["lr_decay" ]
265- batch_size = conf ['training' ]['batch_size' ]
271+ # batch_size = conf['training']['batch_size']
266272 lr = conf ["model" ]["lr" ]
267- clipnorm = conf ['model' ]['clipnorm' ]
273+ # clipnorm = conf['model']['clipnorm']
268274 e = 0
269275
270276 loader .set_inference_mode (False )
271277 train_data_gen = partial (
272278 loader .simple_batch_generator ,
273279 shot_list = shot_list_train ,
274280 )()
275- valid_data_generator = partial (
281+ valid_data_generator = partial ( # noqa
276282 loader .simple_batch_generator ,
277283 shot_list = shot_list_validate ,
278284 inference = True
279285 )()
280- LOGGER .info (f"validate: { len (shot_list_validate )} shots, { shot_list_validate .num_disruptive ()} disruptive" )
281- LOGGER .info (f"training: { len (shot_list_train )} shots, { shot_list_train .num_disruptive ()} disruptive" )
286+ LOGGER .info (f"validate: { len (shot_list_validate )} shots, { shot_list_validate .num_disruptive ()} disruptive" ) # noqa
287+ LOGGER .info (f"training: { len (shot_list_train )} shots, { shot_list_train .num_disruptive ()} disruptive" ) # noqa
282288
283289 loss_fn = nn .MSELoss (size_average = True )
284290 train_model = build_torch_model (conf )
285291
286292 optimizer = opt .Adam (train_model .parameters (), lr = lr )
287293 scheduler = opt .lr_scheduler .ExponentialLR (optimizer , lr_decay )
288-
294+
289295 model_path = get_model_path (conf )
290296 makedirs_process_safe (os .path .dirname (model_path ))
291297
292298 train_model .train ()
293299 LOGGER .info (f"{ num_epochs - 1 - e } epochs left to go" )
294300 while e < num_epochs - 1 :
295301 LOGGER .info (f"Epoch { e } /{ num_epochs } " )
296- (step , ave_loss , curr_loss , num_so_far , effective_epochs ) = train_epoch (
302+ (step , ave_loss , curr_loss , num_so_far ,
303+ effective_epochs ) = train_epoch (
297304 train_model , train_data_gen , optimizer , scheduler , loss_fn
298305 )
299-
306+
300307 e = effective_epochs
301308 torch .save (train_model .state_dict (), model_path )
302- #FIXME no validation for now as OOM
303- #_, _, _, roc_area, loss = make_predictions_and_evaluate_gpu(
309+ # FIXME no validation for now as OOM
310+ # _, _, _, roc_area, loss = make_predictions_and_evaluate_gpu(
304311 # conf, shot_list_validate, valid_data_generator
305- #)
312+ # )
313+
314+ # # stop_training = False
315+ # print("=========Summary======== for epoch{}".format(step))
316+ # print("Training Loss numpy: {:.3e}".format(ave_loss))
317+ # print("Validation Loss: {:.3e}".format(loss))
318+ # print("Validation ROC: {:.4f}".format(roc_area))
306319
307- ## stop_training = False
308- #print("=========Summary======== for epoch{}".format(step))
309- #print("Training Loss numpy: {:.3e}".format(ave_loss))
310- #print("Validation Loss: {:.3e}".format(loss))
311- #print("Validation ROC: {:.4f}".format(roc_area))
312320
313321def apply_model_to_np (model , x ):
314322 return model (torch .from_numpy (x ).float ()).data .numpy ()
315323
316- #FIXME Alexeys change
324+
325+ # FIXME Alexeys change
317326def make_predictions (conf , shot_list , generator , custom_path = None ):
318- #generator = loader.inference_batch_generator_full_shot(shot_list)
327+ # generator = loader.inference_batch_generator_full_shot(shot_list)
319328 inference_model = build_torch_model (conf )
320329
321330 if custom_path is None :
@@ -336,11 +345,11 @@ def make_predictions(conf, shot_list, generator, custom_path=None):
336345
337346 x = torch .from_numpy (x_ ).float ().to (device )
338347 y = torch .from_numpy (y_ ).float ().to (device )
339- #output = apply_model_to_np(inference_model, x)
348+ # output = apply_model_to_np(inference_model, x)
340349 output = inference_model (x )
341350
342351 for batch_idx in range (x .shape [0 ]):
343- #curr_length = lengths[batch_idx]
352+ # curr_length = lengths[batch_idx]
344353 y_prime += [output [batch_idx , :, 0 ]]
345354 y_gold += [y [batch_idx , :, 0 ]]
346355 disruptive += [disr [batch_idx ]]
@@ -352,11 +361,13 @@ def make_predictions(conf, shot_list, generator, custom_path=None):
352361 break
353362 return y_prime , y_gold , disruptive
354363
355- #FIXME ALexeys change loader --> generator
356- def make_predictions_and_evaluate_gpu (conf , shot_list , generator , custom_path = None ):
364+
365+ # FIXME ALexeys change loader --> generator
366+ def make_predictions_and_evaluate_gpu (conf , shot_list , generator ,
367+ custom_path = None ):
357368 y_prime , y_gold , disruptive = make_predictions (
358369 conf , shot_list , generator , custom_path )
359370 analyzer = PerformanceAnalyzer (conf = conf )
360371 roc_area = analyzer .get_roc_area (y_prime , y_gold , disruptive )
361372 loss = get_loss_from_list (y_prime , y_gold , conf ['data' ]['target' ])
362- return y_prime , y_gold , disruptive , roc_area , loss
373+ return y_prime , y_gold , disruptive , roc_area , loss
0 commit comments