avatarRachit Jain

Summary

The article discusses the process of saving and loading model checkpoints in PyTorch to resume training from an intermediate state, covering the importance, contents, and methods of checkpointing.

Abstract

The blog post addresses the necessity of saving and loading model checkpoints in PyTorch, particularly when training deep learning models. It emphasizes the importance of preserving model states to prevent loss of progress due to interruptions such as system crashes or the need for additional training epochs. The author explains the essential components of a checkpoint, including model parameters, the number of epochs, and optimizer parameters, especially for adaptive learning rate optimizers like Adam. The article provides code examples for saving and loading checkpoints, demonstrating how to seamlessly resume training from a saved state, ensuring that the model continues from the exact point it was at when training was interrupted.

Opinions

  • The author believes that saving intermediate model states is crucial for preventing potential disasters, such as losing hours of training due to unforeseen circumstances.
  • There is an opinion that saving the number of epochs is beneficial for logging and tracking the total training duration.
  • The author suggests that saving optimizer parameters is often overlooked but is vital for resuming training, particularly with optimizers that adjust learning rates.
  • The article implies that maintaining a separate record of the best-performing model during training is a good practice.
  • The author expresses a preference for using .pt or .pth file extensions when saving checkpoints in PyTorch.
  • The author encourages community engagement and learning by inviting readers to correct any mistakes in the comments section.

Saving and Loading Your Model to Resume Training in PyTorch

Photo by Émile Perron on Unsplash

I just finished training a deep learning model to create embeddings for song lyrics and ran into multiple problems while trying to resume training my model from a particular state. Hence, I decided to write a blog and share my learnings with the community. So in this post, we will be talking about how to save your model in the form of checkpoints and how to load them back to resume training your model.

Why though?

Well, the first question that you might ask is why do you even need to resume training models?! Aren’t models supposed to be saved only when we are done training them??! Imagine a case when you have been training your model for hours and suddenly the machine crashes or you lose connection to the remote GPU that you have been training your model on. Disaster right? Consider another case that you trained your model for certain epochs which already took a considerable amount of time, but you are not satisfied with the performance and you wish you had trained it for more epochs.

Enter Checkpointing and resuming training!

It is clear that there is a need to save intermediate model states and have a mechanism to resume training. We call these intermediate model states as Checkpoints (as you might have already guessed). Before getting into the details of how to save them, lets first see what they are made of!

What are the contents of a Checkpoint?

Being an intermediate model state from where we need to resume training, we need to make sure that we save in all the information that a model uses during training.

  1. Model Parameters Since model parameters are being optimized while training a model, we would like to save them so that when we resume training we already have them optimized till a particular step and training can continue from there.
  2. Number of Epochs It is a good measure to save the number of epochs for logging purposes and keeping track of how many epochs we have run in total.
  3. Optimizer Parameters Yes. You read it right. You need to save optimizer parameters especially when you are using Adam as your optimizer. Adam is an adaptive learning rate method, which means, it computes individual learning rates for different parameters which you would need if you would like to continue your training from where you left off!

Saving a Checkpoint

Now that we know of the contents, let's save the checkpoint. Pytorch makes it very easy to save checkpoints.

Note that .pt or .pth are common and recommended file extensions for saving files using PyTorch.

Let's go through the above block of code. It saves the state to the specified checkpoint directory. In addition to that, if the model is the best model till now, the same checkpoint is also copied to another directory to keep track of the best model.

To use the above function, add the following lines of code in the training loop

This saves the model in the desired location which can be read later using the function in the next section.

Loading a checkpoint

Much similar to saving a checkpoint, loading is checkpoint is assisted by functions in PyTorch.

The above function reads the checkpoint file and loads the previously saved model state and optimizer state to an instance of model and optimizer. What loading a state essentially means is that it sets the model/optimizer parameters to the values as present in the saved checkpoint.

You can load a model to resume training by adding the following lines of code before starting the training loop.

Basically, you first initialize your model and optimizer and then update the state dictionaries using the load checkpoint function.

Now you can simply pass this model and optimizer to your training loop and you would notice that the model resumes training from where it left off. You can confirm this by looking at the loss values after each epoch, which is in continuation of the previously observed epochs (before training stopped).

That’s about it. Thanks for reading. I hope this helps. Please don’t hesitate to correct any mistakes in the comments section. I would really like to learn and improve.

Machine Learning
Deep Learning
Pytorch
Checkpoints
Model Training
Recommended from ReadMedium