@@ -139,6 +139,7 @@ def train(self) -> None:
139139 if val_loss < best_val_loss :
140140 best_val_loss = val_loss
141141 best_model_state_dict = deepcopy (self .model .state_dict ())
142+ best_model_data = {"epoch" : epoch , "optimizer_state_dict" : deepcopy (self .optimizer .state_dict ()), "loss" : val_loss }
142143 # torch.save(self.model.state_dict(), self.model_dir + "/model.pt")
143144 # torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer.pt")
144145 print (f"Epoch: { epoch + 1 } /{ self .num_epochs } " ,
@@ -148,18 +149,18 @@ def train(self) -> None:
148149 if train_loss < best_train_loss :
149150 best_train_loss = train_loss
150151 best_train_model_state_dict = deepcopy (self .model .state_dict ())
151- # torch.save(self.model.state_dict(), self.model_dir + "/model_curr.pt")
152- # torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer_curr.pt")
152+ best_train_model_data = {"epoch" : epoch , "optimizer_state_dict" : deepcopy (self .optimizer .state_dict ()), "loss" : train_loss }
153153
154- if epoch % 5 == 0 :
154+ if ( epoch + 1 ) % 5 == 0 :
155155 model_state_dict = self .model .state_dict ()
156- model_state_dict ["epoch" ] = epoch
157- model_state_dict ["optimizer_state_dict" ] = self .optimizer .state_dict ()
158- model_state_dict ["loss" ] = train_loss
159156 torch .save (model_state_dict , self .model_dir + "/model_epoch" + str (epoch ) + ".pt" )
157+ model_data = {"epoch" : epoch , "optimizer_state_dict" : deepcopy (self .optimizer .state_dict ()), "loss" : train_loss }
158+ torch .save (model_data , self .model_dir + "/model_data_epoch" + str (epoch ) + ".pt" )
160159
161160 torch .save (best_model_state_dict , self .model_dir + "/model.pt" )
161+ torch .save (best_model_data , self .model_dir + "/model_data.pt" )
162162 torch .save (best_train_model_state_dict , self .model_dir + "/model_train.pt" )
163+ torch .save (best_train_model_data , self .model_dir + "/model_train_data.pt" )
163164 elapsed_time = time .time () - start_time
164165 # Print elapsed time in minutes
165166 print (f"Elapsed time: { elapsed_time / 60 :.2f} minutes" )
0 commit comments