Interpretable Machine Learning
From Linear Regression to Shapley — A Journey into Model Interpretability
A Comprehensive Guide to Interpretability in Machine Learning
Machine learning has made remarkable strides in recent years, enabling us to solve complex problems and make predictions with unprecedented accuracy. However, as models grow in complexity, understanding why they make specific decisions becomes increasingly challenging. This is where interpretability in machine learning comes into play. Interpretability refers to the ability to explain and understand how a machine learning model arrives at its predictions.
In this blog, we’ll delve deep into the world of interpretability in machine learning. We’ll start by exploring the inherent explainability of Linear Regression, a simple yet powerful algorithm. Then, we’ll discuss feature importance, a crucial aspect of model interpretability, in Decision Trees, Random Forests, and XGBoost. Finally, we’ll explore advanced methods of feature importance, such as LIME, Shapley plots, and Permutation Feature Importance, along with the mathematics behind each method.
Inherent Explainability in Linear Regression
Linear Regression is one of the simplest and most interpretable machine learning algorithms. Its inherent explainability comes from its transparent mathematical model. In Linear Regression, we aim to find a linear relationship between the independent variables (features) and the dependent variable (target). The model’s prediction is a weighted sum of the features, each multiplied by a coefficient, plus an intercept term.
Mathematically, the prediction ŷ for a given set of features X is given by:
ŷ =b₀ + b₁X₁ + b₂X₂ + … + bₖXₖ
where:
- ŷ is the predicted output.
- b₀ is the intercept term.
- b₁, b₂, …, bₖ are the coefficients associated with each feature.
- X₁, X₂, …, Xₖ are the feature values.
The model aims to learn the optimal values of the coefficients (b₀, b₁, …, bₖ) that minimize the mean squared error between the predicted values and the actual target values. These coefficients are interpretable because they tell us how each feature impacts the prediction. A positive coefficient indicates that an increase in the corresponding feature leads to an increase in the prediction, while a negative coefficient implies the opposite.
For instance, in a Linear Regression model predicting house prices, a positive coefficient for the “number of bedrooms” feature suggests that more bedrooms typically lead to higher house prices.
In summary, Linear Regression’s inherent explainability stems from its simple, linear model, where the coefficients directly indicate the feature’s impact on predictions. This makes it easy to understand and interpret the model’s decisions.
Feature Importance in Decision Trees
Decision Trees are versatile and interpretable models used for both classification and regression tasks. They make decisions by recursively splitting the data into subsets based on the features, ultimately reaching leaf nodes that provide predictions. Understanding how features contribute to these decisions is crucial for interpretability.
Gini Impurity for Splitting
In Decision Trees, feature importance is often calculated based on the reduction in impurity achieved by splitting on a particular feature. The most common measure of impurity is the Gini Impurity.
Mathematically, the Gini Impurity for a node is defined as:
where:
- C is the number of classes (for classification tasks).
- pᵢ is the proportion of instances belonging to class i in the node.
To calculate feature importance, we consider the weighted average reduction in Gini Impurity across all nodes where the feature is used for splitting. The importance of a feature is determined by how much it reduces the Gini Impurity, with larger reductions implying greater importance.
Information Gain for Splitting
Another measure commonly used for feature importance in Decision Trees is Information Gain. Information Gain is based on the concept of entropy and measures the reduction in uncertainty achieved by splitting a feature.
Mathematically, Information Gain for a node is calculated as:
where:
- H(parent) is the entropy of the parent node.
- m is the number of child nodes after the split.
- N is the total number of instances in the parent node.
- Nᵢ is the number of instances in child node i.
- H(child_i) is the entropy of child node i.
Features that lead to the highest Information Gain are considered the most important because they contribute the most to reducing uncertainty in the decision.
Feature Importance in Random Forest
Random Forest is an ensemble learning method that combines multiple Decision Trees to improve predictive accuracy and interpretability. Feature importance in Random Forest can be calculated using the previously mentioned methods (Gini Impurity or Information Gain) across all trees and then aggregating the results.
Mean Decrease in Gini Impurity (MDI)
MDI calculates feature importance in a Random Forest by averaging the Gini Impurity reduction achieved by each feature across all trees. Features that consistently reduce impurity during splits across different trees are considered important.
Mathematically, MDI for a feature is calculated as:
where:
- Nₜₛ is the number of trees in the Random Forest.
- ΔGiniᵢ(feature) is the Gini Impurity reduction for a feature in tree i.
Mean Decrease in Accuracy (MDA)
MDA is another way to measure feature importance in Random Forest. It uses Permutation Feature Importance to estimate the importance of the feature. It calculates the decrease in accuracy (or increase in error) caused by randomly permuting the values of a feature across all trees.
Mathematically, MDA for a feature is calculated as:
where:
- Nₜₛ is the number of trees in the Random Forest.
- Accuracy_original is the accuracy of the model on the original data.
- Accuracy_permuted is the accuracy of the model when the values of the feature are randomly permuted in tree i.
Features that, when permuted, cause a significant drop in accuracy are considered important, as they provide valuable information for the model’s predictions.
Feature Importance in XGBoost
XGBoost (Extreme Gradient Boosting) is a powerful ensemble learning algorithm known for its efficiency and accuracy. It also offers a built-in method for calculating feature importance.
Gain-based Feature Importance
XGBoost measures feature importance based on the improvement in model performance (typically measured by the reduction in the loss function) that a feature brings when used in splits.
Mathematically, Gain-based Feature Importance for a feature is calculated as:
where:
- Gain_split is the improvement in the loss function achieved by a specific split.
- The number of times a feature is used in splits counts how often the feature is selected for splitting.
Features with higher Gain values are considered more important because they contribute more to reducing the overall loss.
Advanced Methods of Feature Importance
While the methods discussed so far are widely used for calculating feature importance, there are more advanced techniques that offer deeper insights into model interpretability. Let’s explore LIME and Shapley plots.
Local Interpretable Model-Agnostic Explanations (LIME)
LIME is a model-agnostic technique that focuses on explaining individual predictions rather than the entire model. It works by creating a locally faithful, interpretable model around a specific prediction. The idea is to approximate the complex model’s behaviour with a simpler, interpretable one.
Mathematically, LIME constructs a local interpretable model using techniques like linear regression or decision trees. It optimizes the model’s coefficients to minimize the difference between the original complex model’s prediction and the interpretable model’s prediction for the specific instance.
LIME formulates the problem as follows:
- Given a prediction function f(x) and an instance x₀, find a locally faithful model g(x) that approximates f(x) in the vicinity of x₀.
LIME optimises the following objective function:
Where:
- g(x) is the interpretable model.
- πₓ₀ is a distribution that samples instances around x₀.
- L(f, g, x’) is a loss function measuring the difference between f(x’) and g(x’) for a sampled instance x’.
- Ω(g) is a regularization term to ensure model simplicity.
LIME allows you to understand why a model made a particular prediction for a given input by examining the coefficients of the interpretable model. It’s especially useful for black-box models where direct interpretation is challenging.
Shapley Values and Shapley Plots
Shapley values are a powerful concept borrowed from cooperative game theory. They assign a value to each feature, representing its contribution to a prediction. Shapley values provide a holistic view of feature importance across all possible combinations of features.
Mathematically, the Shapley value for a feature is calculated by considering its marginal contribution in all possible feature subsets and averaging them over all permutations of the features.
The Shapley value for a feature i is defined as:
where:
- N is the set of all features.
- S is a subset of features, excluding feature i.
- f(S) represents the model’s prediction when considering the features in set S.
Shapley values provide a fair distribution of contributions among features by considering all possible permutations of feature combinations.
Shapley Plots visualize Shapley values for individual predictions. They show how each feature contributed to the prediction compared to a reference (usually the mean prediction). Positive contributions increase the prediction, while negative contributions decrease it.
Shapley values offer a deep understanding of feature interactions and their impact on predictions. They are widely used for explaining complex models like neural networks.
Conclusion
Interpretability is crucial for building trust in machine learning models, especially as they become more complex. Inherent explainability in Linear Regression provides a solid foundation for understanding model decisions. Feature importance in Decision Trees, Random Forests, and XGBoost helps us analyze how features contribute to predictions.
Additionally, advanced techniques like LIME and Shapley values offer deeper insights into model interpretability. Understanding the mathematical foundations behind these methods empowers data scientists and machine learning practitioners to make informed decisions and gain trust in their models.
By combining these interpretability tools with your machine learning workflows, you can not only build accurate models but also provide transparent and understandable explanations for their predictions, making AI more accessible and trustworthy.
References
- LIME — https://youtu.be/sgJOOcvT04w?feature=shared&t=283
- Shap Values — https://www.youtube.com/watch?v=-taOhqkiuIo&t=131s
- The mathematics behind Shapley Values — https://www.youtube.com/watch?v=UJeu29wq7d0
If you have read so far, a big thank you for reading! I hope you find this article to be helpful. If you’d like, add me on LinkedIn!
Good luck this week, Pratyush