|
23 | 23 | Train a model using pytorch |
24 | 24 | """ |
25 | 25 |
|
26 | | -import time |
27 | 26 | from copy import deepcopy |
28 | 27 |
|
29 | 28 | import torch |
30 | 29 |
|
| 30 | +from rich.progress import ( |
| 31 | + Progress, |
| 32 | + BarColumn, |
| 33 | + TextColumn, |
| 34 | + TaskProgressColumn, |
| 35 | + TimeRemainingColumn, |
| 36 | + TimeElapsedColumn, |
| 37 | + MofNCompleteColumn, |
| 38 | +) |
| 39 | + |
31 | 40 | __all__ = ["TrainModel"] |
32 | 41 |
|
33 | 42 |
|
@@ -70,7 +79,6 @@ def __init__( |
70 | 79 | self.num_epochs = num_epochs |
71 | 80 | self.device = device |
72 | 81 | self.model_dir = model_dir |
73 | | - self.start_time = time.time() |
74 | 82 |
|
75 | 83 | def load_saved_model_dict(self, model_file: str) -> None: |
76 | 84 | """ |
@@ -111,59 +119,82 @@ def train(self) -> None: |
111 | 119 | # Initialize the loss history |
112 | 120 | train_loss_history = [] |
113 | 121 | val_loss_history = [] |
114 | | - start_time = time.time() |
115 | 122 |
|
116 | 123 | best_model_state_dict = None |
117 | 124 | best_train_model_state_dict = None |
118 | 125 |
|
119 | | - # Train the model |
120 | | - |
121 | | - for epoch in range(self.num_epochs): |
122 | | - # Training |
123 | | - train_loss = self.train_epoch() |
124 | | - train_loss_history.append(train_loss) |
| 126 | + columns = [ |
| 127 | + BarColumn(), |
| 128 | + TaskProgressColumn(), |
| 129 | + TextColumn("[progress.description]{task.description}"), |
| 130 | + MofNCompleteColumn(), |
| 131 | + TextColumn("[bold]Loss ", justify="right"), |
| 132 | + TextColumn("[bold blue]T:[/] {task.fields[train_loss]:>.2e}"), |
| 133 | + TextColumn("[bold blue]V:[/] {task.fields[val_loss]:>.2e}"), |
| 134 | + TextColumn("[bold blue]B:[/] {task.fields[best_val_loss]:>.2e}"), |
| 135 | + TextColumn("[bold blue]@[/] {task.fields[best_epoch]:<3.0f}"), |
| 136 | + TimeRemainingColumn(), |
| 137 | + TimeElapsedColumn(), |
| 138 | + ] |
| 139 | + |
| 140 | + val_loss = float("Inf") |
| 141 | + train_loss = float("Inf") |
| 142 | + best_val_loss_epoch = -1 |
| 143 | + |
| 144 | + with Progress(*columns) as progress: |
| 145 | + epoch_task = progress.add_task( |
| 146 | + "[bold blue]Training", |
| 147 | + total=self.num_epochs, |
| 148 | + auto_refresh=False, |
| 149 | + train_loss=train_loss, |
| 150 | + val_loss=val_loss, |
| 151 | + best_val_loss=best_val_loss, |
| 152 | + best_epoch=best_val_loss_epoch, |
| 153 | + ) |
| 154 | + |
| 155 | + best_train_model_state_dict = deepcopy(self.model.state_dict()) |
| 156 | + for epoch in range(self.num_epochs): |
| 157 | + # Training |
| 158 | + train_loss = self.train_epoch() |
| 159 | + train_loss_history.append(train_loss) |
| 160 | + |
| 161 | + if train_loss < best_train_loss: |
| 162 | + best_train_loss = train_loss |
| 163 | + best_train_model_state_dict = deepcopy(self.model.state_dict()) |
| 164 | + best_train_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss} |
| 165 | + |
| 166 | + if self.val_loader is not None: |
| 167 | + val_loss = self.validate_epoch(self.val_loader) |
| 168 | + val_loss_history.append(val_loss) |
| 169 | + |
| 170 | + if val_loss < best_val_loss: |
| 171 | + best_val_loss = val_loss |
| 172 | + best_model_state_dict = deepcopy(self.model.state_dict()) |
| 173 | + best_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": val_loss} |
| 174 | + best_val_loss_epoch = epoch |
| 175 | + |
| 176 | + if (epoch + 1) % 5 == 0: |
| 177 | + model_state_dict = self.model.state_dict() |
| 178 | + torch.save(model_state_dict, self.model_dir + "/model_epoch" + str(epoch) + ".pt") |
| 179 | + model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss} |
| 180 | + torch.save(model_data, self.model_dir + "/model_data_epoch" + str(epoch) + ".pt") |
| 181 | + |
| 182 | + progress.update( |
| 183 | + epoch_task, |
| 184 | + advance=1, |
| 185 | + train_loss=train_loss, |
| 186 | + val_loss=val_loss, |
| 187 | + best_val_loss=best_val_loss, |
| 188 | + best_epoch=best_val_loss_epoch, |
| 189 | + ) |
| 190 | + progress.refresh() |
| 191 | + |
| 192 | + torch.save(val_loss_history, self.model_dir + "/val_loss.pt") |
125 | 193 | torch.save(train_loss_history, self.model_dir + "/train_loss.pt") |
126 | | - # Print the loss |
127 | | - print(f"Epoch: {epoch + 1}/{self.num_epochs} ", |
128 | | - f"Training Loss: {train_loss:.3e} ") |
129 | | - |
130 | | - # Validation |
131 | | - |
132 | | - if self.val_loader is not None: |
133 | | - val_loss = self.validate_epoch(self.val_loader) |
134 | | - val_loss_history.append(val_loss) |
135 | | - torch.save(val_loss_history, self.model_dir + "/val_loss.pt") |
136 | | - |
137 | | - # Save the best model |
138 | | - |
139 | | - if val_loss < best_val_loss: |
140 | | - best_val_loss = val_loss |
141 | | - best_model_state_dict = deepcopy(self.model.state_dict()) |
142 | | - best_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": val_loss} |
143 | | - # torch.save(self.model.state_dict(), self.model_dir + "/model.pt") |
144 | | - # torch.save(self.optimizer.state_dict(), self.model_dir + "/optimizer.pt") |
145 | | - print(f"Epoch: {epoch + 1}/{self.num_epochs} ", |
146 | | - f"Validation Loss: {val_loss:.3e} ", |
147 | | - f"Best Validation Loss: {best_val_loss:.3e}") |
148 | | - |
149 | | - if train_loss < best_train_loss: |
150 | | - best_train_loss = train_loss |
151 | | - best_train_model_state_dict = deepcopy(self.model.state_dict()) |
152 | | - best_train_model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss} |
153 | | - |
154 | | - if (epoch + 1) % 5 == 0: |
155 | | - model_state_dict = self.model.state_dict() |
156 | | - torch.save(model_state_dict, self.model_dir + "/model_epoch" + str(epoch) + ".pt") |
157 | | - model_data = {"epoch": epoch, "optimizer_state_dict": deepcopy(self.optimizer.state_dict()), "loss": train_loss} |
158 | | - torch.save(model_data, self.model_dir + "/model_data_epoch" + str(epoch) + ".pt") |
159 | | - |
160 | 194 | torch.save(best_model_state_dict, self.model_dir + "/model.pt") |
161 | 195 | torch.save(best_model_data, self.model_dir + "/model_data.pt") |
162 | 196 | torch.save(best_train_model_state_dict, self.model_dir + "/model_train.pt") |
163 | 197 | torch.save(best_train_model_data, self.model_dir + "/model_train_data.pt") |
164 | | - elapsed_time = time.time() - start_time |
165 | | - # Print elapsed time in minutes |
166 | | - print(f"Elapsed time: {elapsed_time / 60:.2f} minutes") |
167 | 198 |
|
168 | 199 | # Train the model in batches |
169 | 200 | def train_epoch(self) -> float: |
@@ -198,11 +229,6 @@ def train_epoch(self) -> float: |
198 | 229 | # Calculate the loss |
199 | 230 | loss = self.criterion(output, target) |
200 | 231 |
|
201 | | - # Print batch number and loss |
202 | | - |
203 | | - if batch_idx % 10 == 0: |
204 | | - print(f"Batch: {batch_idx}, Loss: {loss:.3e} ") |
205 | | - |
206 | 232 | # Backward propagation |
207 | 233 | loss.backward() |
208 | 234 |
|
|
0 commit comments