-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
57 lines (43 loc) · 2 KB
/
train.py
File metadata and controls
57 lines (43 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from lightning_fabric.utilities.seed import seed_everything, reset_seed
# from imp_cfm.utils.arg_utils import construct_experiment_subdir
from train_argparser import build_argparser, construct_experiment_subdir
def run(hparams):
# Set random seed
# NOTE: this must be done before any class initialisation,
# hence also before the call to parser.instantiate_classes()
seed_everything(hparams.seed_everything, workers=True)
# Construct the experiment directory
experiment_subdir = construct_experiment_subdir(hparams)
if hparams.trainer.default_root_dir is None:
experiment_dir = f'./{experiment_subdir}'
else:
experiment_dir = f'{hparams.trainer.default_root_dir}/{experiment_subdir}'
# Instantiate dynamic object classes
hparams = parser.instantiate_classes(hparams)
# Get the instantiated data
datamodule = hparams.data
# Get the instantiated model
model = hparams.model
# Instantiate the trainer
trainer_args = { **(hparams.trainer.as_dict()), "default_root_dir": experiment_dir }
if hparams.add_checkpoint_callback:
checkpoint_callback = ModelCheckpoint(save_top_k=1,
save_last=True,
monitor="loss/val")
if trainer_args['callbacks'] is not None:
trainer_args['callbacks'].append(checkpoint_callback)
else:
trainer_args['callbacks'] = [checkpoint_callback]
trainer = pl.Trainer(**trainer_args)
# The instantiation steps might be different for different models.
# Hence we reset the seed before training such that the seed at the start of lightning setup is the same.
reset_seed()
# Begin fitting
trainer.fit(model=model, datamodule=datamodule)
if __name__ == '__main__':
parser = build_argparser()
# Parse arguments
hparams = parser.parse_args()
run(hparams)