avatarBen Hui

Summary

The web content provides a guide on how to visualize Decision Trees and Random Forests using the pybaobabdt and pygraphviz packages in Python, with examples and code snippets for enhancing the interpretability of these models.

Abstract

The article introduces two Python packages, pybaobabdt and pygraphviz, which are essential for visualizing Decision Trees and Random Forests. It explains that while scikit-learn offers built-in functions to display a tree's decision path, these third-party packages provide more intuitive and visually appealing representations. The author demonstrates the use of pybaobabdt to visualize a Decision Tree, including how to manage the tree's size, depth, and use colormaps to highlight specific classes. Additionally, the article covers the visualization of Random Forests by iterating over each tree within the ensemble and plotting them individually, adjusting parameters such as depth for clarity. The guide aims to assist readers in creating comprehensive reports by making complex tree structures more accessible and understandable.

Opinions

  • The author suggests that visualizing decision paths is more intuitive and helpful for creating fancy reports.
  • Visualizing the entire Decision Tree might result in an image that is too large to be useful, suggesting the need to limit the depth of the tree for better visibility.
  • The use of colormaps is recommended to highlight specific classes within the Decision Tree, enhancing the visual distinction between different classes.
  • The process of visualizing Random Forests is described as similar to that of a single Decision Tree, with the added step of looping through each tree in the forest.
  • The article implies that the ability to visualize these complex models is crucial for understanding and interpreting their predictions.

How to visualize Decision Trees and Random Forest Trees?

Sklearn provides built-in functions to display the decision path of a tree’s model (https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html).

If we can visualize this path, it would be more intuitive and helpful for you to create a fancy report.

In this article, I am going to introduce 2 packages that help you to make this happen.

Installation:

pip install pybaobabdt
pip install pygraphviz

(1) Decision Tree Visualization

Dataset: https://github.com/lpfgarcia/ucipp/blob/master/uci/wine-quality-red.arff

import pybaobabdt
import pandas as pd
from scipy.io import arff
from sklearn.tree import DecisionTreeClassifier 

data = arff.loadarff('wine-quality-red.arff') # Import dataset
df = pd.DataFrame(data[0])

y = list(df['Class'])
features = list(df.columns)
features.remove('Class')
X = df.loc[:, features]

clf = DecisionTreeClassifier().fit(X, y)
ax = pybaobabdt.drawTree(clf, size=10, dpi=300, features=features)  #Visualize the tree

It’s too large right? We can set the depth of it to make it more visible:

ax = pybaobabdt.drawTree(
    clf,
    size=10,
    dpi=100,
    maxdepth=6,  # The depth of the tree
    features=features)

By using colormap, we can also highlight specific class:

from matplotlib.colors import ListedColormap

ax = pybaobabdt.drawTree(
    clf,
    size=10,
    dpi=600,
    maxdepth=6,
    colormap=ListedColormap(["#01a2d9", "gray", "#d5695d",
                             "gray"]),  # Highlight Class 3 and 5
    features=features)

(2) Random Forest Visualization

Actually it is similar to visualize a decision tree. We can just use a loop to draw every single tree.

Dataset: https://github.com/renatopp/arff-datasets/blob/master/classification/vehicle.arff

import pybaobabdt
import pandas as pd
from scipy.io import arff
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier

data = arff.loadarff('vehicle.arff')  # Import dataset

df = pd.DataFrame(data[0])
y = list(df['Class'])
features = list(df.columns)
features.remove('Class')
X = df.loc[:, features]

clf = RandomForestClassifier(n_estimators=20, n_jobs=-1, random_state=0)
clf.fit(X, y)

size = (15, 15)
plt.rcParams['figure.figsize'] = size
fig = plt.figure(figsize=size, dpi=300)

for idx, tree in enumerate(clf.estimators_):
    ax1 = fig.add_subplot(5, 4, idx + 1)
    pybaobabdt.drawTree(tree,
                        model=clf,
                        size=15,
                        dpi=300,
                        maxdepth=8,  # Set the depth for each tree
                        features=features,
                        ax=ax1)

Thank you for reading.

Data
Data Analysis
Data Science
Data Visualization
Machine Learning
Recommended from ReadMedium