@@ -73,9 +73,9 @@ def save_stats(self):
7373 def load_stats (self ):
7474 pass
7575
76- @abc .abstractmethod
7776 def print_summary (self , action = 'loaded' ):
78- pass
77+ print ('{} normalization data from {} shots ( {} disruptive )' .format (
78+ action , self .num_processed , self .num_disruptive ))
7979
8080 def set_inference_mode (self , val ):
8181 self .inference_mode = val
@@ -132,8 +132,7 @@ def train_on_files(self, shot_files, use_shots, all_machines):
132132 start_time = time .time ()
133133
134134 for (i , stats ) in enumerate (pool .imap_unordered (
135- self .train_on_single_shot ,
136- shot_list_picked )):
135+ self .train_on_single_shot , shot_list_picked )):
137136 # for (i,stats) in
138137 # enumerate(map(self.train_on_single_shot,shot_list_picked)):
139138 if stats .machine in machines_to_compute :
@@ -308,10 +307,6 @@ def load_stats(self):
308307 print ('Machine {}:' .format (machine ))
309308 self .print_summary ()
310309
311- def print_summary (self , action = 'loaded' ):
312- print ('{} normalization data from {} shots ( {} disruptive )' .format (
313- action , self .num_processed , self .num_disruptive ))
314-
315310
316311class VarNormalizer (MeanVarNormalizer ):
317312 def apply (self , shot ):
@@ -341,7 +336,6 @@ def __str__(self):
341336
342337
343338class AveragingVarNormalizer (VarNormalizer ):
344-
345339 def apply (self , shot ):
346340 apply_positivity (shot )
347341 super (AveragingVarNormalizer , self ).apply (shot )
@@ -444,17 +438,10 @@ def save_stats(self):
444438 # num_processed = dat['num_processed']
445439 # num_disruptive = dat['num_disruptive']
446440 self .ensure_save_directory ()
447- np .savez (
448- self .path ,
449- minimums = self .minimums ,
450- maximums = self .maximums ,
451- num_processed = self .num_processed ,
452- num_disruptive = self .num_disruptive ,
453- machines = self .machines )
454- print (
455- 'saved normalization data from {} shots ( {} disruptive )' .format (
456- self .num_processed ,
457- self .num_disruptive ))
441+ np .savez (self .path , minimums = self .minimums , maximums = self .maximums ,
442+ num_processed = self .num_processed ,
443+ num_disruptive = self .num_disruptive , machines = self .machines )
444+ self .print_summary (action = 'saved' )
458445
459446 def load_stats (self ):
460447 assert (self .previously_saved_stats ()[0 ])
@@ -464,10 +451,9 @@ def load_stats(self):
464451 self .num_processed = dat ['num_processed' ][()]
465452 self .num_disruptive = dat ['num_disruptive' ][()]
466453 self .machines = dat ['machines' ][()]
467- print (
468- 'loaded normalization data from {} shots ( {} disruptive )' .format (
469- self .num_processed ,
470- self .num_disruptive ))
454+ for machine in self .means :
455+ print ('Machine {}:' .format (machine ))
456+ self .print_summary ()
471457
472458
473459def get_individual_shot_file (prepath , shot_num , ext = '.txt' ):
0 commit comments