Skip to content

Commit 9322202

Browse files
author
Julian Kates-Harbeck
committed
added functionality to compute multiple ROC values for different T_min values at once. Add option to use ProcessGenerator or not
1 parent bbfa51b commit 9322202

File tree

1 file changed

+45
-18
lines changed

1 file changed

+45
-18
lines changed

plasma/models/mpi_runner.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,10 @@ def get_val(self):
171171

172172

173173
class MPIModel():
174-
def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=None,warmup_steps=1000,lr=0.01,num_batches_minimum=100):
174+
def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=None,warmup_steps=1000,lr=0.01,num_batches_minimum=100,conf=None):
175175
random.seed(task_index)
176176
np.random.seed(task_index)
177+
self.conf = conf
177178
self.start_time = time.time()
178179
self.epoch = 0
179180
self.num_so_far = 0
@@ -200,10 +201,14 @@ def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=No
200201

201202

202203
def set_batch_iterator_func(self):
203-
self.batch_iterator_func = ProcessGenerator(self.batch_iterator())
204+
if self.conf is not None and 'use_process_generator' in conf['training'] and conf['training']['use_process_generator']:
205+
self.batch_iterator_func = ProcessGenerator(self.batch_iterator())
206+
else:
207+
self.batch_iterator_func = self.batch_iterator()
204208

205209
def close(self):
206-
self.batch_iterator_func.__exit__()
210+
if hasattr(self.batch_iterator_func,'__exit__'):
211+
self.batch_iterator_func.__exit__()
207212

208213
def set_lr(self,lr):
209214
self.lr = lr
@@ -642,6 +647,24 @@ def mpi_make_predictions_and_evaluate(conf,shot_list,loader,custom_path=None):
642647
loss = get_loss_from_list(y_prime,y_gold,conf['data']['target'])
643648
return y_prime,y_gold,disruptive,roc_area,loss
644649

650+
def mpi_make_predictions_and_evaluate_multiple_times(conf,shot_list,loader,times,custom_path=None):
651+
y_prime,y_gold,disruptive = mpi_make_predictions(conf,shot_list,loader,custom_path)
652+
areas = []
653+
losses = []
654+
for T_min_curr in times:
655+
#if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
656+
conf_curr = deepcopy(conf)
657+
T_min_warn_orig = conf['data']['T_min_warn']
658+
conf_curr['data']['T_min_warn'] = T_min_curr
659+
assert(conf['data']['T_min_warn'] == T_min_warn_orig)
660+
analyzer = PerformanceAnalyzer(conf=conf_curr)
661+
roc_area = analyzer.get_roc_area(y_prime,y_gold,disruptive)
662+
#shot_list.set_weights(analyzer.get_shot_difficulty(y_prime,y_gold,disruptive))
663+
loss = get_loss_from_list(y_prime,y_gold,conf['data']['target'])
664+
areas.append(roc_area)
665+
losses.append(loss)
666+
return areas,losses
667+
645668

646669
def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=None,shot_list_test=None):
647670

@@ -680,7 +703,7 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
680703
#{}batch_generator = partial(loader.training_batch_generator_process,shot_list=shot_list_train)
681704

682705
print("warmup {}".format(warmup_steps))
683-
mpi_model = MPIModel(train_model,optimizer,comm,batch_generator,batch_size,lr=lr,warmup_steps = warmup_steps,num_batches_minimum=num_batches_minimum)
706+
mpi_model = MPIModel(train_model,optimizer,comm,batch_generator,batch_size,lr=lr,warmup_steps = warmup_steps,num_batches_minimum=num_batches_minimum,conf=conf)
684707
mpi_model.compile(conf['model']['optimizer'],clipnorm,conf['data']['target'].loss)
685708

686709
tensorboard = None
@@ -709,6 +732,7 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
709732
cmp_fn = min
710733

711734
while e < num_epochs-1:
735+
print_unique("begin epoch {} 0".format(e))
712736
if task_index == 0:
713737
callbacks.on_epoch_begin(int(round(e)))
714738
mpi_model.set_lr(lr*lr_decay**e)
@@ -733,18 +757,15 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
733757
mpi_model.set_batch_iterator_func()
734758

735759
if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
736-
conf_curr = deepcopy(conf)
737-
T_min_warn_orig = conf['data']['T_min_warn']
738-
for T_min_curr in conf_curr['callbacks']['monitor_times']:
739-
conf_curr['data']['T_min_warn'] = T_min_curr
740-
assert(conf['data']['T_min_warn'] == T_min_warn_orig)
741-
if shot_list_test is not None:
742-
_,_,_,roc_area_t,_ = mpi_make_predictions_and_evaluate(conf_curr,shot_list_test,loader)
743-
print_unique('epoch {}, test_roc_{} = {}'.format(int(round(e)),T_min_curr,roc_area_t))
744-
#epoch_logs['test_roc_{}'.format(T_min_curr)] = roc_area_t
745-
_,_,_,roc_area_v,_ = mpi_make_predictions_and_evaluate(conf_curr,shot_list_validate,loader)
746-
print_unique('epoch {}, val_roc_{} = {}'.format(int(round(e)),T_min_curr,roc_area_v))
747-
#epoch_logs['val_roc_{}'.format(T_min_curr)] = roc_area_v
760+
times = conf['callbacks']['monitor_times']
761+
roc_areas,losses = mpi_make_predictions_and_evaluate_multiple_times(conf,shot_list_validate,loader,times)
762+
for roc,t in zip(roc_areas,times):
763+
print_unique('epoch {}, val_roc_{} = {}'.format(int(round(e)),t,roc))
764+
if shot_list_test is not None:
765+
roc_areas,losses = mpi_make_predictions_and_evaluate_multiple_times(conf,shot_list_test,loader,times)
766+
for roc,t in zip(roc_areas,times):
767+
print_unique('epoch {}, test_roc_{} = {}'.format(int(round(e)),t,roc))
768+
748769
epoch_logs['val_roc'] = roc_area
749770
epoch_logs['val_loss'] = loss
750771
epoch_logs['train_loss'] = ave_loss
@@ -764,16 +785,22 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
764785
if hasattr(mpi_model.model,'stop_training'):
765786
stop_training = mpi_model.model.stop_training
766787
if best_so_far != epoch_logs[conf['callbacks']['monitor']]: #only save model weights if quantity we are tracking is improving
767-
print("Not saving model weights")
768-
specific_builder.delete_model_weights(train_model,int(round(e)))
788+
if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
789+
790+
print("No improvement, saving model weights anyways")
791+
else:
792+
print("Not saving model weights")
793+
specific_builder.delete_model_weights(train_model,int(round(e)))
769794

770795
#tensorboard
771796
if backend != 'theano':
772797
val_generator = partial(loader.training_batch_generator,shot_list=shot_list_validate)()
773798
val_steps = 1
774799
tensorboard.on_epoch_end(val_generator,val_steps,int(round(e)),epoch_logs)
775800

801+
print_unique("end epoch {} 0".format(e))
776802
stop_training = comm.bcast(stop_training,root=0)
803+
print_unique("end epoch {} 1".format(e))
777804
if stop_training:
778805
print("Stopping training due to early stopping")
779806
break

0 commit comments

Comments
 (0)