Skip to main content

Saving and Loading Models

In this tutorial, we shall discuss how to save and load models in PyTorch. This is an essential practice as it enables us to save our progress during training, re-use models, and share models with others.

Why Save and Load Models?

During the training of a deep learning model, there are chances that the process might get interrupted, or the model might start overfitting after a certain number of epochs. To prevent such scenarios, we save the model at regular intervals, also known as checkpoints. Furthermore, once the model is trained, we can save it for future use without having to train it again.

Saving Models in PyTorch

In PyTorch, we primarily save the model's state_dict (short for state dictionary). The state_dict is a Python dictionary object that maps each layer in the model to its trainable parameters (weights and biases). Let's look at how to do this.

# Assuming model is an instance of a PyTorch neural network
torch.save(model.state_dict(), 'model.pth')

Here 'model.pth' is the name of the saved model file. You can choose any name you like.

Loading Models in PyTorch

To load a model, we first need to initialize an instance of the same network structure. Then we load the state_dict into this instance.

# Assuming model is an instance of the same network structure
model.load_state_dict(torch.load('model.pth'))

Remember, the structure of the model used for loading should be the same as the one used for saving.

Saving & Loading a General Checkpoint for Inference and/or Resuming Training

For more flexibility, we can save more than just the model's state_dict. It can be useful to save the optimizer's state_dict, epochs, scores, etc. This information can help resume training later from the same point.

# Assuming optimizer is an instance of a PyTorch optimizer
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# Other possible items could be learning rate scheduler's state_dict, scores, etc.
}
torch.save(checkpoint, 'checkpoint.pth')

To load:

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Saving & Loading Model across Devices

When loading a model on a different device than the one it was trained on, you need to remap the storage location for the model's parameters.

For example, to load a model trained on GPU to CPU:

device = torch.device('cpu')
model = Model() # Initialize model
checkpoint = torch.load('checkpoint.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

To load a model trained on CPU to GPU:

device = torch.device("cuda")
model = Model() # Initialize model
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device) # Move model parameters to GPU

And that's it! You should now be familiar with saving and loading models in PyTorch. This practice is crucial in deep learning, whether you're conducting experiments, fine-tuning a model, or deploying it into production. Happy learning!