@@ -171,9 +171,10 @@ def get_val(self):
171171
172172
173173class MPIModel ():
174- def __init__ (self ,model ,optimizer ,comm ,batch_iterator ,batch_size ,num_replicas = None ,warmup_steps = 1000 ,lr = 0.01 ,num_batches_minimum = 100 ):
174+ def __init__ (self ,model ,optimizer ,comm ,batch_iterator ,batch_size ,num_replicas = None ,warmup_steps = 1000 ,lr = 0.01 ,num_batches_minimum = 100 , conf = None ):
175175 random .seed (task_index )
176176 np .random .seed (task_index )
177+ self .conf = conf
177178 self .start_time = time .time ()
178179 self .epoch = 0
179180 self .num_so_far = 0
@@ -200,10 +201,14 @@ def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=No
200201
201202
202203 def set_batch_iterator_func (self ):
203- self .batch_iterator_func = ProcessGenerator (self .batch_iterator ())
204+ if self .conf is not None and 'use_process_generator' in conf ['training' ] and conf ['training' ]['use_process_generator' ]:
205+ self .batch_iterator_func = ProcessGenerator (self .batch_iterator ())
206+ else :
207+ self .batch_iterator_func = self .batch_iterator ()
204208
205209 def close (self ):
206- self .batch_iterator_func .__exit__ ()
210+ if hasattr (self .batch_iterator_func ,'__exit__' ):
211+ self .batch_iterator_func .__exit__ ()
207212
208213 def set_lr (self ,lr ):
209214 self .lr = lr
@@ -642,6 +647,24 @@ def mpi_make_predictions_and_evaluate(conf,shot_list,loader,custom_path=None):
642647 loss = get_loss_from_list (y_prime ,y_gold ,conf ['data' ]['target' ])
643648 return y_prime ,y_gold ,disruptive ,roc_area ,loss
644649
650+ def mpi_make_predictions_and_evaluate_multiple_times (conf ,shot_list ,loader ,times ,custom_path = None ):
651+ y_prime ,y_gold ,disruptive = mpi_make_predictions (conf ,shot_list ,loader ,custom_path )
652+ areas = []
653+ losses = []
654+ for T_min_curr in times :
655+ #if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
656+ conf_curr = deepcopy (conf )
657+ T_min_warn_orig = conf ['data' ]['T_min_warn' ]
658+ conf_curr ['data' ]['T_min_warn' ] = T_min_curr
659+ assert (conf ['data' ]['T_min_warn' ] == T_min_warn_orig )
660+ analyzer = PerformanceAnalyzer (conf = conf_curr )
661+ roc_area = analyzer .get_roc_area (y_prime ,y_gold ,disruptive )
662+ #shot_list.set_weights(analyzer.get_shot_difficulty(y_prime,y_gold,disruptive))
663+ loss = get_loss_from_list (y_prime ,y_gold ,conf ['data' ]['target' ])
664+ areas .append (roc_area )
665+ losses .append (loss )
666+ return areas ,losses
667+
645668
646669def mpi_train (conf ,shot_list_train ,shot_list_validate ,loader , callbacks_list = None ,shot_list_test = None ):
647670
@@ -680,7 +703,7 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
680703 #{}batch_generator = partial(loader.training_batch_generator_process,shot_list=shot_list_train)
681704
682705 print ("warmup {}" .format (warmup_steps ))
683- mpi_model = MPIModel (train_model ,optimizer ,comm ,batch_generator ,batch_size ,lr = lr ,warmup_steps = warmup_steps ,num_batches_minimum = num_batches_minimum )
706+ mpi_model = MPIModel (train_model ,optimizer ,comm ,batch_generator ,batch_size ,lr = lr ,warmup_steps = warmup_steps ,num_batches_minimum = num_batches_minimum , conf = conf )
684707 mpi_model .compile (conf ['model' ]['optimizer' ],clipnorm ,conf ['data' ]['target' ].loss )
685708
686709 tensorboard = None
@@ -709,6 +732,7 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
709732 cmp_fn = min
710733
711734 while e < num_epochs - 1 :
735+ print_unique ("begin epoch {} 0" .format (e ))
712736 if task_index == 0 :
713737 callbacks .on_epoch_begin (int (round (e )))
714738 mpi_model .set_lr (lr * lr_decay ** e )
@@ -733,18 +757,15 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
733757 mpi_model .set_batch_iterator_func ()
734758
735759 if 'monitor_test' in conf ['callbacks' ].keys () and conf ['callbacks' ]['monitor_test' ]:
736- conf_curr = deepcopy (conf )
737- T_min_warn_orig = conf ['data' ]['T_min_warn' ]
738- for T_min_curr in conf_curr ['callbacks' ]['monitor_times' ]:
739- conf_curr ['data' ]['T_min_warn' ] = T_min_curr
740- assert (conf ['data' ]['T_min_warn' ] == T_min_warn_orig )
741- if shot_list_test is not None :
742- _ ,_ ,_ ,roc_area_t ,_ = mpi_make_predictions_and_evaluate (conf_curr ,shot_list_test ,loader )
743- print_unique ('epoch {}, test_roc_{} = {}' .format (int (round (e )),T_min_curr ,roc_area_t ))
744- #epoch_logs['test_roc_{}'.format(T_min_curr)] = roc_area_t
745- _ ,_ ,_ ,roc_area_v ,_ = mpi_make_predictions_and_evaluate (conf_curr ,shot_list_validate ,loader )
746- print_unique ('epoch {}, val_roc_{} = {}' .format (int (round (e )),T_min_curr ,roc_area_v ))
747- #epoch_logs['val_roc_{}'.format(T_min_curr)] = roc_area_v
760+ times = conf ['callbacks' ]['monitor_times' ]
761+ roc_areas ,losses = mpi_make_predictions_and_evaluate_multiple_times (conf ,shot_list_validate ,loader ,times )
762+ for roc ,t in zip (roc_areas ,times ):
763+ print_unique ('epoch {}, val_roc_{} = {}' .format (int (round (e )),t ,roc ))
764+ if shot_list_test is not None :
765+ roc_areas ,losses = mpi_make_predictions_and_evaluate_multiple_times (conf ,shot_list_test ,loader ,times )
766+ for roc ,t in zip (roc_areas ,times ):
767+ print_unique ('epoch {}, test_roc_{} = {}' .format (int (round (e )),t ,roc ))
768+
748769 epoch_logs ['val_roc' ] = roc_area
749770 epoch_logs ['val_loss' ] = loss
750771 epoch_logs ['train_loss' ] = ave_loss
@@ -764,16 +785,22 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
764785 if hasattr (mpi_model .model ,'stop_training' ):
765786 stop_training = mpi_model .model .stop_training
766787 if best_so_far != epoch_logs [conf ['callbacks' ]['monitor' ]]: #only save model weights if quantity we are tracking is improving
767- print ("Not saving model weights" )
768- specific_builder .delete_model_weights (train_model ,int (round (e )))
788+ if 'monitor_test' in conf ['callbacks' ].keys () and conf ['callbacks' ]['monitor_test' ]:
789+
790+ print ("No improvement, saving model weights anyways" )
791+ else :
792+ print ("Not saving model weights" )
793+ specific_builder .delete_model_weights (train_model ,int (round (e )))
769794
770795 #tensorboard
771796 if backend != 'theano' :
772797 val_generator = partial (loader .training_batch_generator ,shot_list = shot_list_validate )()
773798 val_steps = 1
774799 tensorboard .on_epoch_end (val_generator ,val_steps ,int (round (e )),epoch_logs )
775800
801+ print_unique ("end epoch {} 0" .format (e ))
776802 stop_training = comm .bcast (stop_training ,root = 0 )
803+ print_unique ("end epoch {} 1" .format (e ))
777804 if stop_training :
778805 print ("Stopping training due to early stopping" )
779806 break
0 commit comments