2. Saving and Loading Models

This matters because training is not useful if you cannot resume, evaluate, or ship the result. Model persistence is the boundary between an experiment in memory and a reproducible artifact you can reuse.

[1]:
from torchvision import datasets, models, transforms
import torch.optim as optim
import torch.nn as nn
from torchvision.transforms import *
from torch.utils.data import DataLoader
import torch
import numpy as np

def train(dataloader, model, criterion, optimizer, scheduler, num_epochs=20):
    for epoch in range(num_epochs):
        optimizer.step()
        scheduler.step()
        model.train()

        running_loss = 0.0
        running_corrects = 0

        n = 0
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                loss.backward()
                optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            n += len(labels)

        epoch_loss = running_loss / float(n)
        epoch_acc = running_corrects.double() / float(n)

        print(f'epoch {epoch}/{num_epochs} : {epoch_loss:.5f}, {epoch_acc:.5f}')

np.random.seed(37)
torch.manual_seed(37)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pretrained=True
num_classes = 3
num_epochs = 20

transform = transforms.Compose([Resize(224), ToTensor()])
image_folder = datasets.ImageFolder('./shapes/train', transform=transform)
dataloader = DataLoader(image_folder, batch_size=4, shuffle=True, num_workers=4)

model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Rprop(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

train(dataloader, model, criterion, optimizer, scheduler, num_epochs=num_epochs)
epoch 0/20 : 1.18291, 0.66667
epoch 1/20 : 1.89373, 0.56667
epoch 2/20 : 0.41106, 0.80000
epoch 3/20 : 0.09141, 0.96667
epoch 4/20 : 0.09910, 0.96667
epoch 5/20 : 0.08258, 0.96667
epoch 6/20 : 0.06175, 0.96667
epoch 7/20 : 0.34240, 0.86667
epoch 8/20 : 0.03592, 1.00000
epoch 9/20 : 0.15507, 0.93333
epoch 10/20 : 0.40221, 0.96667
epoch 11/20 : 0.07072, 0.96667
epoch 12/20 : 0.44840, 0.93333
epoch 13/20 : 0.01021, 1.00000
epoch 14/20 : 0.00262, 1.00000
epoch 15/20 : 0.00727, 1.00000
epoch 16/20 : 0.00639, 1.00000
epoch 17/20 : 0.05421, 0.96667
epoch 18/20 : 0.03431, 1.00000
epoch 19/20 : 0.00771, 1.00000

2.1. Saving

2.1.1. Saving just the model

[2]:
torch.save(model.state_dict(), './output/resnet18-model.pt')

2.1.2. Saving for later training

[3]:
torch.save({
    'model_state_dict': model.state_dict(),
    'criterion_state_dict': criterion.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict()
}, './output/resnet18-checkpoint.pt')

2.1.3. Saving to ONNX

[4]:
args = torch.randn(4, 3, 224, 224, device=device)
f = './output/alexnet.onnx'

torch.onnx.export(model, args, f, verbose=False)

2.2. Loading

2.2.1. Loading just the model

[5]:
model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

model.load_state_dict(torch.load('./output/resnet18-model.pt', map_location=device))
[5]:
<All keys matched successfully>

2.2.2. Loading for training continuation

[6]:
checkpoint = torch.load('./output/resnet18-checkpoint.pt', map_location=device)

model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Rprop(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

model.load_state_dict(checkpoint['model_state_dict'])
criterion.load_state_dict(checkpoint['criterion_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

train(dataloader, model, criterion, optimizer, scheduler, num_epochs=num_epochs)
epoch 0/20 : 0.14301, 0.96667
epoch 1/20 : 0.00260, 1.00000
epoch 2/20 : 2.75971, 0.76667
epoch 3/20 : 0.19595, 0.96667
epoch 4/20 : 0.11255, 0.96667
epoch 5/20 : 0.24430, 0.96667
epoch 6/20 : 0.49671, 0.93333
epoch 7/20 : 0.49788, 0.90000
epoch 8/20 : 0.44765, 0.86667
epoch 9/20 : 0.03913, 0.96667
epoch 10/20 : 0.01076, 1.00000
epoch 11/20 : 10.49290, 0.83333
epoch 12/20 : 0.03003, 0.96667
epoch 13/20 : 0.22657, 0.96667
epoch 14/20 : 0.00002, 1.00000
epoch 15/20 : 0.00087, 1.00000
epoch 16/20 : 0.20941, 0.96667
epoch 17/20 : 0.00000, 1.00000
epoch 18/20 : 12.09500, 0.76667
epoch 19/20 : 0.02187, 1.00000