forked from zshi0616/DeepCell
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_pm.py
More file actions
37 lines (28 loc) · 1.16 KB
/
train_pm.py
File metadata and controls
37 lines (28 loc) · 1.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import deepcell
import torch
import os
from config import get_parse_args
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
DATA_DIR = './data/lcm'
# checkpoint = './ckpt/pm_dg2.pth'
if __name__ == '__main__':
args = get_parse_args()
num_epochs = args.num_epochs
print('[INFO] Parse Dataset')
dataset = deepcell.NpzParser_Pair(args, DATA_DIR)
train_dataset, val_dataset = dataset.get_dataset()
print('[INFO] Create Model and Trainer')
model = deepcell.Model(aggr=args.pm_aggr)
# model.load(checkpoint)
trainer = deepcell.Trainer(args, model, distributed=args.distributed, device=args.device, training_id=args.exp_id)
if args.resume:
trainer.resume()
trainer.set_training_args(loss_weight=[1.0, 0.0, 0.0], lr=1e-4, lr_step=80)
print('[INFO] Stage 1 Training ...')
trainer.train(40, train_dataset, val_dataset)
# trainer.set_training_args(loss_weight=[3.0, 1.0, 0.5], lr=1e-4, lr_step=80)
# print('[INFO] Stage 2 Training ...')
# trainer.train(num_epochs, train_dataset, val_dataset)