Skip to content

Commit 55639ce

Browse files
committed
Prepare to call print_unique() or analog in normalize.py
1 parent 052b345 commit 55639ce

File tree

2 files changed

+11
-41
lines changed

2 files changed

+11
-41
lines changed

examples/mpi_learn.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
mpi_train, mpi_make_predictions_and_evaluate,
33
comm, task_index, print_unique
44
)
5-
# from mpi4py import MPI
65
from plasma.preprocessor.preprocess import guarantee_preprocessed
76
from plasma.models.loader import Loader
87
from plasma.conf import conf
9-
# from pprint import pprint
108
'''
119
#########################################################
1210
This file trains a deep learning model to predict
@@ -57,24 +55,11 @@
5755
print('unkown normalizer. exiting')
5856
exit(1)
5957

60-
# TODO(KGF): this part of the code is duplicated in mpi_runner.py
61-
# comm = MPI.COMM_WORLD
62-
# task_index = comm.Get_rank()
63-
# num_workers = comm.Get_size()
64-
65-
# NUM_GPUS = conf['num_gpus']
66-
# MY_GPU = task_index % NUM_GPUS
67-
# backend = conf['model']['backend']
68-
69-
# if task_index == 0:
70-
# pprint(conf)
71-
7258
# TODO(KGF): confirm that this second PRNG seed setting is not needed
7359
# (before normalization; done again before MPI training)
7460
# np.random.seed(task_index)
7561
# random.seed(task_index)
7662

77-
7863
only_predict = len(sys.argv) > 1
7964
custom_path = None
8065
if only_predict:
@@ -94,8 +79,7 @@
9479
(shot_list_train, shot_list_validate,
9580
shot_list_test) = guarantee_preprocessed(conf)
9681

97-
98-
print_unique("begin normalization...") # , end='')
82+
print_unique("begin normalization...")
9983
normalizer = Normalizer(conf)
10084
normalizer.train()
10185
loader = Loader(conf, normalizer)

plasma/preprocessor/normalize.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def save_stats(self):
7373
def load_stats(self):
7474
pass
7575

76-
@abc.abstractmethod
7776
def print_summary(self, action='loaded'):
78-
pass
77+
print('{} normalization data from {} shots ( {} disruptive )'.format(
78+
action, self.num_processed, self.num_disruptive))
7979

8080
def set_inference_mode(self, val):
8181
self.inference_mode = val
@@ -132,8 +132,7 @@ def train_on_files(self, shot_files, use_shots, all_machines):
132132
start_time = time.time()
133133

134134
for (i, stats) in enumerate(pool.imap_unordered(
135-
self.train_on_single_shot,
136-
shot_list_picked)):
135+
self.train_on_single_shot, shot_list_picked)):
137136
# for (i,stats) in
138137
# enumerate(map(self.train_on_single_shot,shot_list_picked)):
139138
if stats.machine in machines_to_compute:
@@ -308,10 +307,6 @@ def load_stats(self):
308307
print('Machine {}:'.format(machine))
309308
self.print_summary()
310309

311-
def print_summary(self, action='loaded'):
312-
print('{} normalization data from {} shots ( {} disruptive )'.format(
313-
action, self.num_processed, self.num_disruptive))
314-
315310

316311
class VarNormalizer(MeanVarNormalizer):
317312
def apply(self, shot):
@@ -341,7 +336,6 @@ def __str__(self):
341336

342337

343338
class AveragingVarNormalizer(VarNormalizer):
344-
345339
def apply(self, shot):
346340
apply_positivity(shot)
347341
super(AveragingVarNormalizer, self).apply(shot)
@@ -444,17 +438,10 @@ def save_stats(self):
444438
# num_processed = dat['num_processed']
445439
# num_disruptive = dat['num_disruptive']
446440
self.ensure_save_directory()
447-
np.savez(
448-
self.path,
449-
minimums=self.minimums,
450-
maximums=self.maximums,
451-
num_processed=self.num_processed,
452-
num_disruptive=self.num_disruptive,
453-
machines=self.machines)
454-
print(
455-
'saved normalization data from {} shots ( {} disruptive )'.format(
456-
self.num_processed,
457-
self.num_disruptive))
441+
np.savez(self.path, minimums=self.minimums, maximums=self.maximums,
442+
num_processed=self.num_processed,
443+
num_disruptive=self.num_disruptive, machines=self.machines)
444+
self.print_summary(action='saved')
458445

459446
def load_stats(self):
460447
assert(self.previously_saved_stats()[0])
@@ -464,10 +451,9 @@ def load_stats(self):
464451
self.num_processed = dat['num_processed'][()]
465452
self.num_disruptive = dat['num_disruptive'][()]
466453
self.machines = dat['machines'][()]
467-
print(
468-
'loaded normalization data from {} shots ( {} disruptive )'.format(
469-
self.num_processed,
470-
self.num_disruptive))
454+
for machine in self.means:
455+
print('Machine {}:'.format(machine))
456+
self.print_summary()
471457

472458

473459
def get_individual_shot_file(prepath, shot_num, ext='.txt'):

0 commit comments

Comments
 (0)