@@ -32,6 +32,10 @@ def train_model(model, train_dataloader, epochs, display_every=100):
3232 from dataset import make_dataloaders
3333 from utils import create_loss_meters , update_losses , log_results , visualize
3434
35+ EPOCHS = 100
36+ TRAINING_IMAGES = 15_000
37+ TOTAL_IMAGES = 16_000
38+
3539 # Parse command-line arguments
3640 parser = argparse .ArgumentParser (description = 'Start training a model' )
3741 parser .add_argument ('dataset_path' , type = str , help = 'Path to the dataset folder' )
@@ -51,10 +55,10 @@ def train_model(model, train_dataloader, epochs, display_every=100):
5155
5256 # Forming datasets
5357 np .random .seed (500 )
54- paths_subset = np .random .choice (paths , 20_000 , replace = False )
55- rand_indexes = np .random .permutation (20_000 )
56- train_indexes = rand_indexes [:18000 ]
57- validate_indexes = rand_indexes [18000 :]
58+ paths_subset = np .random .choice (paths , TOTAL_IMAGES , replace = False )
59+ rand_indexes = np .random .permutation (TOTAL_IMAGES )
60+ train_indexes = rand_indexes [:TRAINING_IMAGES ]
61+ validate_indexes = rand_indexes [TRAINING_IMAGES :]
5862 train_paths = paths_subset [train_indexes ]
5963 val_paths = paths_subset [validate_indexes ]
6064 print (f"Images withdrawn with their paths" )
@@ -72,7 +76,7 @@ def train_model(model, train_dataloader, epochs, display_every=100):
7276
7377 main_model = MainModel ()
7478 print (f"Training commenced" )
75- train_model (main_model , train_dl , 100 )
79+ train_model (main_model , train_dl , EPOCHS )
7680
7781 # saving weights for later
7882 torch .save (main_model , "model.pt" )
0 commit comments