|
1 | 1 | from plasma.models.mpi_runner import ( |
2 | | - mpi_train, mpi_make_predictions_and_evaluate |
| 2 | + mpi_train, mpi_make_predictions_and_evaluate, |
| 3 | + comm, task_index, print_unique |
3 | 4 | ) |
4 | | -from mpi4py import MPI |
| 5 | +# from mpi4py import MPI |
5 | 6 | from plasma.preprocessor.preprocess import guarantee_preprocessed |
6 | 7 | from plasma.models.loader import Loader |
7 | 8 | from plasma.conf import conf |
8 | | -from pprint import pprint |
| 9 | +# from pprint import pprint |
9 | 10 | ''' |
10 | 11 | ######################################################### |
11 | 12 | This file trains a deep learning model to predict |
|
37 | 38 |
|
38 | 39 |
|
39 | 40 | if conf['model']['shallow']: |
40 | | - print( |
41 | | - "Shallow learning using MPI is not supported yet. ", |
42 | | - "Set conf['model']['shallow'] to False.") |
| 41 | + print("Shallow learning using MPI is not supported yet. ", |
| 42 | + "Set conf['model']['shallow'] to False.") |
43 | 43 | exit(1) |
44 | 44 | if conf['data']['normalizer'] == 'minmax': |
45 | 45 | from plasma.preprocessor.normalize import MinMaxNormalizer as Normalizer |
|
57 | 57 | print('unkown normalizer. exiting') |
58 | 58 | exit(1) |
59 | 59 |
|
60 | | -comm = MPI.COMM_WORLD |
61 | | -task_index = comm.Get_rank() |
62 | | -num_workers = comm.Get_size() |
63 | | -NUM_GPUS = conf['num_gpus'] |
64 | | -MY_GPU = task_index % NUM_GPUS |
| 60 | +# TODO(KGF): this part of the code is duplicated in mpi_runner.py |
| 61 | +# comm = MPI.COMM_WORLD |
| 62 | +# task_index = comm.Get_rank() |
| 63 | +# num_workers = comm.Get_size() |
65 | 64 |
|
| 65 | +# NUM_GPUS = conf['num_gpus'] |
| 66 | +# MY_GPU = task_index % NUM_GPUS |
| 67 | +# backend = conf['model']['backend'] |
| 68 | + |
| 69 | +# if task_index == 0: |
| 70 | +# pprint(conf) |
| 71 | + |
| 72 | +# TODO(KGF): confirm that this second PRNG seed setting is not needed |
| 73 | +# (before normalization; done again before MPI training) |
| 74 | +# np.random.seed(task_index) |
| 75 | +# random.seed(task_index) |
66 | 76 |
|
67 | | -np.random.seed(task_index) |
68 | | -random.seed(task_index) |
69 | | -if task_index == 0: |
70 | | - pprint(conf) |
71 | 77 |
|
72 | 78 | only_predict = len(sys.argv) > 1 |
73 | 79 | custom_path = None |
74 | 80 | if only_predict: |
75 | 81 | custom_path = sys.argv[1] |
76 | | - print("predicting using path {}".format(custom_path)) |
| 82 | + print_unique("predicting using path {}".format(custom_path)) |
77 | 83 |
|
78 | 84 |
|
79 | 85 | ##################################################### |
|
89 | 95 | shot_list_test) = guarantee_preprocessed(conf) |
90 | 96 |
|
91 | 97 |
|
92 | | -print("normalization", end='') |
| 98 | +print_unique("normalization", end='') |
93 | 99 | normalizer = Normalizer(conf) |
94 | 100 | normalizer.train() |
95 | 101 | loader = Loader(conf, normalizer) |
96 | | -print("...done") |
| 102 | +print_unique("...done") |
97 | 103 |
|
98 | 104 | # ensure training has a separate random seed for every worker |
99 | 105 | np.random.seed(task_index) |
|
104 | 110 |
|
105 | 111 | # load last model for testing |
106 | 112 | loader.set_inference_mode(True) |
107 | | -print('saving results') |
| 113 | +print_unique('saving results') |
108 | 114 | y_prime = [] |
109 | 115 | y_gold = [] |
110 | 116 | disruptive = [] |
|
117 | 123 | loss_test) = mpi_make_predictions_and_evaluate(conf, shot_list_test, |
118 | 124 | loader, custom_path) |
119 | 125 |
|
120 | | -if task_index == 0: |
121 | | - print('=========Summary========') |
122 | | - print('Train Loss: {:.3e}'.format(loss_train)) |
123 | | - print('Train ROC: {:.4f}'.format(roc_train)) |
124 | | - print('Test Loss: {:.3e}'.format(loss_test)) |
125 | | - print('Test ROC: {:.4f}'.format(roc_test)) |
| 126 | +print_unique('=========Summary========') |
| 127 | +print_unique('Train Loss: {:.3e}'.format(loss_train)) |
| 128 | +print_unique('Train ROC: {:.4f}'.format(roc_train)) |
| 129 | +print_unique('Test Loss: {:.3e}'.format(loss_test)) |
| 130 | +print_unique('Test ROC: {:.4f}'.format(roc_test)) |
126 | 131 |
|
127 | 132 |
|
128 | 133 | if task_index == 0: |
|
156 | 161 | # requirement for "allow_pickle=True" to savez() calls |
157 | 162 |
|
158 | 163 | sys.stdout.flush() |
159 | | -if task_index == 0: |
160 | | - print('finished.') |
| 164 | +print_unique('finished.') |
0 commit comments