
Graph Neural Networks: Merging Deep Learning With Graphs (Part I)
When It Comes to Node Classification
Recently, Graph Neural Networks (GNNs) have received a lot of attention. From marketing to social science to biology, they have been widely promoted as the new way of learning “smartly” from data. It’s more than a trend, though, as many research papers have proven that they can actually lead to more accurate and robust models.
What could possibly explain this? This is certainly due to their ability to combine graphical representation learning (which is used today for a wider variety of use cases) with the predictive power of deep learning models.
Objective
This article is the first part of three-part series that aims to provide a comprehensive overview of the most common applications of GNN models to real-world problems.
While the first focuses on node classification, the two others tackle link prediction and graph classification, respectively.
After reading this article, you will understand:
- What is graphical representation learning all about?
- What are the main mechanisms hidden under GNNs models?
- How can they be applied to real-world classification problems?
The experimentations described in the article were carried out using the libraries PyTorch Geometric, NetworkX, igraph, and Plotly.
You can find the code here on GitHub.
1. Graphical Representation Learning: What Is It All About?
1.1. Motivation
The first question you are likely to ask yourself is why and when should you consider graphical representation learning to solve your use case?
Graphs provide a simple yet powerful tool to describe complex systems. In simple terms, it consists of representing a problem as a set of objects (nodes) along with a set of interactions (edges) between pairs of these objects.
Yet, the number of applications where data is represented in the form of graphs is important. Here are some examples:
- Recommender Systems: In e-commerce, you can represent interactions between users and products graphically and use this knowledge to make more relevant personalized recommendations.
- Social Networks: By using a graph to describe the relationship between users, you can train a model to detect fake accounts.
- Transport: Networks can be visualized as a graph and used as inputs by models to accurately forecast traffic speed, volume, or density.
- Chemistry: Molecules are usually modeled as graphs. By using this representation, you can predict their bioactivity for drug discovery purposes.
1.2. Benefits
The second question you might consider is: What do graph-based models bring to the table compared to “traditional” approaches?
The key advantage of using a graphical representation of a problem lies in its ability to represent both information about the points and relationships between nodes.
To put it more concretely, let’s consider the case where you would like to classify products sold in a store. You would probably gather information about the products (description, price, brand, etc.) and use it as input to train a model. But what if this information is non-existent or too poor to build a robust model?
In this context, you can leverage the graphical aspect of the problem. Each product can be represented as a node and each pair of products frequently bought together can be linked. A graph-based model is then likely to perform better than a “traditional” machine learning algorithm, as it would learn not only from information about products but also from the relationships between them. In fact, instead of considering each product independently, it would leverage this additional information to detect valuable patterns.
1.3. Prediction Tasks on Graphs
Before going further, it is important to distinguish between three main types of tasks for which graph-based models can be used for:
Node-level tasks: Node classification and regression
- Goal: Predict a label, type, category, or attribute of a node.
- Example: Given a large social network with millions of users, detect fake accounts.
Edge-level tasks: Link prediction
- Goal: Given a set of nodes and an incomplete set of edges between these nodes, infer the missing edges.
- Example: Predict biological interactions between proteins.
Graph-level tasks: Graph classification, regression, and clustering
- Goal: Carry a classification, regression, or clustering task over entire graphs.
- Example: Given a graph representing the structure of a molecule, predict molecules’ toxicity.
In the rest of the article, I will focus on node classification.
2. Node Classification With GNN: What Performance Should You Expect?
2.1. Description of the Use Case
Imagine that you run a large online knowledge-sharing platform such as Wikipedia. Every day, thousands of scientific articles are published.
To help your readers easily navigate the platform and find the content they are interested in, you need to make sure that each article is classified into the right category quickly after its publication.
In this context, the problem can be modeled as a graph where each node represents an article and has as an attribute an embedding of the content. Two articles are linked if one of them contains a link to the other. The goal is to predict the category of new articles.
This is thus a typical node classification task!

The dataset I will be using for experiments is Wiki-CS from the paper Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks. It consists of nodes corresponding to computer science articles, with edges based on hyperlinks and 10 classes that represent different branches of the field. You can download it using PyTorch datasets here.


Challenges
The nature of graph data poses a real challenge to existing deep learning models. Why?
- Non-Euclidean data. The usual deep learning toolbox does not apply directly to graph data. For instance, convolutional neural networks (CNNs) need grid-structured inputs such as images, while recurrent neural networks (RNNs) require sequences such as text.
- Variable shapes. Graphs are by nature irregular: They have different numbers of nodes, and nodes may have different numbers of neighbors. This makes operations that are easily computed in the other domains more difficult to apply in the graph domain.
- Permutation invariance: Operations applied to graph data must be permutation-invariant, i.e. independent of the order of neighbor nodes, as there is no specific way to order them.
- Internal dependence. One of the core assumptions of existing ML models is that instances are independent of each other. However, for graph data, this assumption is no longer valid as each instance (node) is related to others by links of various types, such as citations, friendships, and interactions.
These challenges motivate the need to introduce a new kind of deep learning architecture to apply deep learning methods over graphs.
It is also worth mentioning that the ‘traditional’ approach from ML models to learn from graph data is to include additional features that characterize the instance within the graph. These can be features like the nodes’ degree, their centrality, etc. You can read my article on fraud detection in which I applied such an approach. However, the feasibility and the success of such approaches are highly dependent on the use case (e.g., in some cases, the class of nodes may have no correlation with how central it is in the graph).
2.2. Graph Convolutional Networks (GCN)
What are the key concepts of the model?
The goal of a Graph Neural Network (GNN) model is to use all the information about the graph, namely nodes’ features and the connection between them, to learn a new representation for each of the nodes called node embeddings.
These node embeddings are low-dimensional vectors that summarize nodes’ positions in the graph and the structure of their local graph neighborhood. The embeddings can then be directly used to classify nodes.
To do so, GNNs rely on a message-passing framework. At each iteration, every node aggregates information from its local neighborhood.
- So after the first iteration (k = 1), every node embedding contains information from its 1-hop neighborhood, i.e., its immediate graph neighbors.
- After the second iteration (k = 2), every node embedding contains information from its 2-hop neighborhood, i.e. nodes that can be reached by a path of length 2 in the graph.
- etc.
As these iterations progress, each node embedding contains more and more information from further reaches of the graph.

What kind of “information” does a node embedding actually encode?
- Structural information about the graph (e.g., degrees of all the nodes in their k-hop neighborhood).
- Feature-based information about the nodes’ k-hop neighborhood.
One common issue with GNNs is over-smoothing: After multiple iterations of message passing, the representations for all the nodes in the graph can become very similar to one another. Natural ways to avoid this issue include maintaining a low number of layers or using skip connections.
What is the message passing framework about?
What is a message-passage layer composed of? How are embeddings updated at each iteration? During each message-passing iteration in a GNN, a hidden embedding h_u corresponding to each node u is updated according to information aggregated from u’s graph neighborhood N(u). The figure below illustrates the first iteration.

The message-passing update is expressed as follows:

The UPDATE and AGGREGATE functions vary depending on the model. For instance, GCN models rely on a symmetric-normalized aggregation as well as a self-loop update approach as shown in the previous figure.
What is the intuition behind this type of normalization? It provides more importance to neighborhood nodes that are not very connected (low degree). This is relevant for use cases such as the classification of pages as highly connected nodes tend to discuss very broad and general topics. Thus, they do not provide information that is truly useful for classifying the nodes to which they are linked.
GCN approach has proved to be one of the most popular and effective baselines for GNN architectures. How do they relate to the concept of convolution? Just as CNNs aggregate feature information from spatially-defined patches in an image, GNNs aggregate information based on local graph neighborhoods. The figure below illustrates the analogy.

How does the model perform compared to a “traditional” classifier?
To compare the performance of this approach to traditional machine learning models, I implemented a GCN-based model and compared it to a random forest classifier.
As shown in the code below, the GCN is composed of two graph convolutional layers with a non-linear transformation between them and a final softmax layer for multi-class classification.










