Skip to content

Commit 7ced758

Browse files
committed
Use two different fns for controlling MPI printint to stdout
1 parent 6583004 commit 7ced758

File tree

2 files changed

+40
-18
lines changed

2 files changed

+40
-18
lines changed

examples/mpi_learn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@
129129
print_unique('Test Loss: {:.3e}'.format(loss_test))
130130
print_unique('Test ROC: {:.4f}'.format(roc_test))
131131

132-
133132
if task_index == 0:
134133
disruptive_train = np.array(disruptive_train)
135134
disruptive_test = np.array(disruptive_test)

plasma/models/mpi_runner.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -554,16 +554,16 @@ def train_epoch(self):
554554
+ 'loss: {:.5f} [{:.5f}] | '.format(ave_loss, curr_loss)
555555
+ 'walltime: {:.4f} | '.format(
556556
time.time() - self.start_time))
557-
print_unique(write_str + write_str_0)
557+
write_unique(write_str + write_str_0)
558558
step += 1
559559
else:
560-
print_unique('\r[{}] warmup phase, num so far: {}'.format(
560+
write_unique('\r[{}] warmup phase, num so far: {}'.format(
561561
self.task_index, self.num_so_far))
562562

563563
effective_epochs = 1.0*self.num_so_far/num_total
564564
epoch_previous = self.epoch
565565
self.epoch = effective_epochs
566-
print_unique('\nEpoch {:.2f} finished ({:.2f} epochs passed)'.format(
566+
write_unique('\nEpoch {:.2f} finished ({:.2f} epochs passed)'.format(
567567
1.0 * self.epoch, self.epoch - epoch_previous)
568568
+ ' in {:.2f} seconds.\n'.format(t2 - t_start))
569569
return (step, ave_loss, curr_loss, self.num_so_far, effective_epochs)
@@ -576,7 +576,7 @@ def estimate_remaining_time(self, time_so_far, work_so_far, work_total):
576576
def get_effective_lr(self, num_replicas):
577577
effective_lr = self.lr * num_replicas
578578
if effective_lr > self.max_lr:
579-
print_unique('Warning: effective learning rate set to {}, '.format(
579+
write_unique('Warning: effective learning rate set to {}, '.format(
580580
effective_lr) + 'larger than maximum {}. Clipping.'.format(
581581
self.max_lr))
582582
effective_lr = self.max_lr
@@ -604,18 +604,41 @@ def calculate_speed(self, t0, t_after_deltas, t_after_update, num_replicas,
604604
effective_batch_size, self.batch_size, num_replicas,
605605
self.get_effective_lr(num_replicas), self.lr, num_replicas)
606606
if verbose:
607-
print_unique(print_str)
607+
write_unique(print_str)
608608
return print_str
609609

610610

611-
def print_unique(print_str):
611+
def print_unique(print_output, end='\n', flush=False):
612+
"""
613+
Only master MPI rank 0 calls print().
614+
615+
Trivial wrapper function to print()
616+
"""
617+
if task_index == 0:
618+
print(print_output, end=end, flush=flush)
619+
620+
621+
def write_unique(write_str):
622+
"""
623+
Only master MPI rank 0 writes to and flushes stdout.
624+
625+
A specialized case of print_unique(). Unlike print(), sys.stdout.write():
626+
- Must pass a string; will not cast argument
627+
- end='\n' kwarg of print() is not available
628+
(often the argument here is prepended with \r=carriage return in order to
629+
simulate a terminal output that overwrites itself)
630+
"""
631+
# TODO(KGF): \r carriage returns appear as ^M in Unix-encoded .out files
632+
# from non-interactive Slurm batch jobs. Convert these to true Unix
633+
# line feeds / newlines (^J, \n) when we can detect such a stdout
612634
if task_index == 0:
613-
sys.stdout.write(print_str)
635+
sys.stdout.write(write_str)
614636
sys.stdout.flush()
615637

616638

617-
def print_all(print_str):
618-
sys.stdout.write('[{}] '.format(task_index) + print_str)
639+
def write_all(write_str):
640+
'''All MPI ranks write to stdout, appending [rank]'''
641+
sys.stdout.write('[{}] '.format(task_index) + write_str)
619642
sys.stdout.flush()
620643

621644

@@ -798,12 +821,12 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
798821
print("Optimizer not implemented yet")
799822
exit(1)
800823

801-
print_unique('{} epochs left to go'.format(num_epochs - 1 - e))
824+
write_unique('{} epochs left to go'.format(num_epochs - 1 - e))
802825

803826
batch_generator = partial(loader.training_batch_generator_partial_reset,
804827
shot_list=shot_list_train)
805828

806-
print_unique("warmup steps = {}".format(warmup_steps))
829+
write_unique("warmup steps = {}".format(warmup_steps))
807830
mpi_model = MPIModel(train_model, optimizer, comm, batch_generator,
808831
batch_size, lr=lr, warmup_steps=warmup_steps,
809832
num_batches_minimum=num_batches_minimum, conf=conf)
@@ -835,11 +858,11 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
835858
cmp_fn = min
836859

837860
while e < (num_epochs - 1):
838-
print_unique("begin epoch {}".format(e))
861+
write_unique("begin epoch {}".format(e))
839862
if task_index == 0:
840863
callbacks.on_epoch_begin(int(round(e)))
841864
mpi_model.set_lr(lr*lr_decay**e)
842-
print_unique('\nEpoch {}/{}'.format(e, num_epochs))
865+
write_unique('\nEpoch {}/{}'.format(e, num_epochs))
843866

844867
(step, ave_loss, curr_loss, num_so_far,
845868
effective_epochs) = mpi_model.train_epoch()
@@ -871,13 +894,13 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
871894
areas, _ = mpi_make_predictions_and_evaluate_multiple_times(
872895
conf, shot_list_validate, loader, times)
873896
for roc, t in zip(areas, times):
874-
print_unique('epoch {}, val_roc_{} = {}'.format(
897+
write_unique('epoch {}, val_roc_{} = {}'.format(
875898
int(round(e)), t, roc))
876899
if shot_list_test is not None:
877900
areas, _ = mpi_make_predictions_and_evaluate_multiple_times(
878901
conf, shot_list_test, loader, times)
879902
for roc, t in zip(areas, times):
880-
print_unique('epoch {}, test_roc_{} = {}'.format(
903+
write_unique('epoch {}, test_roc_{} = {}'.format(
881904
int(round(e)), t, roc))
882905

883906
epoch_logs['val_roc'] = roc_area
@@ -917,9 +940,9 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
917940
int(round(e)), epoch_logs)
918941

919942
stop_training = comm.bcast(stop_training, root=0)
920-
print_unique("end epoch {}".format(e))
943+
write_unique("end epoch {}".format(e))
921944
if stop_training:
922-
print_unique("Stopping training due to early stopping")
945+
write_unique("Stopping training due to early stopping")
923946
break
924947

925948
if task_index == 0:

0 commit comments

Comments
 (0)