Skip to content

Commit a74e559

Browse files
author
Saurav Agarwal
committed
add progress bar to trainer
1 parent 0902c2e commit a74e559

File tree

2 files changed

+82
-53
lines changed

2 files changed

+82
-53
lines changed

python/coverage_control/nn/trainers/trainer.py

Lines changed: 77 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,20 @@
2323
Train a model using pytorch
2424
"""
2525

26-
import time
2726
from copy import deepcopy
2827

2928
import torch
3029

30+
from rich.progress import (
31+
Progress,
32+
BarColumn,
33+
TextColumn,
34+
TaskProgressColumn,
35+
TimeRemainingColumn,
36+
TimeElapsedColumn,
37+
MofNCompleteColumn,
38+
)
39+
3140
__all__ = ["TrainModel"]
3241

3342

@@ -70,7 +79,6 @@ def __init__(
7079
self.num_epochs = num_epochs
7180
self.device = device
7281
self.model_dir = model_dir
73-
self.start_time = time.time()
7482

7583
def load_saved_model_dict(self, model_file: str) -> None:
7684
"""
@@ -111,59 +119,82 @@ def train(self) -> None:
111119
# Initialize the loss history
112120
train_loss_history = []
113121
val_loss_history = []
114-
start_time = time.time()
115122

116123
best_model_state_dict = None
117124
best_train_model_state_dict = None
118125

119-
# Train the model
120-
121-
for epoch in range(self.num_epochs):
122-
# Training
123-
train_loss = self.train_epoch()
124-
train_loss_history.append(train_loss)
126+
columns = [
127+
BarColumn(),
128+
TaskProgressColumn(),
129+
TextColumn("[progress.description]{task.description}"),
130+
MofNCompleteColumn(),
131+
TextColumn("[bold]Loss ", justify="right"),
132+
TextColumn("[bold blue]T:[/] {task.fields[train_loss]:>.2e}"),
133+
TextColumn("[bold blue]V:[/] {task.fields[val_loss]:>.2e}"),
134+
TextColumn("[bold blue]B:[/] {task.fields[best_val_loss]:>.2e}"),
135+
TextColumn("[bold blue]@[/] {task.fields[best_epoch]:<3.0f}"),
136+
TimeRemainingColumn(),
137+
TimeElapsedColumn(),
138+
]
139+
140+
val_loss = float("Inf")
141+
train_loss = float("Inf")
142+
best_val_loss_epoch = -1
143+
144+
with Progress(*columns) as progress:
145+
epoch_task = progress.add_task(
146+
"[bold blue]Training",
147+
total=self.num_epochs,
148+
auto_refresh=False,
149+
train_loss=train_loss,
150+
val_loss=val_loss,
151+
best_val_loss=best_val_loss,
152+
best_epoch=best_val_loss_epoch,
153+
)
154+
155+
best_train_model_state_dict = deepcopy(self.model.state_dict())
156+
for epoch in range(self.num_epochs):
157+
# Training
158+
train_loss = self.train_epoch()
159+
train_loss_history.append(train_loss)
160+
161+
if train_loss < best_train_loss:
162+
best_train_loss = train_loss
163+
best_train_model_state_dict = deepcopy(self.model.state_dict())
164+
best_train_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss}
165+
166+
if self.val_loader is not None:
167+
val_loss = self.validate_epoch(self.val_loader)
168+
val_loss_history.append(val_loss)
169+
170+
if val_loss < best_val_loss:
171+
best_val_loss = val_loss
172+
best_model_state_dict = deepcopy(self.model.state_dict())
173+
best_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": val_loss}
174+
best_val_loss_epoch = epoch
175+
176+
if (epoch + 1) % 5 == 0:
177+
model_state_dict = self.model.state_dict()
178+
torch.save(model_state_dict, self.model_dir + "/model_epoch" + str(epoch) + ".pt")
179+
model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss}
180+
torch.save(model_data, self.model_dir + "/model_data_epoch" + str(epoch) + ".pt")
181+
182+
progress.update(
183+
epoch_task,
184+
advance=1,
185+
train_loss=train_loss,
186+
val_loss=val_loss,
187+
best_val_loss=best_val_loss,
188+
best_epoch=best_val_loss_epoch,
189+
)
190+
progress.refresh()
191+
192+
torch.save(val_loss_history, self.model_dir + "/val_loss.pt")
125193
torch.save(train_loss_history, self.model_dir + "/train_loss.pt")
126-
# Print the loss
127-
print(f"Epoch: {epoch + 1}/{self.num_epochs} ",
128-
f"Training Loss: {train_loss:.3e} ")
129-
130-
# Validation
131-
132-
if self.val_loader is not None:
133-
val_loss = self.validate_epoch(self.val_loader)
134-
val_loss_history.append(val_loss)
135-
torch.save(val_loss_history, self.model_dir + "/val_loss.pt")
136-
137-
# Save the best model
138-
139-
if val_loss < best_val_loss:
140-
best_val_loss = val_loss
141-
best_model_state_dict = deepcopy(self.model.state_dict())
142-
best_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": val_loss}
143-
# torch.save(self.model.state_dict(), self.model_dir + "/model.pt")
144-
# torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer.pt")
145-
print(f"Epoch: {epoch + 1}/{self.num_epochs} ",
146-
f"Validation Loss: {val_loss:.3e} ",
147-
f"Best Validation Loss: {best_val_loss:.3e}")
148-
149-
if train_loss < best_train_loss:
150-
best_train_loss = train_loss
151-
best_train_model_state_dict = deepcopy(self.model.state_dict())
152-
best_train_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss}
153-
154-
if (epoch + 1) % 5 == 0:
155-
model_state_dict = self.model.state_dict()
156-
torch.save(model_state_dict, self.model_dir + "/model_epoch" + str(epoch) + ".pt")
157-
model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss}
158-
torch.save(model_data, self.model_dir + "/model_data_epoch" + str(epoch) + ".pt")
159-
160194
torch.save(best_model_state_dict, self.model_dir + "/model.pt")
161195
torch.save(best_model_data, self.model_dir + "/model_data.pt")
162196
torch.save(best_train_model_state_dict, self.model_dir + "/model_train.pt")
163197
torch.save(best_train_model_data, self.model_dir + "/model_train_data.pt")
164-
elapsed_time = time.time() - start_time
165-
# Print elapsed time in minutes
166-
print(f"Elapsed time: {elapsed_time / 60:.2f} minutes")
167198

168199
# Train the model in batches
169200
def train_epoch(self) -> float:
@@ -198,11 +229,6 @@ def train_epoch(self) -> float:
198229
# Calculate the loss
199230
loss = self.criterion(output, target)
200231

201-
# Print batch number and loss
202-
203-
if batch_idx % 10 == 0:
204-
print(f"Batch: {batch_idx}, Loss: {loss:.3e} ")
205-
206232
# Backward propagation
207233
loss.backward()
208234

utils/scripts/run.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ SCRIPT_DIR="${CoverageControl_ws}/src/CoverageControl/python"
66
# Set the parameters directory based on the environment variable
77
PARAMS_DIR="${CoverageControl_ws}/lpac/params/"
88

9+
# Set env size
10+
ENV_SIZE=1024
11+
912
# Define the parameter file names
1013
DATA_PARAMS_FILE="data_params.toml"
1114
# DATA_GEN_ALGORITHM="--algorithm CentralizedCVT"
@@ -58,10 +61,10 @@ fi
5861
# Edit and execute process_data.sh
5962

6063
# Running the data generation script
61-
run_command "python ${SCRIPT_DIR}/data_generation/data_generation.py ${PARAMS_DIR}/${DATA_PARAMS_FILE} ${DATA_GEN_ALGORITHM} --split True" "Data Generation"
64+
# run_command "python ${SCRIPT_DIR}/data_generation/data_generation.py ${PARAMS_DIR}/${DATA_PARAMS_FILE} ${DATA_GEN_ALGORITHM} --split True" "Data Generation"
6265

6366
# Running the training script
64-
run_command "python ${SCRIPT_DIR}/training/train_lpac.py ${PARAMS_DIR}/${LEARNING_PARAMS_FILE} 1024" "Model Training"
67+
run_command "python ${SCRIPT_DIR}/training/train_lpac.py ${PARAMS_DIR}/${LEARNING_PARAMS_FILE} ${ENV_SIZE}" "Model Training"
6568

6669
# Running the evaluation script
6770
run_command "python ${SCRIPT_DIR}/evaluators/eval.py ${PARAMS_DIR}/${EVAL_PARAMS_FILE}" "Model Evaluation"

0 commit comments

Comments
 (0)