Skip to content

Commit 9962727

Browse files
committed
Attempt dedpulication of code in mpi_learn.py from mpi_runner.py
Also, call mpi_runner.print_unique() fn instead of print() in mpi_learn.py
1 parent b77942c commit 9962727

2 files changed

Lines changed: 45 additions & 41 deletions

File tree

examples/mpi_learn.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from plasma.models.mpi_runner import (
2-
mpi_train, mpi_make_predictions_and_evaluate
2+
mpi_train, mpi_make_predictions_and_evaluate,
3+
comm, task_index, print_unique
34
)
4-
from mpi4py import MPI
5+
# from mpi4py import MPI
56
from plasma.preprocessor.preprocess import guarantee_preprocessed
67
from plasma.models.loader import Loader
78
from plasma.conf import conf
8-
from pprint import pprint
9+
# from pprint import pprint
910
'''
1011
#########################################################
1112
This file trains a deep learning model to predict
@@ -37,9 +38,8 @@
3738

3839

3940
if conf['model']['shallow']:
40-
print(
41-
"Shallow learning using MPI is not supported yet. ",
42-
"Set conf['model']['shallow'] to False.")
41+
print("Shallow learning using MPI is not supported yet. ",
42+
"Set conf['model']['shallow'] to False.")
4343
exit(1)
4444
if conf['data']['normalizer'] == 'minmax':
4545
from plasma.preprocessor.normalize import MinMaxNormalizer as Normalizer
@@ -57,23 +57,29 @@
5757
print('unkown normalizer. exiting')
5858
exit(1)
5959

60-
comm = MPI.COMM_WORLD
61-
task_index = comm.Get_rank()
62-
num_workers = comm.Get_size()
63-
NUM_GPUS = conf['num_gpus']
64-
MY_GPU = task_index % NUM_GPUS
60+
# TODO(KGF): this part of the code is duplicated in mpi_runner.py
61+
# comm = MPI.COMM_WORLD
62+
# task_index = comm.Get_rank()
63+
# num_workers = comm.Get_size()
6564

65+
# NUM_GPUS = conf['num_gpus']
66+
# MY_GPU = task_index % NUM_GPUS
67+
# backend = conf['model']['backend']
68+
69+
# if task_index == 0:
70+
# pprint(conf)
71+
72+
# TODO(KGF): confirm that this second PRNG seed setting is not needed
73+
# (before normalization; done again before MPI training)
74+
# np.random.seed(task_index)
75+
# random.seed(task_index)
6676

67-
np.random.seed(task_index)
68-
random.seed(task_index)
69-
if task_index == 0:
70-
pprint(conf)
7177

7278
only_predict = len(sys.argv) > 1
7379
custom_path = None
7480
if only_predict:
7581
custom_path = sys.argv[1]
76-
print("predicting using path {}".format(custom_path))
82+
print_unique("predicting using path {}".format(custom_path))
7783

7884

7985
#####################################################
@@ -89,11 +95,11 @@
8995
shot_list_test) = guarantee_preprocessed(conf)
9096

9197

92-
print("normalization", end='')
98+
print_unique("normalization", end='')
9399
normalizer = Normalizer(conf)
94100
normalizer.train()
95101
loader = Loader(conf, normalizer)
96-
print("...done")
102+
print_unique("...done")
97103

98104
# ensure training has a separate random seed for every worker
99105
np.random.seed(task_index)
@@ -104,7 +110,7 @@
104110

105111
# load last model for testing
106112
loader.set_inference_mode(True)
107-
print('saving results')
113+
print_unique('saving results')
108114
y_prime = []
109115
y_gold = []
110116
disruptive = []
@@ -117,12 +123,11 @@
117123
loss_test) = mpi_make_predictions_and_evaluate(conf, shot_list_test,
118124
loader, custom_path)
119125

120-
if task_index == 0:
121-
print('=========Summary========')
122-
print('Train Loss: {:.3e}'.format(loss_train))
123-
print('Train ROC: {:.4f}'.format(roc_train))
124-
print('Test Loss: {:.3e}'.format(loss_test))
125-
print('Test ROC: {:.4f}'.format(roc_test))
126+
print_unique('=========Summary========')
127+
print_unique('Train Loss: {:.3e}'.format(loss_train))
128+
print_unique('Train ROC: {:.4f}'.format(roc_train))
129+
print_unique('Test Loss: {:.3e}'.format(loss_test))
130+
print_unique('Test ROC: {:.4f}'.format(roc_test))
126131

127132

128133
if task_index == 0:
@@ -156,5 +161,4 @@
156161
# requirement for "allow_pickle=True" to savez() calls
157162

158163
sys.stdout.flush()
159-
if task_index == 0:
160-
print('finished.')
164+
print_unique('finished.')

plasma/models/mpi_runner.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,25 @@
3838
sys.setrecursionlimit(10000)
3939

4040

41-
def pprint_unique(obj):
42-
from pprint import pprint
43-
if task_index == 0:
44-
pprint(obj)
45-
46-
4741
# import keras sequentially because it otherwise reads from ~/.keras/keras.json
4842
# with too many threads:
4943
# from mpi_launch_tensorflow import get_mpi_task_index
5044
comm = MPI.COMM_WORLD
5145
task_index = comm.Get_rank()
5246
num_workers = comm.Get_size()
5347

54-
5548
NUM_GPUS = conf['num_gpus']
5649
MY_GPU = task_index % NUM_GPUS
57-
5850
backend = conf['model']['backend']
5951

52+
53+
def pprint_unique(obj):
54+
from pprint import pprint
55+
if task_index == 0:
56+
pprint(obj)
57+
58+
59+
# initialization code for mpi_runner.py module:
6060
if backend == 'tf' or backend == 'tensorflow':
6161
if NUM_GPUS > 1:
6262
os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(MY_GPU)
@@ -544,13 +544,13 @@ def train_epoch(self):
544544
# print_unique(self.model.get_weights()[0][0][:4])
545545
loss_averager.add_val(curr_loss)
546546
ave_loss = loss_averager.get_val()
547-
eta = self.estimate_remaining_time(t0 - t_start,
548-
self.num_so_far - self.epoch*num_total,
549-
num_total)
547+
eta = self.estimate_remaining_time(
548+
t0 - t_start, self.num_so_far - self.epoch*num_total,
549+
num_total)
550550
write_str = (
551551
'\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], '.format(
552-
self.task_index, step, eta,
553-
1.0*self.num_so_far, num_total)
552+
self.task_index, step, eta, 1.0*self.num_so_far,
553+
num_total)
554554
+ 'loss: {:.5f} [{:.5f}] | '.format(ave_loss, curr_loss)
555555
+ 'walltime: {:.4f} | '.format(
556556
time.time() - self.start_time))
@@ -834,7 +834,7 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
834834
best_so_far = np.inf
835835
cmp_fn = min
836836

837-
while e < num_epochs-1:
837+
while e < (num_epochs - 1):
838838
print_unique("begin epoch {}".format(e))
839839
if task_index == 0:
840840
callbacks.on_epoch_begin(int(round(e)))

0 commit comments

Comments
 (0)