avatarCameron R. Wolfe, Ph.D.

Summary

The provided content offers an introduction to reinforcement learning (RL) and its application in training large language models (LLMs), emphasizing the importance of RL in AI research and the need for a better understanding of RL algorithms among AI practitioners.

Abstract

The web content delves into the basics of reinforcement learning (RL), particularly its relevance in the context of training state-of-the-art large language models (LLMs). It underscores the gap in open-source research on language models, which predominantly focuses on supervised learning strategies, neglecting the potential of RL. The article suggests that this gap is due to the complexity of RL, the need for curated human preference data, and the prevalent use of supervised learning methods. It aims to bridge this knowledge gap by building a foundational understanding of RL, starting with basic definitions and algorithms, and progressing to modern RL algorithms like PPO, which are used for fine-tuning language models with reinforcement learning from human feedback (RLHF). The content also discusses the formal framework of RL, including the Markov Decision Process (MDP), and introduces key terms and definitions. It concludes by providing a brief overview of Q-Learning and Deep Q-Learning, illustrating how these algorithms can be applied to practical applications, including the training of language models.

Opinions

  • The author expresses that AI practitioners often avoid reinforcement learning due to a lack of understanding, preferring to use more familiar supervised learning approaches.
  • The article cites a preference for supervised annotation among practitioners, but it also highlights the cost and time effectiveness of reinforcement learning as observed in research.
  • There is an opinion that reinforcement learning is not as commonly used as supervised learning, which contributes to skepticism and a lack of emphasis on RL in language model training.
  • The author conveys enthusiasm about the potential of RL, stating that it allows learning from arbitrary feedback on a neural network's output, which can be used to teach language models to exhibit qualities like helpfulness, harmlessness, and honesty.
  • The author suggests that the application of RL to language models is a drastic shift from traditional RL use cases but maintains that the core principles of RL are versatile and applicable across various domains.
  • The article implies that the concept of discounting future rewards in RL is complex and rooted in both practical and mathematical considerations.
  • There is an opinion that the distinction between on-policy and off-policy learning is nuanced but important for understanding the learning process in RL.
  • The author endorses the use of an ε-greedy policy in RL for balancing exploration and exploitation, which is considered a key strategy for effective learning.
  • The article posits that Q-Learning is a foundational RL algorithm that is simple to understand and serves as a good introduction to the field, with a proven mathematical guarantee for finding an optimal policy in finite MDPs.
  • The author expresses the view that Deep Q-Learning (DQL) is a scalable solution for complex environments where traditional Q-Learning's lookup table approach becomes intractable, and that DQL's use of neural networks is a significant advancement.
  • The author highlights the importance of the target network in DQL for providing a stable training target, drawing a parallel to the concept of knowledge distillation in deep learning.
  • The article concludes with the author's belief that reinforcement learning is an incredibly powerful learning approach that, once understood, can unlock new possibilities beyond supervised learning for improving practical applications like LLMs and recommendation systems.

Basics of Reinforcement Learning for LLMs

Understanding the problem formulation and basic algorithms for RL

(Photo by Ricardo Gomez Angel on Unsplash)

Recent AI research has revealed that reinforcement learning — more specifically, reinforcement learning from human feedback (RLHF) — is a key component of training a state-of-the-art large language model (LLM). Despite this fact, most open-source research on language models heavily emphasizes supervised learning strategies, such as supervised fine-tuning (SFT). This lack of emphasis upon reinforcement learning can be attributed to several factors, including the necessity to curate human preference data or the amount of data needed to perform high-quality RLHF. However, one undeniable factor that likely underlies skepticism towards reinforcement learning is the simple fact that it is not as commonly-used compared to supervised learning. As a result, AI practitioners (including myself!) avoid reinforcement learning due to a simple lack of understanding — we tend to stick with using the approaches that we know best.

“Many among us expressed a preference for supervised annotation, attracted by its denser signal… However, reinforcement learning proved highly effective, particularly given its cost and time effectiveness.” — from [8]

This series. In the next few overviews, we will aim to eliminate this problem by building a working understanding of reinforcement learning from the ground up. We will start with basic definitions and approaches — covered in this overview — and work our way towards modern algorithms (e.g., PPO) that are used to finetune language models with RLHF. Throughout this process, we will explore example implementations of these ideas, aiming to demystify and normalize the use of reinforcement learning in the language modeling domain. As we will see, these ideas are easy to use in practice if we take the time to understand how they work!

What is Reinforcement Learning?

Comparison of supervised and reinforcement learning (created by author)

At the highest level, reinforcement learning (RL) is just another way of training a machine learning model. In prior overviews, we have seen a variety of techniques for training neural networks, but the two most commonly-used techniques for language models are supervised and self-supervised learning.

(Self-)Supervised Learning. In supervised learning, we have a dataset of inputs (i.e., a sequence of text) with corresponding labels (e.g., a classification or completion of the input text), and we want to train our model to accurately predict those labels from the input. For example, maybe we want to finetune a language model (e.g., BERT) to classify sentences that contain explicit language. In this case, we can obtain a dataset of sentences with binary labels indicating whether the sentence contains explicit language or not. Then, we can train our language model to classify this data correctly by iteratively:

  1. Sampling a mini-batch of data from the dataset.
  2. Predicting the labels with the model.
  3. Computing the loss (e.g., CrossEntropy).
  4. Backpropagating the gradient through the model.
  5. Performing a weight update.

Self-supervised learning is similar to the setup explained above, but there are no explicit labels within our dataset. Rather, the “labels” that we use are already present within the input data. For example, language models are pretrained with a self-supervised language modeling objective that trains the model to predict the next token given prior tokens as input. Here, the next token is already present within the data (assuming that we have access to the full textual sequence).

(from [2])

When is RL useful? Although RL is just another way of training a neural network, the training setup is different compared to supervised learning. Similarly to how humans learn, RL trains neural networks through trial and error. More specifically, the neural network will produce an output, receive some feedback about this output, then learn from the feedback. For example, when finetuning a language model with reinforcement learning from human feedback (RLHF), the language model produces some text then receives a score/reward from a human annotator that captures the quality of that text; see above. Then, we use RL to finetune the language model to generate outputs with high scores.

The environment is not differentiable within RL (created by author)

In this case, we cannot apply a loss function that trains the language model to maximize human preferences with supervised learning. Why? Well, the score that we get from the human is a bit of a black box. There’s no easy way for us to explain this score or connect it mathematically to the output of the neural network. In other words, we cannot backpropagate a loss applied to this score through the rest of the neural network. This would require that we are able to differentiate (i.e., compute the gradient of) the system that generates the score, which is a human that subjectively evaluates the generated text; see above.

Big picture. The above discussion starts to provide us with insight as to why RL is such a beautiful and promising learning algorithm for neural networks. RL allows us to learn from signals that are non-differentiable and, therefore, not compatible with supervised learning. Put simply, this means that we can learn from arbitrary feedback on a neural network’s output! In the case of RLHF, we can score the outputs generated by a language model according to any principle that we have in mind. Then, we can use RL to learn from these scores, no matter how we choose to define them! In this way, we can teach a language model to be helpful, harmless, honest, more capable (e.g., by using tools), and much more.

A Formal Framework for RL

The agent acts and receives rewards (and new states) from the environment (created by author)

Problems that are solved via RL tend to be structured in a similar format. Namely, we have an agent that is interacting with an environment; see above. The agent has a state in the environment and produces actions, which can modify the current state, as output. As the agent interacts with the environment, it can receive both positive and negative rewards for its actions. The agent’s goal is to maximize the rewards that it receives, but there is not a reward associated with every action taken by the agent! Rather, rewards may have a long horizon, meaning that it takes several correct, consecutive actions to generate any positive reward.

Markov Decision Process (MDP)

To make things more formal and mathematically sound, we can formulate the system described above as a Markov Decision Process (MDP). Within an MDP, we have states, actions, rewards, transitions, and a policy; see below.

Components of an MDP (created by author)

States and actions have discrete values, while rewards are real numbers. In an MDP, we define two types of functions: transition and policy functions. The policy takes a state as input, then outputs a probability distribution over possible actions. Given this output, we can make a decision for the action to be taken from a current state, and the transition is then a function that outputs the next state based upon the prior state and chosen action. Using these components, the agent can interact with the environment in an iterative fashion; see below.

Structure of an MDP (created by author)

One thing we might be wondering here is: What is the difference between the agent and the policy? The distinction is a bit nuanced. However, we can think of the agent as implementing the policy within its environment. The policy describes how the agent chooses its next action given the current state. The agent follows this strategy as it interacts with the environment, and our goal is to learn a policy that maximizes the reward that the agent receives from the environment.

As the agent interacts with the environment, we form a “trajectory” of states and actions that are chosen throughout this process. Then, given the reward associated with each of these states, we get a total return given by the equation below, where γ is the discount factor (more explanation coming soon). This return is the summed reward across the agent’s full trajectory, but rewards achieved at later time steps are exponentially discounted by the factor γ; see below.

Trajectory and the return (created by author)

The goal of RL is to train an agent that maximizes this return. As shown by the equation below, we can characterize this as finding a policy that maximizes the return over trajectories that are sampled from the final policy.

Objective being solved by RL (created by author)

Example application. As a simplified example of the setup described above, let’s consider training a neural network to navigate a 2 X 3 grid from some initial state to some final state; see below. Here, we see in the grid that the agent will receive a reward of +10 for reaching the desired final state and a reward of -10 whenever it visits the red square.

A simplistic RL environment (created by author)

Our environment is the 2 X 3 grid and the state is given by the current position within this grid—we can represent this as a one-hot vector. We can implement our policy with a feed-forward neural network that takes the current one-hot position as input and predicts a probability distribution over potential actions (i.e., move up, move down, move left, move right). For each chosen action, the transition function simply moves the agent to the corresponding next position on the grid and avoids allowing the agent to move out of bounds. The optimal agent learns to reach the final state without passing through the red square; see below.

The optimal (largest return) solution path (created by author)

Like many problems that are solved with RL, this setup has an environment that is not differentiable (i.e., we can’t compute a gradient and train the model in a supervised fashion) and contains long-term dependencies, meaning that we might have to learn how to perform several sequential actions to get any reward.

Great… but how does this apply to language models? The application described above is a traditional use case for RL, including an agent/policy that learns to interact with an external (potentially simulated) environment. There are numerous examples of such successful applications of RL; e.g., Atari [3], Go, autonomous driving [4] and more. However, RL has recently been leveraged for finetuning language models. Although this is a drastically different use case, the components discussed above can be easily translated to language modeling!

Next token prediction with a language model (created by author)

As has been discussed extensively in prior overviews, language models specialize in performing next token prediction; see above. In other words, our language model takes several tokens as input (i.e., a prefix) and predicts the next token based on this context. When generating text at inference time, this is done autoregressively, meaning that the language model continually:

  1. Predicts the next token.
  2. Adds the next token to the current input sequence.
  3. Repeats.

To view this setup from the lens of RL, we can consider our language model to be the policy. Our state is just the current textual sequence. Given this state as input, the language model can produce an action — the next token — that modifies the current state to produce the next state — the textual sequence with an added token. Once a full textual sequence has been produced, we can obtain a reward by rating the quality of the language model’s output, either with a human or a reward model that has been trained over human preferences.

Although this setup is quite different from learning to navigate a simple grid (i.e., the model, data modality, environment, and reward function are all completely different!), we begin to see that the problem formulation used for RL is quite generic. There are many different problems that we can solve using this approach!

Important Terms and Definitions

Now that we understand the basic setup that is used for RL, we should overview some of the common terms we might see when studying this area of research. We have outlined a few notable terms and definitions below.

Trajectory: A trajectory is simply a sequence of states and actions that describe the path taken by an agent through an environment.

Episode: Sometimes, the environment we are exploring has a well-defined end state; e.g., reaching the final destination in our 2 X 3 grid. In these cases, we refer to the trajectory of actions and states from start to end state as an episode.

Discounting rewards when computing the return (created by author)

Return: Return is just the reward summed over an entire trajectory. As shown above, however, this sum typically includes a discount factor. Intuitively, this means that current rewards are more valuable than later rewards, due to both uncertainty and the simple fact that waiting to receive a reward is less desirable.

Discount factor: The concept of discounting goes beyond RL (e.g., discounting is a core concept in finance) and refers to the basic idea of determining the current value of a future reward. As shown by the equation above, we handle discounting in RL via an exponential discount factor. Although the intuitive explanation of the discount factor is easy to understand, the exact formulation we see above is rooted in mathematics and is actually a complex topic of discussion; see here.

On vs. Off-Policy: In RL, we have a target policy that describes the policy our agent is aiming to learn. Additionally, we have a behavior policy that is being used by the agent to select actions as it interacts with the environment. The distinction between on and off-policy learning is subtle, but the difference between these two approaches lies in whether the behavior policy used to select actions as the agent navigates the environment during RL is the same (on-policy) as the target policy that we are trying to evaluate and improve or not (off-policy).

ε-Greedy Policy: RL trains a neural network via interaction with an environment. The policy that this neural network implements takes a current state as input and produces a probability distribution over potential actions as output. But, how do we choose which action to actually execute? One of the most common approaches is an ε-greedy policy, which selects the action with the highest expected return most of the time (i.e., with probability 1 — ε) and a random action otherwise. Such an approach balances exploration and exploitation by allowing the agent to explore new actions in addition to those that it knows to work well.

Q-Learning: A Simple Introduction to RL

Now that we understand the framework that is typically used to solve problems with RL, let’s take a look at our first RL algorithm. This algorithm, called Q-Learning, is simple to understand and a great intro into the topic. Once we understand Q-Learning, we will also study our first deep RL algorithm (i.e., a system that trains a deep neural network with RL), called Deep Q-Learning.

Q-Learning: Modeling Q Values with a Lookup Table

Q-Learning is a model-free RL algorithm, meaning that we don’t have to learn a model for the environment with which the agent interacts. Concretely, this means that we don’t have to train a model to estimate the transition or reward functions — these are just given to us as the agent interacts with the environment. The goal of Q-Learning is to learn the value of any action at a particular state. We do this through learning a Q function, which defines the value of a state-action pair as the expected return of taking that action at the current state under a certain policy and continuing afterwards according to the same policy.

Storing Q values for state-action pairs in a lookup table (created by author)

To learn this Q function, we create a lookup table for state-action pairs. Each row in this table represents a unique state, and each column represents a unique action; see above. The values within each of entry of the lookup table represent the Q value (i.e., the output of the Q function) for a particular state-action pair. These Q values are initialized as zero and updated — using the Bellman equation — as the agent interacts with the environment until they become optimal.

High-level depiction of the Q-learning algorithm (created by author)

The algorithm. The first step of Q-learning is to initialize our Q values as zero and pick an initial state with which to start the learning process. Then, we iterate over the following steps (depicted above):

  1. Pick an action to execute from the current state (using an ε-Greedy Policy).
  2. Get a reward and next state from the (model-free) environment.
  3. Update the Q value based on the Bellman equation.

As shown in the figure above, our update to the Q value considers the reward of the current action, the Q value of the current state, and the Q value of the next state. However, given that our agent might execute several actions within the next state, it is unclear which Q value we should use for the next state when performing our update. In Q-learning, we choose to use the maximum Q value, as shown below.

Q-learning update rule based on Bellman equation (created by author)

Interestingly, Q-learning utilizes an ε-greedy policy when selecting actions, allowing new states and actions to be explored with a certain probability. When computing Q value updates, however, we always consider the next action with the maximum Q value, which may or may not be executed from the next state. In other words, Q-learning estimates the return for state-action pairs by assuming a greedy policy that just selects the highest-return action at the next state, even though we don’t follow such an approach when actually selecting an action. For this reason, Q-learning is an off-policy learning algorithm; see here for more details.

Brief mathematical note. The update rule used for Q-learning is mathematically guaranteed to find an optimal policy for any (finite) MDP — meaning that we will get a policy that maximizes our objective given sufficient iterations of the above process. An approachable and (almost) self-contained proof of this result is provided in [5].

Deep Q-Learning

The foundation of Deep Q-learning (DQL) lies in the vanilla Q-learning algorithm described above. DQL is just an extension of Q-learning for deep reinforcement learning, meaning that we use an approach similar to Q-learning to train a deep neural network. Given that we are now using a more powerful model (rather than a lookup table), Deep Q-Learning can actually be leveraged in interesting (but still relatively simple) practical applications. Let’s take a look at this algorithm and a few related applications that might be of interest.

Q-learning models a Q function with a lookup table, while Deep Q-learning models a Q function with a deep neural network (created by author)

The problem with Q-Learning. The size of the lookup table that we define for Q-learning is dependent upon the total number of states and actions that exist within an environment. In complex environments (e.g., high-resolution video games or real life), maintaining such a lookup table is intractable — we need a more scalable approach. DQL solves this problem by modeling the Q function with a neural network instead of a lookup table; see above. This neural network takes the current state as input and predicts the Q values of all possible actions from that state as output. DQL eliminates the need to store a massive lookup table! We just store the parameters of our neural network and use it to predict Q values.

Schematic depiction of DQL (created by author)

The algorithm. In DQL, we have two neural networks: the Q network and the target network. These networks are identical, but the exact architecture they use depend upon the problem being solved. To train these networks, we first gather data by interacting with the environment. This data is gathered using the current Q network with an ε-greedy policy. This process of gathering interaction data for training the Q network is referred to as experience replay; see above.

From here, we use data that has been collected to train the Q network. During each training iteration, we sample a batch of data and pass it through both the Q network and the target network. The Q network takes the current state as input and predicts the Q value of the action that is taken (i.e., predicted Q value), while the target network takes the next state as input and predicts the Q value of the best action that can be taken from that state (i.e., target Q value).

Loss function for DQL (created by author)

From here, we use the predicted Q value, the target Q value, and the observed reward to train the Q network with an MSE loss; see above. The target network is held fixed. Every several iterations, the weights of the Q network are copied to the target network, allowing this model to be updated as well. Then, we just repeat this process until the Q network converges. Notably, the dataset we obtain from experience replay is cumulative, meaning that we maintain all of the data we have observed from the environment throughout all iterations.

Why do we need the target network? The vanilla Q-learning framework leverages two Q values in its update rule: a (predicted) Q value for the current state-action pair and the (target) Q value of the best state-action pair for the next state. In DQL, we similarly have to generate both of these Q values. In theory, we could do this with a single neural network by making multiple passes through the Q network — one for the predicted Q value and one for the target Q value. However, the Q network’s weights are being updated at every training iteration, which would cause the target Q value to constantly fluctuate as the model is updated. To avoid this issue, we keep the target network separate and fixed, only updating its weights every several iterations to avoid creating a “moving target”.

(from [7])

This idea of using a separate network to produce a training target for another network — referred to as knowledge distillation [6] — is heavily utilized within deep learning. Furthermore, the idea of avoiding too much fluctuation in the weights of the teacher/target model has been addressed in this domain. For example, the mean teacher approach [7] updates the weights of the teacher model as an exponential moving average of the student network’s weights; see above. In this way, we can ensure a stable target is provided by the teacher during training.

Practical applications. DQL is a deep RL framework that has been used for several interesting practical applications. One early and notable demonstration of DQL was for playing Atari breakout. In [3], authors from Google DeepMind show that Deep Q-Learning is a useful approach for training a agents that can successfully beat simple video games. For a more hands on tutorial, check out the article below that explores a similar approach for Space Invaders.

Final Remarks

We now have a basic understanding of RL, the associated problem formulation, and how such problems can be solved by algorithms like (deep) Q-learning. Although reinforcement learning is a complex topic, the algorithms and formulations we have studied so far are quite simple. Over the course of coming overviews, we will slowly build upon these concepts, eventually arriving at the algorithms that are used today for finetuning language models. As we will see, RL is an incredibly powerful learning approach that can be used to improve a variety of practical applications from LLMs to recommendation systems. Gaining a deep understanding of this concept, its capabilities, and how it can be implemented unlocks an entire realm of new possibilities beyond supervised learning.

Connect with me!

Thanks so much for reading this article. I am Cameron R. Wolfe, Director of AI at Rebuy. I study the empirical and theoretical foundations of deep learning. If you liked this overview, subscribe to my Deep (Learning) Focus newsletter, where I help readers understand AI research via overviews of relevant topics from the ground up. You can also follow me on X and LinkedIn, or check out my other writings on medium!

Bibliography

[1] Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018).

[2] Bai, Yuntao, et al. “Training a helpful and harmless assistant with reinforcement learning from human feedback.” arXiv preprint arXiv:2204.05862 (2022).

[3] Mnih, Volodymyr, et al. “Playing atari with deep reinforcement learning.” arXiv preprint arXiv:1312.5602 (2013).

[4] Kiran, B. Ravi, et al. “Deep reinforcement learning for autonomous driving: A survey.” IEEE Transactions on Intelligent Transportation Systems 23.6 (2021): 4909–4926.

[5] Regehr, Matthew T., and Alex Ayoub. “An Elementary Proof that Q-learning Converges Almost Surely.” arXiv preprint arXiv:2108.02827 (2021).

[6] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. “Distilling the knowledge in a neural network.” arXiv preprint arXiv:1503.02531 (2015).

[7] Tarvainen, Antti, and Harri Valpola. “Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results.” Advances in neural information processing systems 30 (2017).

[8] Touvron, Hugo, et al. “Llama 2: Open foundation and fine-tuned chat models.” arXiv preprint arXiv:2307.09288 (2023).

Artificial Intelligence
Getting Started
Machine Learning
Reinforcement Learning
Llm
Recommended from ReadMedium