Estimating an Optimal Learning Rate For a Deep Neural Network
The learning rate is one of the most important hyper-parameters to tune for training deep neural networks.
In this post, I’m describing a simple and powerful way to find a reasonable learning rate that I learned from fast.ai Deep Learning course. I’m taking the new version of the course in person at University of San Francisco. It’s not available to the general public yet, but will be at the end of the year at course.fast.ai (which currently has the last year’s version).
How does learning rate impact training?
Deep learning models are typically trained by a stochastic gradient descent optimizer. There are many variations of stochastic gradient descent: Adam, RMSProp, Adagrad, etc. All of them let you set the learning rate. This parameter tells the optimizer how far to move the weights in the direction opposite of the gradient for a mini-batch.
If the learning rate is low, then training is more reliable, but optimization will take a lot of time because steps towards the minimum of the loss function are tiny.
If the learning rate is high, then training may not converge or even diverge. Weight changes can be so big that the optimizer overshoots the minimum and makes the loss worse.
The training should start from a relatively large learning rate because, in the beginning, random weights are far from optimal, and then the learning rate can decrease during training to allow more fine-grained weight updates.
There are multiple ways to select a good starting point for the learning rate. A naive approach is to try a few different values and see which one gives you the best loss without sacrificing speed of training. We might start with a large value like 0.1, then try exponentially lower values: 0.01, 0.001, etc. When we start training with a large learning rate, the loss doesn’t improve and probably even grows while we run the first few iterations of training. When training with a smaller learning rate, at some point the value of the loss function starts decreasing in the first few iterations. This learning rate is the maximum we can use, any higher value doesn’t let the training converge. Even this value is too high: it won’t be good enough to train for multiple epochs because over time the network will require more fine-grained weight updates. Therefore, a reasonable learning rate to start training from will be probably 1–2 orders of magnitude lower.
There must be a smarter way
Leslie N. Smith describes a powerful technique to select a range of learning rates for a neural network in section 3.3 of the 2015 paper “Cyclical Learning Rates for Training Neural Networks” .
The trick is to train a network starting from a low learning rate and increase the learning rate exponentially for every batch.

Record the learning rate and training loss for every batch. Then, plot the loss and the learning rate. Typically, it looks like this:

First, with low learning rates, the loss improves slowly, then training accelerates until the learning rate becomes too large and loss goes up: the training process diverges.
We need to select a point on the graph with the fastest decrease in the loss. In this example, the loss function decreases fast when the learning rate is between 0.001 and 0.01.
Another way to look at these numbers is calculating the rate of change of the loss (a derivative of the loss function with respect to iteration number), then plot the change rate on the y-axis and the learning rate on the x-axis.

It looks too noisy, let’s smooth it out using simple moving average.

This looks better. On this graph, we need to find the minimum. It is close to lr=0.01.
Implementation
Jeremy Howard and his team at USF Data Institute developed fast.ai, a deep learning library that is a high-level abstraction on top of PyTorch. It’s an easy to use and yet powerful toolset for training state of the art deep learning models. Jeremy uses the library in the latest version of the Deep Learning course (fast.ai).
The library provides an implementation of the learning rate finder. You need just two lines of code to plot the loss over learning rates for your model:





