This article provides a practical example of how to save and load a model in PyTorch, allowing for continued training and inference.
Abstract
The goal of this article is to demonstrate how to save a model and load it to continue training after previous epochs and make predictions. The article assumes familiarity with deep learning and PyTorch. It uses Fashion_MNIST_data as the dataset and provides a complete flow from importing data to making predictions. The article explains the process of setting up, importing libraries and creating helper functions, importing the dataset and creating a data loader, defining and creating a model, training the network and saving the model, and finally, loading the model.
Opinions
The article is aimed at those familiar with deep learning and PyTorch.
The author emphasizes the importance of being able to save models, especially when using free cloud services such as Kaggle and Google Colab, which have idle timeouts and limit the time for training.
The author recommends saving both the latest checkpoint and the best checkpoint for flexibility.
The article uses a simple network from [1] and Adam optimizer with cross-entropy loss.
The author recommends calling model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.
The author provides references for further reading.
The article ends with a promotional message for an AI service.
How To Save and Load Model In PyTorch With A Complete Example
A practical example of how to save and load a model in PyTorch. We are going to look at how to continue training and load the model for inference
The goal of this article is to show you how to save a model and load it to continue training after previous epoch and make a prediction. If you are reading this article, I assume you are familiar with the basic of deep learning and PyTorch.
Have you experienced a situation where you spend hours or days training your model and then it stops in the middle? Or you are not satisfied with your model performance and want to train the model again? There are multiple reasons why we might need a flexible way to save and load our model.
Most of the free cloud services such as Kaggle, Google Colab, etc have idle time outs that will disconnect your notebook, plus the notebook will be disconnected or interrupted once it reaches its limit time. Unless you train for a small number of epochs with GPU, the process takes time. Being able to save the model gives you a huge advantage and save the day. To be flexible, I am going to save both the latest checkpoint and the best checkpoint.
Fashion_MNIST_data will be used as our dataset and we’ll write a complete flow from import data to make the prediction. In this exercise, I am going to use a Kaggle notebook.
Step 1: Setting up
By default in Kaggle, the notebook you are working on is called __notebook__.ipyn
Create two directories to store checkpoint and best model:
Step 2: Importing libraries and creating helper functions
Importing libraries
Saving function
save_ckp is created to save checkpoint, the latest one and the best one. This creates flexibility: either you are interested in the state of the latest checkpoint or the best checkpoint.
In our case, we want to save a checkpoint that allows us to use this information to continue our model training. Here is the information needed:
epoch: a measure of the number of times all of the training vectors are used once to update the weights.
valid_loss_min: the minimum validation loss, this is needed so that when we continue the training, we can start with this rather than np.Inf value.
state_dict: model architecture information. It includes the parameter matrices for each of the layers.
optimizer: You need to save optimizer parameters especially when you are using Adam as your optimizer. Adam is an adaptive learning rate method, which means, it computes individual learning rates for different parameters which we would need if we want to continue our training from where we left off [2].
Loading function
load_chkp is created for loading model. It takes:
location of the saved checkpoint
model instance that you want to load the state to
the optimizer
Step 3: Importing dataset Fashion_MNIST_data and creating data loader
The train function gives us the ability to set the number of epochs, the learning rate, and other parameters.
define loss function and optimizer
Below, we are using an Adam optimizer and cross entropy loss since we are looking at character class scores as output. We calculate the loss and perform back-propagation.
Define train method
Train the model
Output:
Epoch: 1 Training Loss: 0.000010 Validation Loss: 0.000044Validation loss decreased (inf --> 0.000044). Saving model ...
Epoch: 2 Training Loss: 0.000007 Validation Loss: 0.000040Validation loss decreased (0.000044 --> 0.000040). Saving model ...
Epoch: 3 Training Loss: 0.000007 Validation Loss: 0.000040Validation loss decreased (0.000040 --> 0.000040). Saving model ...
Let’s focus on a few parameters we used above:
start_epoch: value start of the epoch for the training
n_epochs: value end of the epoch for the training
valid_loss_min_input = np.Inf
checkpoint_path: full path to save state of latest checkpoint of the training
best_model_path: full path to best state of latest checkpoint of the training
After we load all the information we need, we can continue training, start_epoch = 4. Previously, we train the model from 1 to 3
Step 7: Continue Training and/or Inference
Continue training
We can continue to train our model using the train function and provide the values of checkpoint we get from the load_ckp function above.
Output:
Epoch: 4 Training Loss: 0.000006 Validation Loss: 0.000040Epoch: 5 Training Loss: 0.000006 Validation Loss: 0.000037Validation loss decreased (0.000040 --> 0.000037). Saving model ...
Epoch: 6 Training Loss: 0.000006 Validation Loss: 0.000036Validation loss decreased (0.000037 --> 0.000036). Saving model ...
Notice: epoch now start from 4 to 6. (start_epoch = 4)
The validation loss continues from the last training checkpoint.
at epoch 3, min validation loss is 0.000040
here, minimum validation loss starts with 0.000040 and not INF
Inference
Remember that you must call model.eval() to set dropout and batch, normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results [3].
Output:
Accuracy of the network on10000 test images: 86.58%
Where to find output/saved files in Kaggle notebook
In your Kaggle notebook, you can scroll down to the bottom of the page. There are files saved in the previous operations.