avatarThiago Carvalho

Summary

The provided content discusses techniques for enhancing the visualization of cluster analysis using Python's Matplotlib library, with a focus on scatter plots and annotations to improve interpretability of clustered data.

Abstract

The article "Visualizing Clusters with Python’s Matplotlib" delves into the art of improving cluster visualizations to better understand cluster analysis results. It begins by acknowledging the long history of clustering algorithms, such as k-means, and then proceeds to demonstrate how to use scatter plots to visualize clusters effectively. The author uses a dataset of Pokémon stats to illustrate the process of clustering data and visualizing it using Matplotlib. The article covers the use of color-coding to represent different clusters, the challenges of visualizing multiple dimensions, and the application of 3D scatter plots and interactivity to address these challenges. It also emphasizes the importance of annotations, such as titles, labels, legends, and reference lines, to provide context and aid interpretation. The author further explores advanced visualization techniques, including the use of lines to connect data points to their respective centroids and the implementation of convex hulls to outline clusters, thereby providing a clearer representation of the data's structure. The article concludes by acknowledging the complexity of cluster visualization and encourages experimentation with different visualization methods to effectively communicate findings.

Opinions

  • The author believes that visualizing clusters is as important as the clustering results themselves and that scatter plots are a fundamental tool for this purpose.
  • There is an opinion that bubble charts and 3D scatter plots, while useful, have limitations and can sometimes be misleading.
  • The article suggests that interactivity in visualizations can enhance understanding and should be utilized when possible.
  • The author expresses that adding centroids and reference lines to scatter plots can significantly improve the interpretability of cluster distributions.
  • It is conveyed that convex hulls can be a powerful way to highlight the spread and separation of clusters, even though they should not be taken as precise measurements.
  • The author advocates for a thoughtful and iterative approach to visualization, suggesting that the best method depends on the specific data and insights to be communicated.
  • There is a subtle endorsement of the Matplotlib library for its ability to create both simple and complex visualizations with the help of additional toolkits.

Visualizing Clusters with Python’s Matplotlib

How to improve the visualization of your cluster analysis

Clustering sure isn’t something new. MacQueen developed the k-means algorithm in 1967, and since then, many other implementations and algorithms have been developed to perform the task of grouping data.

Scatter Plots — Image by the author

This article will explore how to improve our cluster’s visualization with scatter plots.

Scatter Plots

Let’s start by loading and preparing our data. I’ll use a dataset of Pokemon stats.

import pandas as pd
df = pd.read_csv('data/Pokemon.csv')
# prepare data
types = df['Type 1'].isin(['Grass', 'Fire', 'Water'])
drop_cols = ['Type 1', 'Type 2', 'Generation', 'Legendary', '#']
df = df[types].drop(columns = drop_cols)
df.head()
Data frame — Image by the author

Since this article isn’t so much about clustering as it is about visualization, I’ll use a simple k-means for the following examples.

We’ll calculate three clusters, get their centroids, and set some colors.

from sklearn.cluster import KMeans
import numpy as np
# k means
kmeans = KMeans(n_clusters=3, random_state=0)
df['cluster'] = kmeans.fit_predict(df[['Attack', 'Defense']])
# get centroids
centroids = kmeans.cluster_centers_
cen_x = [i[0] for i in centroids] 
cen_y = [i[1] for i in centroids]
## add to df
df['cen_x'] = df.cluster.map({0:cen_x[0], 1:cen_x[1], 2:cen_x[2]})
df['cen_y'] = df.cluster.map({0:cen_y[0], 1:cen_y[1], 2:cen_y[2]})
# define and map colors
colors = ['#DF2020', '#81DF20', '#2095DF']
df['c'] = df.cluster.map({0:colors[0], 1:colors[1], 2:colors[2]})

Then we can pass the fields we used to create the cluster to Matplotlib’s scatter and use the ‘c’ column we created to paint the points in our chart according to their cluster.

import matplotlib.pyplot as plt
plt.scatter(df.Attack, df.Defense, c=df.c, alpha = 0.6, s=10)
Scatter Plots— Image by the author

Cool. That’s the basic visualization of a clustered dataset, and even without much information, we can already start to make sense of our clusters and how they are divided.

Multiple Dimensions

We often use multiple variables to cluster our data and scatter plots can only display two variables. There are several options for visualizing more than three variables, but they all have disadvantages that should be considered.

We could use the markers’ size and make it a bubble chart, but that’s not an optimal solution. We couldn’t compare this third variable with the others since they would have different encodings.

For example, we can tell if a record has a higher Attack or Defense by looking at the chart we made earlier. But if we added Speed as the size, we couldn’t compare it with the other two variables.

plt.scatter(df.Attack, df.Defense, c=df.c, s=df.Speed, alpha = 0.6)
Bubble chart — Image by the author

3D plots can also encode a third variable, but it can also get confusing, sometimes even misleading — That’s because depending on how we look at the chart, it may give us the wrong impression.

3D Scatter Plot — Image by the author

Still, 3D scatter plots can be helpful, especially if they’re not static.

Depending on your environment, it’s easy to add some interactivity with Matplotlib.

Some IDEs will have this by default; other environments will require extensions and a magic command such as “Matplotlib Widget” on Jupyter Lab or “Matplotlib Notebook” on Jupyter notebooks.

By changing the angle we’re looking at the chart, we can examine it more carefully and avoid misinterpreting the data.

from mpl_toolkits.mplot3d import Axes3D
%matplotlib widget
colors = ['#DF2020', '#81DF20', '#2095DF']
kmeans = KMeans(n_clusters=3, random_state=0)
df['cluster'] = kmeans.fit_predict(df[['Attack', 'Defense', 'HP']])
df['c'] = df.cluster.map({0:colors[0], 1:colors[1], 2:colors[2]})
fig = plt.figure(figsize=(26,6))
ax = fig.add_subplot(131, projection='3d')
ax.scatter(df.Attack, df.Defense, df.HP, c=df.c, s=15)
ax.set_xlabel('Attack')
ax.set_ylabel('Defense')
ax.set_zlabel('HP')
plt.show()
3D Scatter Plots— Image by the author

Overall, they still are a pretty limited solution.

I think the best approach is to use multiple scatter plots, either in a matrix format or by changing between variables. You can also consider using some data reduction method such as PCA to consolidate your variables into a smaller number of factors.

Annotations

Now, let’s begin improving on our visualization.

If data visualization is storytelling, annotations are the equivalent of a narrator in our story. They should help the viewer understand and focus on what’s important while not taking too much space on the plot.

We’ll add the basics, a title, labels, and a legend.

### BUILD A TWO DIMENSIONS CLUSTER AGAIN ###
# k means
kmeans = KMeans(n_clusters=3, random_state=0)
df['cluster'] = kmeans.fit_predict(df[['Attack', 'Defense']])
# get centroids
centroids = kmeans.cluster_centers_
cen_x = [i[0] for i in centroids] 
cen_y = [i[1] for i in centroids]
## add to df
df['cen_x'] = df.cluster.map({0:cen_x[0], 1:cen_x[1], 2:cen_x[2]})
df['cen_y'] = df.cluster.map({0:cen_y[0], 1:cen_y[1], 2:cen_y[2]})
# define and map colors
colors = ['#DF2020', '#81DF20', '#2095DF']
df['c'] = df.cluster.map({0:colors[0], 1:colors[1], 2:colors[2]})
#####PLOT#####
from matplotlib.lines import Line2D
fig, ax = plt.subplots(1, figsize=(8,8))
# plot data
plt.scatter(df.Attack, df.Defense, c=df.c, alpha = 0.6, s=10)
# create a list of legend elemntes
## markers / records
legend_elements = [Line2D([0], [0], marker='o', color='w', label='Cluster {}'.format(i+1), 
               markerfacecolor=mcolor, markersize=5) for i, mcolor in enumerate(colors)]
# plot legend
plt.legend(handles=legend_elements, loc='upper right')
# title and labels
plt.title('Pokemon Stats\n', loc='left', fontsize=22)
plt.xlabel('Attack')
plt.ylabel('Defense')
Scatter Plot — Image by the author

Cool, now we can clearly get what this chart is about.

We can also give the viewer some reference points. Displaying the centroids and drawing reference lines to averages or a percentile can help explain our cluster.

from matplotlib.lines import Line2D
fig, ax = plt.subplots(1, figsize=(8,8))
# plot data
plt.scatter(df.Attack, df.Defense, c=df.c, alpha = 0.6, s=10)
# plot centroids
plt.scatter(cen_x, cen_y, marker='^', c=colors, s=70)
# plot Attack mean
plt.plot([df.Attack.mean()]*2, [0,200], color='black', lw=0.5, linestyle='--')
plt.xlim(0,200)
# plot Defense mean
plt.plot([0,200], [df.Defense.mean()]*2, color='black', lw=0.5, linestyle='--')
plt.ylim(0,200)
# create a list of legend elemntes
## average line
legend_elements = [Line2D([0], [0], color='black', lw=0.5, linestyle='--', label='Average')]
## markers / records
cluster_leg = [Line2D([0], [0], marker='o', color='w', label='Cluster {}'.format(i+1), 
               markerfacecolor=mcolor, markersize=5) for i, mcolor in enumerate(colors)]
## centroids
cent_leg = [Line2D([0], [0], marker='^', color='w', label='Centroid - C{}'.format(i+1), 
            markerfacecolor=mcolor, markersize=10) for i, mcolor in enumerate(colors)]
# add all elements to the same list
legend_elements.extend(cluster_leg)
legend_elements.extend(cent_leg)
# plot legend
plt.legend(handles=legend_elements, loc='upper right', ncol=2)
# title and labels
plt.title('Pokemon Stats\n', loc='left', fontsize=22)
plt.xlabel('Attack')
plt.ylabel('Defense')
Scatter Plot — Image by the author

It’s way easier to tell how the clusters are divided now.

The red cluster groups the highest attack and defence values, while the blue has the lowest, and the green group is generally closer to the average.

Lines

Illustrating how our cluster work can be as important as its results. In k-means, since we’re working with distances, connecting the points to their respective centroids can help us visualize what the algorithm is actually doing.

fig, ax = plt.subplots(1, figsize=(8,8))
# plot data
plt.scatter(df.Attack, df.Defense, c=df.c, alpha = 0.6, s=10)
# plot centroids
plt.scatter(cen_x, cen_y, marker='^', c=colors, s=70)
# plot lines
for idx, val in df.iterrows():
    x = [val.Attack, val.cen_x,]
    y = [val.Defense, val.cen_y]
    plt.plot(x, y, c=val.c, alpha=0.2)
# legend
legend_elements = [Line2D([0], [0], marker='o', color='w', label='Cluster {}'.format(i+1), 
                   markerfacecolor=mcolor, markersize=5) for i, mcolor in enumerate(colors)]
legend_elements.extend([Line2D([0], [0], marker='^', color='w', label='Centroid - C{}'.format(i+1), 
            markerfacecolor=mcolor, markersize=10) for i, mcolor in enumerate(colors)])
legend_elements.extend(cent_leg)
plt.legend(handles=legend_elements, loc='upper right', ncol=2)
# x and y limits
plt.xlim(0,200)
plt.ylim(0,200)
# title and labels
plt.title('Pokemon Stats\n', loc='left', fontsize=22)
plt.xlabel('Attack')
plt.ylabel('Defense')
Connected Scatter Plot — Image by the author

Now the relationship between the clusters and the centroids is totally explicit, and it’s easier to explain how the algorithm works.

We can also see how spread out the values in each cluster are.

For example, the red values appear to be farther away from their centroid than blue values. If the groups’ variance is something important to our analysis, a chart like this could be effective.

We should also note that the separation between green and blue wasn’t so evident in the previous visualizations.

Scatter Plot — Image by the author

Even though they have different colors and are connected to different places, those records circled in black are still more similar between themselves than most values in their own cluster.

This visualization makes it harder to perceive that and may give the impression that values from distinct clusters are totally different.

Convex Hull

Another option to help us visualize our clusters’ size or spread is to draw a shape around it or a shadow. Doing so manually would take forever and for sure wouldn’t be worth the effort.

Luckily, there are ways to automate that.

The convex hull is the smallest set of connections between our data points to form a polygon that encloses all the points, and there are ways to find the convex hull systematically — That is to say, we can use Sklearn to get the contour of our dataset.

from scipy.spatial import ConvexHull
import numpy as np
fig, ax = plt.subplots(1, figsize=(8,8))
# plot data
plt.scatter(df.Attack, df.Defense, c=df.c, alpha = 0.6, s=10)
# plot centers
plt.scatter(cen_x, cen_y, marker='^', c=colors, s=70)
# draw enclosure
for i in df.cluster.unique():
    points = df[df.cluster == i][['Attack', 'Defense']].values
    # get convex hull
    hull = ConvexHull(points)
    # get x and y coordinates
    # repeat last point to close the polygon
    x_hull = np.append(points[hull.vertices,0],
                       points[hull.vertices,0][0])
    y_hull = np.append(points[hull.vertices,1],
                       points[hull.vertices,1][0])
    # plot shape
    plt.fill(x_hull, y_hull, alpha=0.3, c=colors[i])
    
plt.xlim(0,200)
plt.ylim(0,200)
Highlighted scatter Plot — Image by the author

Great. We can even interpolate the lines of our polygon to make a smoother shape around our data.

from scipy import interpolate
fig, ax = plt.subplots(1, figsize=(8,8))
plt.scatter(df.Attack, df.Defense, c=df.c, alpha = 0.6, s=10)
plt.scatter(cen_x, cen_y, marker='^', c=colors, s=70)
    
for i in df.cluster.unique():
    # get the convex hull
    points = df[df.cluster == i][['Attack', 'Defense']].values
    hull = ConvexHull(points)
    x_hull = np.append(points[hull.vertices,0],
                       points[hull.vertices,0][0])
    y_hull = np.append(points[hull.vertices,1],
                       points[hull.vertices,1][0])
    
    # interpolate
    dist = np.sqrt((x_hull[:-1] - x_hull[1:])**2 + (y_hull[:-1] - y_hull[1:])**2)
    dist_along = np.concatenate(([0], dist.cumsum()))
    spline, u = interpolate.splprep([x_hull, y_hull], 
                                    u=dist_along, s=0, per=1)
    interp_d = np.linspace(dist_along[0], dist_along[-1], 50)
    interp_x, interp_y = interpolate.splev(interp_d, spline)
    # plot shape
    plt.fill(interp_x, interp_y, '--', c=colors[i], alpha=0.2)
    
plt.xlim(0,200)
plt.ylim(0,200)

*The interpolation method was based on replies from this thread. **I’ve added the argument per=1 to the splprep function as pointed out by Dragan Vidovic in the comments.

Highlighted Scatter Plot — Image by the author

We shouldn’t take those contours so seriously since they are not an actual measurement. But they still do a great job highlighting the clusters so no viewer can miss them.

Overall, there’s no simple solution to visualize clusters. Each case is unique, and we should experiment a lot before deciding what to display to our public.

It’s also important to mention that the examples I used were clean. I used a simple dataset and only two variables for clustering. In real cases, it won’t always look like this. Often, drawing the connections to the centroid can make our chart too polluted and almost unreadable.

It would be best to start with something simple: visualize every combination of variables and identify the more meaningful ones or the ones where you can naturally demonstrate your insights. Then you can experiment with other visualizations and techniques for highlighting what you found.

I hope you learned something and enjoyed my article. Thanks! More Tutorials | Cool Stuff | Twitter

Data Visualization
Data Science
Clustering
Python
Matplotlib
Recommended from ReadMedium