3232 metavar = 'C' , type = str , help = 'Checkpoint path' )
3333parser .add_argument ('--gpu' , default = 0 ,
3434 metavar = 'G' , type = int , help = 'GPU device ID' )
35+ parser .add_argument ('--visualize' , action = 'store_true' ,
36+ help = 'Visualize train and validation loss.' )
3537# Data options
3638parser .add_argument ('--seq-len' , type = int , default = 100 ,
3739 help = 'Sequence length for tbptt' )
@@ -134,10 +136,11 @@ def train(model, criterion, optimizer, epoch, train_losses):
134136
135137 avg = total / len (train_loader )
136138 train_losses .append (avg )
137- vis .line (Y = np .asarray (train_losses ),
138- X = torch .arange (1 , 1 + len (train_losses )),
139- opts = dict (title = "Train" ),
140- win = 'Train loss ' + args .expName )
139+ if args .visualize :
140+ vis .line (Y = np .asarray (train_losses ),
141+ X = torch .arange (1 , 1 + len (train_losses )),
142+ opts = dict (title = "Train" ),
143+ win = 'Train loss ' + args .expName )
141144
142145 logging .info ('====> Train set loss: {:.4f}' .format (avg ))
143146
@@ -161,10 +164,11 @@ def evaluate(model, criterion, epoch, eval_losses):
161164
162165 avg = total / len (valid_loader )
163166 eval_losses .append (avg )
164- vis .line (Y = np .asarray (eval_losses ),
165- X = torch .arange (1 , 1 + len (eval_losses )),
166- opts = dict (title = "Eval" ),
167- win = 'Eval loss ' + args .expName )
167+ if args .visualize :
168+ vis .line (Y = np .asarray (eval_losses ),
169+ X = torch .arange (1 , 1 + len (eval_losses )),
170+ opts = dict (title = "Eval" ),
171+ win = 'Eval loss ' + args .expName )
168172
169173 logging .info ('====> Test set loss: {:.4f}' .format (avg ))
170174 return avg
0 commit comments