Skip to content

Commit 773fa2a

Browse files
committed
Made visdom is optional
1 parent 0879b1d commit 773fa2a

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

train.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
metavar='C', type=str, help='Checkpoint path')
3333
parser.add_argument('--gpu', default=0,
3434
metavar='G', type=int, help='GPU device ID')
35+
parser.add_argument('--visualize', action='store_true',
36+
help='Visualize train and validation loss.')
3537
# Data options
3638
parser.add_argument('--seq-len', type=int, default=100,
3739
help='Sequence length for tbptt')
@@ -134,10 +136,11 @@ def train(model, criterion, optimizer, epoch, train_losses):
134136

135137
avg = total / len(train_loader)
136138
train_losses.append(avg)
137-
vis.line(Y=np.asarray(train_losses),
138-
X=torch.arange(1, 1 + len(train_losses)),
139-
opts=dict(title="Train"),
140-
win='Train loss ' + args.expName)
139+
if args.visualize:
140+
vis.line(Y=np.asarray(train_losses),
141+
X=torch.arange(1, 1 + len(train_losses)),
142+
opts=dict(title="Train"),
143+
win='Train loss ' + args.expName)
141144

142145
logging.info('====> Train set loss: {:.4f}'.format(avg))
143146

@@ -161,10 +164,11 @@ def evaluate(model, criterion, epoch, eval_losses):
161164

162165
avg = total / len(valid_loader)
163166
eval_losses.append(avg)
164-
vis.line(Y=np.asarray(eval_losses),
165-
X=torch.arange(1, 1 + len(eval_losses)),
166-
opts=dict(title="Eval"),
167-
win='Eval loss ' + args.expName)
167+
if args.visualize:
168+
vis.line(Y=np.asarray(eval_losses),
169+
X=torch.arange(1, 1 + len(eval_losses)),
170+
opts=dict(title="Eval"),
171+
win='Eval loss ' + args.expName)
168172

169173
logging.info('====> Test set loss: {:.4f}'.format(avg))
170174
return avg

0 commit comments

Comments
 (0)