Skip to content

Commit 6977665

Browse files
Introduced constants for easier change of epochs and number of images
1 parent ec5f0c8 commit 6977665

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

train.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def train_model(model, train_dataloader, epochs, display_every=100):
3232
from dataset import make_dataloaders
3333
from utils import create_loss_meters, update_losses, log_results, visualize
3434

35+
EPOCHS = 100
36+
TRAINING_IMAGES = 15_000
37+
TOTAL_IMAGES = 16_000
38+
3539
# Parse command-line arguments
3640
parser = argparse.ArgumentParser(description='Start training a model')
3741
parser.add_argument('dataset_path', type=str, help='Path to the dataset folder')
@@ -51,10 +55,10 @@ def train_model(model, train_dataloader, epochs, display_every=100):
5155

5256
# Forming datasets
5357
np.random.seed(500)
54-
paths_subset = np.random.choice(paths, 20_000, replace=False)
55-
rand_indexes = np.random.permutation(20_000)
56-
train_indexes = rand_indexes[:18000]
57-
validate_indexes = rand_indexes[18000:]
58+
paths_subset = np.random.choice(paths, TOTAL_IMAGES, replace=False)
59+
rand_indexes = np.random.permutation(TOTAL_IMAGES)
60+
train_indexes = rand_indexes[:TRAINING_IMAGES]
61+
validate_indexes = rand_indexes[TRAINING_IMAGES:]
5862
train_paths = paths_subset[train_indexes]
5963
val_paths = paths_subset[validate_indexes]
6064
print(f"Images withdrawn with their paths")
@@ -72,7 +76,7 @@ def train_model(model, train_dataloader, epochs, display_every=100):
7276

7377
main_model = MainModel()
7478
print(f"Training commenced")
75-
train_model(main_model, train_dl, 100)
79+
train_model(main_model, train_dl, EPOCHS)
7680

7781
# saving weights for later
7882
torch.save(main_model, "model.pt")

0 commit comments

Comments
 (0)