@@ -45,7 +45,7 @@ def __init__(self,n_scalars,n_profiles,profile_size,layer_sizes_spatial,
4545 num_channels_tcn ,kernel_size_temporal ,dropout = 0.1 ):
4646 super (FTCN , self ).__init__ ()
4747 self .lin = InputBlock (n_scalars , n_profiles ,profile_size , layer_sizes_spatial , kernel_size_spatial , linear_size , dropout )
48- self .input_layer = TimeDistributed (lin ,batch_first = True )
48+ self .input_layer = TimeDistributed (self . lin ,batch_first = True )
4949 self .tcn = TCN (linear_size , output_size , num_channels_tcn , kernel_size_temporal , dropout )
5050 self .model = nn .Sequential (self .input_layer ,self .tcn )
5151
@@ -102,8 +102,8 @@ def forward(self, x):
102102 if self .n_scalars == 0 :
103103 x_profiles = x
104104 else :
105- x_scalars = x [:,:n_scalars ]
106- x_profiles = x [:,n_scalars :]
105+ x_scalars = x [:,:self . n_scalars ]
106+ x_profiles = x [:,self . n_scalars :]
107107 x_profiles = x_profiles .contiguous ().view (x .size (0 ),self .n_profiles ,self .profile_size )
108108 profile_features = self .net (x_profiles ).view (x .size (0 ),- 1 )
109109 if self .n_scalars == 0 :
@@ -271,18 +271,18 @@ def build_torch_model(conf):
271271# dim = 10
272272
273273 # lin = nn.Linear(input_size,intermediate_dim)
274- n_scalars , n_profile , profile_size = get_signal_dimensions (conf )
274+ n_scalars , n_profiles , profile_size = get_signal_dimensions (conf )
275275 dim = n_scalars + n_profiles * profile_size
276276 input_size = dim
277277 output_size = 1
278278 # intermediate_dim = 15
279279
280- layer_sizes_spatial = [40 ,20 ,20 ]
280+ layer_sizes_spatial = [6 , 3 , 3 ] #[ 40,20,20]
281281 kernel_size_spatial = 3
282- linear_size = 10
282+ linear_size = 5
283283
284- num_channels_tcn = [3 ]* 5
285- kernel_size_temporal = 3
284+ num_channels_tcn = [10 , 5 , 3 , 3 ] #[ 3]*5
285+ kernel_size_temporal = 3 #3
286286 model = FTCN (n_scalars ,n_profiles ,profile_size ,layer_sizes_spatial ,
287287 kernel_size_spatial ,linear_size ,output_size ,num_channels_tcn ,
288288 kernel_size_temporal ,dropout )
@@ -300,16 +300,68 @@ def get_signal_dimensions(conf):
300300 num_channels = sig .num_channels
301301 if num_channels > 1 :
302302 profile_size = num_channels
303- num_1D += 1
303+ n_profiles += 1
304304 is_1D_region = True
305305 else :
306306 assert (not is_1D_region ), "make sure all use_signals are ordered such that 1D signals come last!"
307307 assert (num_channels == 1 )
308- num_0D += 1
308+ n_scalars += 1
309309 is_1D_region = False
310310 return n_scalars ,n_profiles ,profile_size
311311
312- def train_epoch (model ,data_gen ,loss_fn ):
312+ def apply_model_to_np (model ,x ):
313+ # return model(Variable(torch.from_numpy(x).float()).unsqueeze(0)).squeeze(0).data.numpy()
314+ return model (Variable (torch .from_numpy (x ).float ())).data .numpy ()
315+
316+
317+
318+ def make_predictions (conf ,shot_list ,loader ,custom_path = None ):
319+ generator = loader .inference_batch_generator_full_shot (shot_list )
320+ inference_model = build_torch_model (conf )
321+
322+ if custom_path == None :
323+ model_path = get_model_path (conf )
324+ else :
325+ model_path = custom_path
326+ inference_model .load_state_dict (torch .load (model_path ))
327+ #shot_list = shot_list.random_sublist(10)
328+
329+ y_prime = []
330+ y_gold = []
331+ disruptive = []
332+ num_shots = len (shot_list )
333+
334+ pbar = Progbar (num_shots )
335+ while True :
336+ x ,y ,mask ,disr ,lengths ,num_so_far ,num_total = next (generator )
337+ #x, y, mask = Variable(torch.from_numpy(x_).float()), Variable(torch.from_numpy(y_).float()),Variable(torch.from_numpy(mask_).byte())
338+ output = apply_model_to_np (inference_model ,x )
339+ for batch_idx in range (x .shape [0 ]):
340+ curr_length = lengths [batch_idx ]
341+ y_prime += [output [batch_idx ,:curr_length ,0 ]]
342+ y_gold += [y [batch_idx ,:curr_length ,0 ]]
343+ disruptive += [disr [batch_idx ]]
344+ pbar .add (1.0 )
345+ if len (disruptive ) >= num_shots :
346+ y_prime = y_prime [:num_shots ]
347+ y_gold = y_gold [:num_shots ]
348+ disruptive = disruptive [:num_shots ]
349+ break
350+ return y_prime ,y_gold ,disruptive
351+
352+ def make_predictions_and_evaluate_gpu (conf ,shot_list ,loader ,custom_path = None ):
353+ y_prime ,y_gold ,disruptive = make_predictions (conf ,shot_list ,loader ,custom_path )
354+ analyzer = PerformanceAnalyzer (conf = conf )
355+ roc_area = analyzer .get_roc_area (y_prime ,y_gold ,disruptive )
356+ loss = get_loss_from_list (y_prime ,y_gold ,conf ['data' ]['target' ])
357+ return y_prime ,y_gold ,disruptive ,roc_area ,loss
358+
359+
360+ def get_model_path (conf ):
361+ return conf ['paths' ]['model_save_path' ] + 'torch/' + model_filename #save_prepath + model_filename
362+
363+
364+ def train_epoch (model ,data_gen ,optimizer ,loss_fn ):
313365 loss = 0
314366 total_loss = 0
315367 num_so_far = 0
@@ -335,17 +387,19 @@ def train_epoch(model,data_gen,loss_fn):
335387 loss .backward ()
336388 optimizer .step ()
337389 step += 1
390+ print ("[{}] [{}/{}] loss: {:.3f}, ave_loss: {:.3f}" .format (step ,num_so_far - num_so_far_start ,num_total ,loss .data [0 ],total_loss / step ))
338391 if num_so_far - num_so_far_start >= num_total :
339392 break
340- x_ ,y_ ,mask_ ,num_so_far_start ,num_total = next (data_gen )
341- return step ,loss ,total_loss ,num_so_far ,1.0 * num_so_far / num_total
393+ x_ ,y_ ,mask_ ,num_so_far ,num_total = next (data_gen )
394+ return step ,loss . data [ 0 ] ,total_loss ,num_so_far ,1.0 * num_so_far / num_total
342395
343396
344397def train (conf ,shot_list_train ,shot_list_validate ,loader ):
345398
346399 np .random .seed (1 )
347400
348- data_gen = ProcessGenerator (partial (loader .training_batch_generator_full_shot_partial_reset ,shot_list = shot_list_train ))
401+ #data_gen = ProcessGenerator(partial(loader.training_batch_generator_full_shot_partial_reset,shot_list=shot_list_train)())
402+ data_gen = partial (loader .training_batch_generator_full_shot_partial_reset ,shot_list = shot_list_train )()
349403
350404 print ('validate: {} shots, {} disruptive' .format (len (shot_list_validate ),shot_list_validate .num_disruptive ()))
351405 print ('training: {} shots, {} disruptive' .format (len (shot_list_train ),shot_list_train .num_disruptive ()))
@@ -358,6 +412,7 @@ def train(conf,shot_list_train,shot_list_validate,loader):
358412 # e = specific_builder.load_model_weights(train_model)
359413
360414 num_epochs = conf ['training' ]['num_epochs' ]
415+ patience = conf ['callbacks' ]['patience' ]
361416 lr_decay = conf ['model' ]['lr_decay' ]
362417 batch_size = conf ['training' ]['batch_size' ]
363418 lr = conf ['model' ]['lr' ]
@@ -385,23 +440,25 @@ def train(conf,shot_list_train,shot_list_validate,loader):
385440 else :
386441 best_so_far = np .inf
387442 cmp_fn = min
388- optimizer = opt .Adam (model .parameters (),lr = lr )
389- model .train ()
443+ optimizer = opt .Adam (train_model .parameters (),lr = lr )
444+ scheduler = opt .lr_scheduler .ExponentialLR (optimizer ,lr_decay )
445+ train_model .train ()
390446 not_updated = 0
391447 total_loss = 0
392448 count = 0
393- loss_fn = nn .MSELoss (size_average = False )
394- model_path = conf [ 'paths' ][ 'model_save_path' ] + model_filename #save_prepath + model_filename
395- makedirs_process_safe (conf [ 'paths' ][ 'model_save_path' ] )
449+ loss_fn = nn .MSELoss (size_average = True )
450+ model_path = get_model_path ( conf )
451+ makedirs_process_safe (os . path . dirname ( model_path ) )
396452 while e < num_epochs - 1 :
397- print_unique ('\n Epoch {}/{}' .format (e ,num_epochs ))
398- (step ,ave_loss ,curr_loss ,num_so_far ,effective_epochs ) = train_epoch (model ,data_gen ,loss_fn )
453+ scheduler .step ()
454+ print ('\n Epoch {}/{}' .format (e ,num_epochs ))
455+ (step ,ave_loss ,curr_loss ,num_so_far ,effective_epochs ) = train_epoch (train_model ,data_gen ,optimizer ,loss_fn )
399456 e = effective_epochs
400457 loader .verbose = False #True during the first iteration
401458 # if task_index == 0:
402459 # specific_builder.save_model_weights(train_model,int(round(e)))
403- model . save_state_dict ( model_path )
404- _ ,_ ,_ ,roc_area ,loss = mpi_make_predictions_and_evaluate (conf ,shot_list_validate ,loader )
460+ torch . save ( train_model . state_dict (), model_path )
461+ _ ,_ ,_ ,roc_area ,loss = make_predictions_and_evaluate_gpu (conf ,shot_list_validate ,loader )
405462
406463 best_so_far = cmp_fn (roc_area ,best_so_far )
407464
@@ -411,54 +468,13 @@ def train(conf,shot_list_train,shot_list_validate,loader):
411468 print ('Validation Loss: {:.3e}' .format (loss ))
412469 print ('Validation ROC: {:.4f}' .format (roc_area ))
413470
414- if best_so_far != epoch_logs [ conf [ 'callbacks' ][ 'monitor' ]] : #only save model weights if quantity we are tracking is improving
471+ if best_so_far != roc_area : #only save model weights if quantity we are tracking is improving
415472 print ("No improvement, still saving model" )
416473 not_updated += 1
417474 else :
418475 print ("Saving model" )
419- model .save_state_dict (model_path )
420476 # specific_builder.delete_model_weights(train_model,int(round(e)))
421477 if not_updated > patience :
422478 print ("Stopping training due to early stopping" )
423479 break
424480
425- def make_predictions (conf ,shot_list ,loader ,custom_path = None ):
426- generator = loader .inference_batch_generator_full_shot (shot_list )
427- inference_model = build_torch_model (conf )
428-
429- if custom_path == None :
430- model_path = conf ['paths' ]['model_save_path' ] + model_filename #save_prepath + model_filename
431- else :
432- model_path = custom_path
433- inference_model .load_state_dict (model_path )
434- #shot_list = shot_list.random_sublist(10)
435-
436- y_prime = []
437- y_gold = []
438- disruptive = []
439- num_shots = len (shot_list )
440-
441- pbar = Progbar (num_shots )
442- while True :
443- x_ ,y_ ,mask_ ,disr_ ,num_so_far ,num_total = next (generator )
444- x , y , mask = Variable (torch .from_numpy (x_ ).float ()), Variable (torch .from_numpy (y_ ).float ()),Variable (torch .from_numpy (mask_ ).byte ())
445- output = model (x )
446- for batch_idx in range (x .shape [0 ])
447- y_prime [batch_idx ] + = [output [batch_idx ,:,:]]
448- y_gold += [y_ [batch_idx ,:,:]]
449- disruptive += [disr [batch_idx ]]
450- pbar .add (1.0 )
451- if len (disruptive ) >= num_shots :
452- y_prime = y_prime [:num_shots ]
453- y_gold = y_gold [:num_shots ]
454- disruptive = disruptive [:num_shots ]
455- break
456- return y_prime ,y_gold ,disruptive
457-
458- def make_predictions_and_evaluate_gpu (conf ,shot_list ,loader ,custom_path = None ):
459- y_prime ,y_gold ,disruptive = make_predictions (conf ,shot_list ,loader ,custom_path )
460- analyzer = PerformanceAnalyzer (conf = conf )
461- roc_area = analyzer .get_roc_area (y_prime ,y_gold ,disruptive )
462- loss = get_loss_from_list (y_prime ,y_gold ,conf ['data' ]['target' ])
463- return y_prime ,y_gold ,disruptive ,roc_area ,loss
464-
0 commit comments