|
23 | 23 | import random |
24 | 24 |
|
25 | 25 | from functools import partial |
| 26 | +from copy import deepcopy |
26 | 27 | import socket |
27 | 28 | sys.setrecursionlimit(10000) |
28 | 29 | import getpass |
@@ -642,7 +643,7 @@ def mpi_make_predictions_and_evaluate(conf,shot_list,loader,custom_path=None): |
642 | 643 | return y_prime,y_gold,disruptive,roc_area,loss |
643 | 644 |
|
644 | 645 |
|
645 | | -def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=None): |
| 646 | +def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=None,shot_list_test=None): |
646 | 647 |
|
647 | 648 | loader.set_inference_mode(False) |
648 | 649 | conf['num_workers'] = comm.Get_size() |
@@ -730,6 +731,17 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non |
730 | 731 | mpi_model.batch_iterator_func.__exit__() |
731 | 732 | mpi_model.num_so_far_accum = mpi_model.num_so_far_indiv |
732 | 733 | mpi_model.set_batch_iterator_func() |
| 734 | + if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']: |
| 735 | + conf_curr = deepcopy(conf) |
| 736 | + T_min_warn_orig = conf['data']['T_min_warn'] |
| 737 | + for T_min_curr in conf_curr['callbacks']['monitor_times']: |
| 738 | + conf_curr['data']['T_min_warn'] = T_min_curr |
| 739 | + assert(conf['data']['T_min_warn'] == T_min_warn_orig) |
| 740 | + if shot_list_test is not None: |
| 741 | + _,_,_,roc_area_t,_ = mpi_make_predictions_and_evaluate(conf_curr,shot_list_test,loader) |
| 742 | + epoch_logs['test_roc_{}'.format(T_min_curr)] = roc_area_t |
| 743 | + _,_,_,roc_area_v,_ = mpi_make_predictions_and_evaluate(conf_curr,shot_list_validate,loader) |
| 744 | + epoch_logs['val_roc_{}'.format(T_min_curr)] = roc_area_v |
733 | 745 | epoch_logs['val_roc'] = roc_area |
734 | 746 | epoch_logs['val_loss'] = loss |
735 | 747 | epoch_logs['train_loss'] = ave_loss |
|
0 commit comments