-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
23 lines (18 loc) · 799 Bytes
/
train.py
File metadata and controls
23 lines (18 loc) · 799 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import argparse
from src.trainer import Trainer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.0003)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--weight_decay", type=float, default=0.0001)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--num_workers", type=int, default=3)
parser.add_argument("--csv_path", type=str, default="train.csv")
parser.add_argument("--val_csv_path", type=str, default="test.csv")
parser.add_argument("--pretrained_weight", type=str)
parser.add_argument("--memo", type=str)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
trainer = Trainer(args)
trainer.train()