@@ -554,16 +554,16 @@ def train_epoch(self):
554554 + 'loss: {:.5f} [{:.5f}] | ' .format (ave_loss , curr_loss )
555555 + 'walltime: {:.4f} | ' .format (
556556 time .time () - self .start_time ))
557- print_unique (write_str + write_str_0 )
557+ write_unique (write_str + write_str_0 )
558558 step += 1
559559 else :
560- print_unique ('\r [{}] warmup phase, num so far: {}' .format (
560+ write_unique ('\r [{}] warmup phase, num so far: {}' .format (
561561 self .task_index , self .num_so_far ))
562562
563563 effective_epochs = 1.0 * self .num_so_far / num_total
564564 epoch_previous = self .epoch
565565 self .epoch = effective_epochs
566- print_unique ('\n Epoch {:.2f} finished ({:.2f} epochs passed)' .format (
566+ write_unique ('\n Epoch {:.2f} finished ({:.2f} epochs passed)' .format (
567567 1.0 * self .epoch , self .epoch - epoch_previous )
568568 + ' in {:.2f} seconds.\n ' .format (t2 - t_start ))
569569 return (step , ave_loss , curr_loss , self .num_so_far , effective_epochs )
@@ -576,7 +576,7 @@ def estimate_remaining_time(self, time_so_far, work_so_far, work_total):
576576 def get_effective_lr (self , num_replicas ):
577577 effective_lr = self .lr * num_replicas
578578 if effective_lr > self .max_lr :
579- print_unique ('Warning: effective learning rate set to {}, ' .format (
579+ write_unique ('Warning: effective learning rate set to {}, ' .format (
580580 effective_lr ) + 'larger than maximum {}. Clipping.' .format (
581581 self .max_lr ))
582582 effective_lr = self .max_lr
@@ -604,18 +604,41 @@ def calculate_speed(self, t0, t_after_deltas, t_after_update, num_replicas,
604604 effective_batch_size , self .batch_size , num_replicas ,
605605 self .get_effective_lr (num_replicas ), self .lr , num_replicas )
606606 if verbose :
607- print_unique (print_str )
607+ write_unique (print_str )
608608 return print_str
609609
610610
611- def print_unique (print_str ):
611+ def print_unique (print_output , end = '\n ' , flush = False ):
612+ """
613+ Only master MPI rank 0 calls print().
614+
615+ Trivial wrapper function to print()
616+ """
617+ if task_index == 0 :
618+ print (print_output , end = end , flush = flush )
619+
620+
621+ def write_unique (write_str ):
622+ """
623+ Only master MPI rank 0 writes to and flushes stdout.
624+
625+ A specialized case of print_unique(). Unlike print(), sys.stdout.write():
626+ - Must pass a string; will not cast argument
627+ - end='\n ' kwarg of print() is not available
628+ (often the argument here is prepended with \r =carriage return in order to
629+ simulate a terminal output that overwrites itself)
630+ """
631+ # TODO(KGF): \r carriage returns appear as ^M in Unix-encoded .out files
632+ # from non-interactive Slurm batch jobs. Convert these to true Unix
633+ # line feeds / newlines (^J, \n) when we can detect such a stdout
612634 if task_index == 0 :
613- sys .stdout .write (print_str )
635+ sys .stdout .write (write_str )
614636 sys .stdout .flush ()
615637
616638
617- def print_all (print_str ):
618- sys .stdout .write ('[{}] ' .format (task_index ) + print_str )
639+ def write_all (write_str ):
640+ '''All MPI ranks write to stdout, appending [rank]'''
641+ sys .stdout .write ('[{}] ' .format (task_index ) + write_str )
619642 sys .stdout .flush ()
620643
621644
@@ -798,12 +821,12 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
798821 print ("Optimizer not implemented yet" )
799822 exit (1 )
800823
801- print_unique ('{} epochs left to go' .format (num_epochs - 1 - e ))
824+ write_unique ('{} epochs left to go' .format (num_epochs - 1 - e ))
802825
803826 batch_generator = partial (loader .training_batch_generator_partial_reset ,
804827 shot_list = shot_list_train )
805828
806- print_unique ("warmup steps = {}" .format (warmup_steps ))
829+ write_unique ("warmup steps = {}" .format (warmup_steps ))
807830 mpi_model = MPIModel (train_model , optimizer , comm , batch_generator ,
808831 batch_size , lr = lr , warmup_steps = warmup_steps ,
809832 num_batches_minimum = num_batches_minimum , conf = conf )
@@ -835,11 +858,11 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
835858 cmp_fn = min
836859
837860 while e < (num_epochs - 1 ):
838- print_unique ("begin epoch {}" .format (e ))
861+ write_unique ("begin epoch {}" .format (e ))
839862 if task_index == 0 :
840863 callbacks .on_epoch_begin (int (round (e )))
841864 mpi_model .set_lr (lr * lr_decay ** e )
842- print_unique ('\n Epoch {}/{}' .format (e , num_epochs ))
865+ write_unique ('\n Epoch {}/{}' .format (e , num_epochs ))
843866
844867 (step , ave_loss , curr_loss , num_so_far ,
845868 effective_epochs ) = mpi_model .train_epoch ()
@@ -871,13 +894,13 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
871894 areas , _ = mpi_make_predictions_and_evaluate_multiple_times (
872895 conf , shot_list_validate , loader , times )
873896 for roc , t in zip (areas , times ):
874- print_unique ('epoch {}, val_roc_{} = {}' .format (
897+ write_unique ('epoch {}, val_roc_{} = {}' .format (
875898 int (round (e )), t , roc ))
876899 if shot_list_test is not None :
877900 areas , _ = mpi_make_predictions_and_evaluate_multiple_times (
878901 conf , shot_list_test , loader , times )
879902 for roc , t in zip (areas , times ):
880- print_unique ('epoch {}, test_roc_{} = {}' .format (
903+ write_unique ('epoch {}, test_roc_{} = {}' .format (
881904 int (round (e )), t , roc ))
882905
883906 epoch_logs ['val_roc' ] = roc_area
@@ -917,9 +940,9 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
917940 int (round (e )), epoch_logs )
918941
919942 stop_training = comm .bcast (stop_training , root = 0 )
920- print_unique ("end epoch {}" .format (e ))
943+ write_unique ("end epoch {}" .format (e ))
921944 if stop_training :
922- print_unique ("Stopping training due to early stopping" )
945+ write_unique ("Stopping training due to early stopping" )
923946 break
924947
925948 if task_index == 0 :
0 commit comments