Skip to content

Commit df41970

Browse files
committed
Suppress duplicated stdout output from normalize.py
1 parent 1680750 commit df41970

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

examples/mpi_learn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,26 @@
8080
(shot_list_train, shot_list_validate,
8181
shot_list_test) = guarantee_preprocessed(conf, verbose=True)
8282

83+
# TODO(KGF): shouldn't normalize.train() be called like guaranteed_preprocessed
84+
# above? I.e. if Normalizer.previously_saved_stats() does not load a computed
85+
# normalizer for all machines ("loaded normalization data from {d3d: 3449, jet: 2918} # noqa
86+
# shots ( {d3d: 810, jet: 74} disruptive )" ), then only the master MPI rank
87+
# calls normalizer.train() ???
8388
g.print_unique("begin normalization...")
8489
normalizer = Normalizer(conf)
8590
normalizer.train()
8691
loader = Loader(conf, normalizer)
8792
g.print_unique("...done")
8893

94+
# TODO(KGF): note, "python examples/guaranteed_preprocessed.py" does NOT train
95+
# the normalizer. Try deleting the previously-computed file, e.g.
96+
# normalization/normalization_signal_group_250640798211266795112500621861190558178.npz # noqa
97+
# or set conf['data']['recompute_normalization'] = True to see example stdout
98+
99+
# TODO(KGF): both preprocess.py and normalize.py are littered with print()
100+
# calls that should probably be replaced with print_unique() when they are not
101+
# purely loading previously-computed quantities from file
102+
89103
#####################################################
90104
# TRAINING #
91105
#####################################################

plasma/models/mpi_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -782,12 +782,12 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
782782
print("Optimizer not implemented yet")
783783
exit(1)
784784

785-
g.write_unique('{} epochs left to go'.format(num_epochs - 1 - e))
785+
g.print_unique('{} epochs left to go'.format(num_epochs - 1 - e))
786786

787787
batch_generator = partial(loader.training_batch_generator_partial_reset,
788788
shot_list=shot_list_train)
789789

790-
g.write_unique("warmup steps = {}".format(warmup_steps))
790+
g.print_unique("warmup steps = {}".format(warmup_steps))
791791
mpi_model = MPIModel(train_model, optimizer, g.comm, batch_generator,
792792
batch_size, lr=lr, warmup_steps=warmup_steps,
793793
num_batches_minimum=num_batches_minimum, conf=conf)

plasma/preprocessor/normalize.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
'''
1010

1111
from __future__ import print_function
12+
import plasma.global_vars as g
1213
import os
1314
import time
1415
import sys
@@ -74,7 +75,7 @@ def load_stats(self):
7475
pass
7576

7677
def print_summary(self, action='loaded'):
77-
print('{} normalization data from {} shots ( {} disruptive )'.format(
78+
g.print_unique('{} normalization data from {} shots ( {} disruptive )'.format(
7879
action, self.num_processed, self.num_disruptive))
7980

8081
def set_inference_mode(self, val):
@@ -149,7 +150,7 @@ def train_on_files(self, shot_files, use_shots, all_machines):
149150
self.save_stats()
150151
else:
151152
self.load_stats()
152-
print(self)
153+
g.print_unique(self)
153154

154155
def cut_end_of_shot(self, shot):
155156
cut_shot_ends = self.conf['data']['cut_shot_ends']
@@ -222,7 +223,7 @@ def __str__(self):
222223
for machine in self.means:
223224
means = np.median(self.means[machine], axis=0)
224225
stds = np.median(self.stds[machine], axis=0)
225-
s += 'Machine: {}:\nMean Var Normalizer.\n'.format(machine)
226+
s += 'Machine = {}:\nMean Var Normalizer.\n'.format(machine)
226227
s += 'means: {}\nstds: {}'.format(means, stds)
227228
return s
228229

@@ -304,8 +305,8 @@ def load_stats(self):
304305
self.num_disruptive = dat['num_disruptive'][()]
305306
self.machines = dat['machines'][()]
306307
for machine in self.means:
307-
print('Machine {}:'.format(machine))
308-
self.print_summary()
308+
g.print_unique('Machine = {}:'.format(machine))
309+
self.print_summary()
309310

310311

311312
class VarNormalizer(MeanVarNormalizer):
@@ -452,7 +453,7 @@ def load_stats(self):
452453
self.num_disruptive = dat['num_disruptive'][()]
453454
self.machines = dat['machines'][()]
454455
for machine in self.means:
455-
print('Machine {}:'.format(machine))
456+
g.print_unique('Machine {}:'.format(machine))
456457
self.print_summary()
457458

458459

0 commit comments

Comments
 (0)