Skip to content

Commit 1d5cf00

Browse files
author
Julian Kates-Harbeck
committed
supporting logging of val and test roc values for various T_min_warn times
1 parent 76a566e commit 1d5cf00

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

examples/mpi_learn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
np.random.seed(task_index)
9191
random.seed(task_index)
9292
if not only_predict:
93-
mpi_train(conf,shot_list_train,shot_list_validate,loader)
93+
mpi_train(conf,shot_list_train,shot_list_validate,loader,shot_list_test=shot_list_test)
9494

9595
#load last model for testing
9696
loader.set_inference_mode(True)

plasma/models/mpi_runner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import random
2424

2525
from functools import partial
26+
from copy import deepcopy
2627
import socket
2728
sys.setrecursionlimit(10000)
2829
import getpass
@@ -642,7 +643,7 @@ def mpi_make_predictions_and_evaluate(conf,shot_list,loader,custom_path=None):
642643
return y_prime,y_gold,disruptive,roc_area,loss
643644

644645

645-
def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=None):
646+
def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=None,shot_list_test=None):
646647

647648
loader.set_inference_mode(False)
648649
conf['num_workers'] = comm.Get_size()
@@ -730,6 +731,17 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
730731
mpi_model.batch_iterator_func.__exit__()
731732
mpi_model.num_so_far_accum = mpi_model.num_so_far_indiv
732733
mpi_model.set_batch_iterator_func()
734+
if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
735+
conf_curr = deepcopy(conf)
736+
T_min_warn_orig = conf['data']['T_min_warn']
737+
for T_min_curr in conf_curr['callbacks']['monitor_times']:
738+
conf_curr['data']['T_min_warn'] = T_min_curr
739+
assert(conf['data']['T_min_warn'] == T_min_warn_orig)
740+
if shot_list_test is not None:
741+
_,_,_,roc_area_t,_ = mpi_make_predictions_and_evaluate(conf_curr,shot_list_test,loader)
742+
epoch_logs['test_roc_{}'.format(T_min_curr)] = roc_area_t
743+
_,_,_,roc_area_v,_ = mpi_make_predictions_and_evaluate(conf_curr,shot_list_validate,loader)
744+
epoch_logs['val_roc_{}'.format(T_min_curr)] = roc_area_v
733745
epoch_logs['val_roc'] = roc_area
734746
epoch_logs['val_loss'] = loss
735747
epoch_logs['train_loss'] = ave_loss

0 commit comments

Comments
 (0)