avatarDagang Wei

Summary

The provided content discusses the significance of the chain rule in machine learning, particularly in optimization algorithms like backpropagation and gradient descent, and illustrates its application in logistic regression.

Abstract

The article "Essential Math for Machine Learning: The Chain Rule" delves into the fundamental role of the chain rule in the optimization processes of machine learning algorithms. It explains the chain rule as a calculus technique for differentiating composite functions, which is crucial for training neural networks through backpropagation. The chain rule enables the calculation of derivatives for complex functions with respect to their parameters, facilitating the adjustment of weights and biases in models like logistic regression. By applying the chain rule, machine learning practitioners can effectively compute gradients, which are essential for minimizing error in models using gradient descent. The article also provides a practical example using Python's SymPy library to demonstrate the mathematical operations involved in gradient calculation for logistic regression, reinforcing the theoretical concepts with tangible code.

Opinions

  • The author emphasizes the importance of the chain rule as a foundational mathematical concept in machine learning, without which optimization algorithms would not function effectively.
  • The article suggests that understanding the chain rule is key to grasping how machine learning models "learn" from data, implying that a deep comprehension of this concept is essential for practitioners in the field.
  • The use of the Python SymPy library in the provided example code serves to illustrate the practical utility of the chain rule in real-world machine learning applications, showcasing the author's view on the relevance of theoretical knowledge in practical scenarios.
  • By presenting the chain rule within the context of both simple (logistic regression) and complex (neural networks) models, the author conveys the universality of the chain rule's application across various machine learning architectures.

Essential Math for Machine Learning: The Chain Rule

The Change Lineage Tracker

Image generated with DALL-E

This article is part of the series Essential Math for Machine Learning.

Introduction

In the complex world of machine learning, mathematics plays the role of a guiding star. It helps algorithms learn and optimize. One vital mathematical concept that powers the optimization process in nearly all machine learning models is the chain rule. Let’s explore what it is, why it’s so important, and how it works its magic.

What is the Chain Rule?

At its core, the chain rule is a calculus technique for calculating the derivative of composite functions. A composite function is just a function within a function. For example, let’s say we have a function:

h(x) = (x² + 1

You can break this down into two “nested” functions:

  • Outer function: f(u) = u³
  • Inner function: g(x) = x² + 1

The chain rule tells us how to find the derivative of h(x), or how a change in the input ‘x’ impacts the final output of our composite function. It states:

dh/dx = df/du * du/dx

In simpler terms:

  • Derivative of outer function (with the inner function plugged in): df/du
  • Derivative of inner function: du/dx

Let’s see how the chain rule works:

h(x) = (x² + 1f(u) = u³   
g(x) = x² + 1

Derivative of outer function: df/du = 3u²

Derivative of inner function: dg/dx = 2x

Apply the chain rule: dh/dx = df/du * du/dx = 3(x² + 1)² * 2x

Why Chain Rule Matters in Machine Learning

Neural networks, the powerhouse behind many machine learning models, are essentially giant webs of composite functions. Understanding how to find the derivatives of these complex functions with respect to their parameters (weights and biases) is the key to making those models “learn.”

Here’s where the chain rule comes in:

  • Backpropagation: The heart of neural network training is an algorithm called backpropagation. It calculates how much each parameter in the network contributes to the overall error. This computation of “error signals” involves repeatedly applying the chain rule to efficiently trace those errors back through layers of the network.
  • Gradient Descent: Once we know how an error is linked to parameters, we want to make adjustments. Gradient descent, an optimization algorithm heavily used in machine learning, takes these error signals (gradients) and updates the parameters in a direction that minimizes the error. The chain rule makes this gradient calculation possible.

Example: Logistic Regression

Let’s explore the chain rule and gradient descent within the context of logistic regression.

The Model

In logistic regression, we aim to model the probability of a binary outcome (e.g., 0 or 1). The core building blocks are:

  • Linear Combination: First, a linear combination of features (with weights and bias) is computed: z = w*x + b
  • Sigmoid Function: The linear combination ‘z’ is passed through the sigmoid function to map it into a probability between 0 and 1: σ(z) = 1 / (1 + exp(-z))
  • Loss Function: Commonly used is the binary cross-entropy loss function, comparing the model’s predicted probability with the true label.

Gradient Calculation with the Chain Rule

For a single data point, let’s say the true label is ‘y’ and the model’s predicted probability is ŷ. Here’s how the chain rule connects the calculation of gradients:

  1. Loss with respect to ŷ: This derivative relates to the binary cross-entropy loss function itself.
  2. ŷ with respect to z: This is the derivative of the sigmoid function.
  3. z with respect to w (and similarly for b): This is just the derivative of a linear function.

Python Code with SymPy

The code is available in this colab notebook.

import sympy as sp

# Variables
x = sp.symbols('x')
w = sp.symbols('w')
b = sp.symbols('b')
y = sp.symbols('y')  # True label

# Linear combination
z = w * x + b

# Sigmoid activation
y_hat = 1 / (1 + sp.exp(-z))

# Binary cross-entropy loss 
loss = -(y * sp.log(y_hat) + (1 - y) * sp.log(1 - y_hat))

# Gradients (chain rule in action!)
gradient_w = sp.diff(loss, w)
gradient_b = sp.diff(loss, b)

# Display
print("Gradient with respect to w:", gradient_w)
print("Gradient with respect to b:", gradient_b)

Output:

Gradient with respect to w: -x*y*exp(-b - w*x)/(exp(-b - w*x) + 1) + x*(1 - y)*exp(-b - w*x)/((1 - 1/(exp(-b - w*x) + 1))*(exp(-b - w*x) + 1)**2)
Gradient with respect to b: -y*exp(-b - w*x)/(exp(-b - w*x) + 1) + (1 - y)*exp(-b - w*x)/((1 - 1/(exp(-b - w*x) + 1))*(exp(-b - w*x) + 1)**2)

The Math Behind the Code

Let’s break down the mathematical formulas represented in the SymPy code for logistic regression.

Notations

  • x: Input feature (or vector of features)
  • w: Weight (or vector of weights)
  • b: Bias term
  • y: True label (0 or 1)
  • ŷ: Predicted probability (output of the model)
  • z: Linear combination of input features and weights (input to the sigmoid)

Formulas

Linear Combination:

z = w * x + b

Sigmoid Function:

ŷ = σ(z) = 1 / (1 + exp(-z))

Binary Cross-Entropy Loss:

loss = -(y * log(ŷ) + (1 - y) * log(1 - ŷ))

Derivative of loss = -(y * log(ŷ) + (1 — y) * log(1 — ŷ)) with respect to ŷ:

∂loss/∂ŷ  =  -(y / ŷ) + ((1 - y) / (1 - ŷ))

Derivative of ŷ = σ(z) = 1 / (1 + exp(-z)) with respect to z:

∂ŷ/∂z = σ(z) * (1 - σ(z))

Derivative of z = w * x + b with respect to w:

∂z/∂w = x

Gradient of loss with respect to w:

∂loss/∂w = ∂loss/∂ŷ * ∂ŷ/∂z * ∂z/∂w

Similar process is for b.

Chain Rule:

  • The gradients are where the chain rule shines. The derivative of the loss function with respect to ‘w’ (or ‘b’) involves “peeling back” the layers of the logistic regression model:
  • From loss to predicted probability (ŷ)
  • From predicted probability (ŷ) to the linear combination (z)
  • From the linear combination (z) to the parameters (w and b)

Conclusion

The Chain Rule is a cornerstone of calculus that finds extensive application in machine learning. It is key to understanding and implementing the optimization algorithms that enable machine learning models to learn from data. By breaking down the process of computing derivatives of composite functions into manageable steps, the Chain Rule provides a clear path for gradient calculation in complex models, making it an indispensable tool for anyone delving into the field of machine learning. Whether you are training a simple linear regression model or a complex deep neural network, a solid grasp of the Chain Rule will greatly enhance your understanding of how these models learn and improve.

Machine Learning
Recommended from ReadMedium