Saving and Loading Your Model to Resume Training in PyTorch
I just finished training a deep learning model to create embeddings for song lyrics and ran into multiple problems while trying to resume training my model from a particular state. Hence, I decided to write a blog and share my learnings with the community. So in this post, we will be talking about how to save your model in the form of checkpoints and how to load them back to resume training your model.
Why though?
Well, the first question that you might ask is why do you even need to resume training models?! Aren’t models supposed to be saved only when we are done training them??! Imagine a case when you have been training your model for hours and suddenly the machine crashes or you lose connection to the remote GPU that you have been training your model on. Disaster right? Consider another case that you trained your model for certain epochs which already took a considerable amount of time, but you are not satisfied with the performance and you wish you had trained it for more epochs.
Enter Checkpointing and resuming training!
It is clear that there is a need to save intermediate model states and have a mechanism to resume training. We call these intermediate model states as Checkpoints (as you might have already guessed). Before getting into the details of how to save them, lets first see what they are made of!
What are the contents of a Checkpoint?
Being an intermediate model state from where we need to resume training, we need to make sure that we save in all the information that a model uses during training.
- Model Parameters Since model parameters are being optimized while training a model, we would like to save them so that when we resume training we already have them optimized till a particular step and training can continue from there.
- Number of Epochs It is a good measure to save the number of epochs for logging purposes and keeping track of how many epochs we have run in total.
- Optimizer Parameters Yes. You read it right. You need to save optimizer parameters especially when you are using Adam as your optimizer. Adam is an adaptive learning rate method, which means, it computes individual learning rates for different parameters which you would need if you would like to continue your training from where you left off!
Saving a Checkpoint
Now that we know of the contents, let's save the checkpoint. Pytorch makes it very easy to save checkpoints.





