Skip to content

Commit f4fb93c

Browse files
author
Saurav Agarwal
committed
update model saving
1 parent 793633e commit f4fb93c

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

python/coverage_control/nn/trainers/trainer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def train(self) -> None:
139139
if val_loss < best_val_loss:
140140
best_val_loss = val_loss
141141
best_model_state_dict = deepcopy(self.model.state_dict())
142+
best_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": val_loss}
142143
# torch.save(self.model.state_dict(), self.model_dir + "/model.pt")
143144
# torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer.pt")
144145
print(f"Epoch: {epoch + 1}/{self.num_epochs} ",
@@ -148,18 +149,18 @@ def train(self) -> None:
148149
if train_loss < best_train_loss:
149150
best_train_loss = train_loss
150151
best_train_model_state_dict = deepcopy(self.model.state_dict())
151-
# torch.save(self.model.state_dict(), self.model_dir + "/model_curr.pt")
152-
# torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer_curr.pt")
152+
best_train_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss}
153153

154-
if epoch % 5 == 0:
154+
if (epoch + 1) % 5 == 0:
155155
model_state_dict = self.model.state_dict()
156-
model_state_dict["epoch"] = epoch
157-
model_state_dict["optimizer_state_dict"] = self.optimizer.state_dict()
158-
model_state_dict["loss"] = train_loss
159156
torch.save(model_state_dict, self.model_dir + "/model_epoch" + str(epoch) + ".pt")
157+
model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss}
158+
torch.save(model_data, self.model_dir + "/model_data_epoch" + str(epoch) + ".pt")
160159

161160
torch.save(best_model_state_dict, self.model_dir + "/model.pt")
161+
torch.save(best_model_data, self.model_dir + "/model_data.pt")
162162
torch.save(best_train_model_state_dict, self.model_dir + "/model_train.pt")
163+
torch.save(best_train_model_data, self.model_dir + "/model_train_data.pt")
163164
elapsed_time = time.time() - start_time
164165
# Print elapsed time in minutes
165166
print(f"Elapsed time: {elapsed_time / 60:.2f} minutes")

utils/scripts/run.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
SCRIPT_DIR="${CoverageControl_ws}/src/CoverageControl/python"
55

66
# Set the parameters directory based on the environment variable
7-
PARAMS_DIR="${CoverageControl_ws}/lpac/params/"
7+
PARAMS_DIR="${CoverageControl_ws}/lpac_512/params/"
88

99
# Define the parameter file names
1010
DATA_PARAMS_FILE="data_params.toml"
1111
# DATA_GEN_ALGORITHM="--algorithm CentralizedCVT"
1212
LEARNING_PARAMS_FILE="learning_params.toml"
13-
EVAL_PARAMS_FILE="eval.toml"
13+
EVAL_PARAMS_FILE="eval_multi.toml"
1414

1515
# Function to print messages in red
1616
print_error() {
@@ -58,10 +58,10 @@ fi
5858
# Edit and execute process_data.sh
5959

6060
# Running the data generation script
61-
run_command "python ${SCRIPT_DIR}/data_generation/data_generation.py ${PARAMS_DIR}/${DATA_PARAMS_FILE} ${DATA_GEN_ALGORITHM} --split True" "Data Generation"
61+
# run_command "python ${SCRIPT_DIR}/data_generation/data_generation.py ${PARAMS_DIR}/${DATA_PARAMS_FILE} ${DATA_GEN_ALGORITHM} --split True" "Data Generation"
6262

6363
# Running the training script
64-
run_command "python ${SCRIPT_DIR}/training/train_lpac.py ${PARAMS_DIR}/${LEARNING_PARAMS_FILE} 1024" "Model Training"
64+
# run_command "python ${SCRIPT_DIR}/training/train_lpac.py ${PARAMS_DIR}/${LEARNING_PARAMS_FILE} 512" "Model Training"
6565

6666
# Running the evaluation script
6767
run_command "python ${SCRIPT_DIR}/evaluators/eval.py ${PARAMS_DIR}/${EVAL_PARAMS_FILE}" "Model Evaluation"

0 commit comments

Comments
 (0)