flowchart LR
subgraph Layer ℓ — update node v
A[neighbor h_u] --> M[AGGREGATE<br/>sum / mean / max / attention]
B[neighbor h_w] --> M
C[neighbor h_x] --> M
M --> U[UPDATE<br/>combine with h_v, apply W + σ]
V[own h_v] --> U
U --> H[new h_v]
end
Chapter 40 — 🔗 Graph Machine Learning
📖 All chapters | ← 39 · 🌠 Frontier & Emerging Directions | 41 · 🤖 Robotics & Autonomy →
📚 Jump to any chapter
🧮 Mathematical Foundations
- 01 · 🧮 Linear Algebra
- 02 · ∂ Calculus & Differentiation
- 03 · 📉 Optimization
- 04 · 🎲 Probability & Statistics
🧭 The ML Workflow
🧩 Classical Machine Learning
- 08 · 📈 Regression
- 09 · 📐 Classification Algorithms
- 10 · 🌳 Ensemble Methods
- 11 · 🔮 Clustering & Unsupervised Learning
- 12 · 🎯 Model Evaluation & Tuning
🎲 Probabilistic Models
🧠 Deep Learning
- 14 · 🧠 Neural Networks (Core)
- 15 · 🖼️ Convolutional Neural Networks
- 16 · 🔁 Recurrent & Sequence Models
- 17 · ⚡ Attention & Transformers
- 18 · 🎨 Generative Models
🗣️ Applied AI: Vision, Language, Audio & Time
- 19 · 👁️ Computer Vision
- 20 · 💬 Natural Language Processing
- 21 · 🔊 Speech & Audio Processing
- 22 · ⏳ Time Series & Forecasting
- 23 · 📚 Large Language Models
- 24 · 🌈 Multimodal AI
🕹️ Reinforcement Learning
🛠️ Applied ML Systems & Industries
🚀 Production, Tooling & Infrastructure
📚 Classical & Symbolic AI
- 32 · 🧭 Search & Problem Solving
- 33 · 📖 Knowledge Representation & Reasoning
- 34 · 🗺️ Planning, Constraint Satisfaction & Game Playing
- 35 · 🧬 Evolutionary Computation & Metaheuristics
⚖️ Responsible AI & Frontier
- 36 · 🔍 Explainable AI & Interpretability
- 37 · 🧷 Causal Inference
- 38 · ⚖️ AI Ethics, Fairness & Safety
- 39 · 🌠 Frontier & Emerging Directions
🎓 Advanced & Specialized Topics
- 40 · 🔗 Graph Machine Learning
- 41 · 🤖 Robotics & Autonomy
- 42 · 📐 Learning Theory
- 43 · 🔎 Information Retrieval & Data Mining
- 44 · 🏗️ LLM Systems: Building LLMs from Scratch
🎚️ Post-Training & Fine-Tuning
- 45 · 🎚️ Post-Training I — Transfer, Fine-Tuning & PEFT
- 46 · 🏅 Post-Training II — Alignment & Evaluation
🚢 Model Serving & Deployment
Most of the models in this encyclopedia assume your data arrives in a tidy shape: a grid of pixels, a sequence of tokens, a table of rows. But a huge amount of the world’s interesting data is relational — molecules are atoms bonded together, social networks are people who follow people, knowledge bases are facts linking entities, road systems are intersections joined by streets. These are graphs, and their structure is the signal. This chapter is about learning directly on that structure: how to embed nodes, how to pass messages along edges, and how to predict missing links.
🧭 In context: representation learning on relational data · node/edge/graph prediction for molecules, social networks, knowledge bases, recommenders · learn from connectivity, not just features — by aggregating information along edges.
💡 Remember this: A graph neural network learns by repeatedly passing each node a permutation-invariant summary of its neighbors and updating itself — so a node’s vector ends up encoding its local connectivity, not just its own features.
40.1 Why graphs need their own models
A graph \(G = (V, E)\) is a set of nodes \(V\) and edges \(E\) connecting them. We usually summarize the edges with an adjacency matrix \(A \in \{0,1\}^{n \times n}\), where \(A_{ij}=1\) if there is an edge from \(i\) to \(j\), and attach a feature matrix \(X \in \mathbb{R}^{n \times d}\) giving each node a \(d\)-dimensional feature vector.
The trouble is that the grid and sequence models from earlier chapters (CNNs, RNNs/Transformers) lean on a regularity graphs do not have:
- No fixed neighbor count. A pixel always has 8 neighbors; a word always has one predecessor. A graph node may have 1 neighbor or 10,000. A convolution kernel of fixed size has nothing to slide over.
- No canonical ordering. Pixels have rows and columns; tokens have positions. Graph nodes have no inherent order. If you flatten \(A\) into a vector and feed it to an MLP, then relabeling the same graph’s nodes produces a totally different input — yet it is the same graph. We want permutation invariance (for graph-level outputs) and permutation equivariance (for node-level outputs): relabel the nodes, and the outputs relabel the same way, never change meaning.
Think of it like a seating chart at a dinner party. A photo of the table cares where each guest sits (grid: position matters). But the question “who is sitting next to whom?” does not change if you walk around and renumber the chairs — the relationships are the same. Graph models are built to answer the second kind of question and to ignore the renumbering.
Tasks on graphs come in three flavors, and the right architecture depends on which you face:
| Level | Predict | Examples |
|---|---|---|
| Node | a label/value per node | fraud account vs. legit, document topic, traffic speed at a sensor |
| Edge (link) | whether an edge exists / its type | friend recommendation, knowledge-graph completion, drug–target binding |
| Graph | one label/value per whole graph | molecule toxicity, solubility, is-this-program-malware |
The mental shift: in a CNN you ask “what is around this pixel in space?” On a graph you ask “what is around this node in the connectivity?” The neighborhood is defined by edges, not by coordinates — and that is the only thing that changes, but it changes everything about the architecture.
A first formula: permutation equivariance, made precise
We keep saying “relabel the nodes and the answer should relabel too.” Here is that promise in symbols. A permutation is just a relabeling, captured by a permutation matrix \(P\) (one \(1\) per row and column, zeros elsewhere). A node-level GNN function \(f\) is permutation equivariant when
\[f(PX,\ PAP^\top) = P\,f(X, A).\]
In words: if you shuffle the rows of the features and shuffle the adjacency the same way before running the model, you get the same answer you’d have gotten by running the model first and shuffling its output afterward — the model never depends on how you numbered the nodes. Also written: for a graph-level readout \(g\) that returns one vector for the whole graph, the matching property is plain invariance, \(g(PX, PAP^\top) = g(X, A)\) — the permutation disappears entirely.
This one equation is the design constraint behind every architecture in this chapter. Everything else is engineering inside that box.
40.2 Node embeddings from random walks: DeepWalk and node2vec
Before end-to-end graph networks took over, a beautifully simple idea dominated: treat random walks on the graph like sentences, and apply word2vec. If you stand on a node and wander to random neighbors, the sequence of nodes you visit is like a sentence whose “words” are nodes. Nodes that co-occur in many walks are “used in similar contexts,” so — exactly as in the skip-gram model — they should get similar embeddings.
DeepWalk does precisely this. For each node, sample a few fixed-length uniform random walks; collect all walks as a corpus; run skip-gram to learn a vector \(z_v \in \mathbb{R}^k\) per node. Skip-gram maximizes the probability of seeing a node’s walk-neighbors given the node:
\[\max_{Z}\ \sum_{v \in V}\ \sum_{u \in N_{\text{walk}}(v)} \log \frac{\exp(z_u^\top z_v)}{\sum_{w \in V}\exp(z_w^\top z_v)}\]
In words: find a vector for every node so that nodes which keep showing up near each other on random walks have vectors pointing in similar directions, and everyone else’s point away. Also written: since \(\frac{\exp(z_u^\top z_v)}{\sum_w \exp(z_w^\top z_v)} = \operatorname{softmax}_u(Z z_v)\), the inner term is just \(\log P(u \mid v)\) under a softmax over all nodes — maximize the log-likelihood of each observed walk-neighbor.
The softmax denominator over all nodes is too expensive, so in practice it is approximated with negative sampling (Chapter 20): push \(z_u^\top z_v\) up for real walk-neighbors, push it down for a handful of random “negative” nodes.
The doodle makes the “walk = sentence” idea concrete: a token slides from node to node, and the path it traces is exactly the sequence handed to skip-gram.
node2vec adds one clever knob. A pure uniform walk cannot tell you what kind of similarity you care about. node2vec biases the walk with two parameters:
- Return parameter \(p\) — how likely to step back to where you just came from.
- In–out parameter \(q\) — whether to stay local (BFS-like, capturing structural roles / communities) or venture outward (DFS-like, capturing homophily / reachability).
With \(q > 1\) the walk stays close (breadth-first flavor); with \(q < 1\) it explores outward (depth-first flavor). This single biasing lets the same algorithm produce embeddings tuned for either community structure or role similarity.
import numpy as np
# Tiny DeepWalk-style walk generator on an adjacency list
graph = {0:[1,2], 1:[0,2], 2:[0,1,3], 3:[2,4], 4:[3]} # a barbell-ish graph
def random_walk(start, length, rng):
walk = [start]
for _ in range(length - 1):
nbrs = graph[walk[-1]]
walk.append(rng.choice(nbrs)) # uniform step to a neighbor
return walk
rng = np.random.default_rng(0)
corpus = [random_walk(v, 5, rng) for v in graph for _ in range(3)] # 3 walks/node
for w in corpus[:4]:
print(w)
# These walks become "sentences" fed to skip-gram -> one vector per node.In real projects you do not hand-roll skip-gram on top of these walks — you let a library do the whole pipeline. With node2vec (which wraps gensim’s Word2Vec):
# pip install node2vec networkx
import networkx as nx
from node2vec import Node2Vec
G = nx.from_dict_of_lists({0:[1,2], 1:[0,2], 2:[0,1,3], 3:[2,4], 4:[3]})
# walk_length / num_walks shape the "corpus"; p, q bias BFS vs DFS exploration
n2v = Node2Vec(G, dimensions=16, walk_length=10, num_walks=20, p=1, q=0.5, workers=1)
model = n2v.fit(window=5, min_count=1) # this is gensim skip-gram under the hood
print(model.wv["2"][:5]) # 16-d embedding for node 2
print(model.wv.most_similar("2")) # nearest nodes in embedding spaceThe output is a lookup table of vectors you can drop into any downstream classifier. The catch: these embeddings are transductive. They are learned for the exact nodes present at training time. Add a new node and you must retrain — there is no function that maps “features + neighborhood” to an embedding. That limitation is exactly what message-passing networks fix.
40.3 The message-passing framework
Here is the unifying idea behind nearly every modern graph neural network (GNN). To compute a node’s new representation, gather the representations of its neighbors, combine them into one message, and use that message to update the node. Repeat this for \(L\) rounds (layers), and each node’s vector comes to summarize its \(L\)-hop neighborhood.
An analogy: gossip on a network. Each round, every person tells their immediate friends everything they currently know, and everyone updates their own picture by blending in what they just heard. After one round you know about your friends; after two rounds you know about your friends’ friends; after \(L\) rounds, news has rippled out \(L\) handshakes. A GNN layer is exactly one round of this gossip, with learned rules for how to listen and how to update.
Formally, layer \(\ell\) updates every node \(v\) with hidden state \(h_v\) as:
\[m_v^{(\ell)} = \text{AGGREGATE}^{(\ell)}\Big(\{\, h_u^{(\ell-1)} : u \in N(v) \,\}\Big), \qquad h_v^{(\ell)} = \text{UPDATE}^{(\ell)}\Big(h_v^{(\ell-1)},\ m_v^{(\ell)}\Big)\]
In words: first squash all of a node’s neighbors into one summary message, then mix that message with the node’s own current state to get its next state — and do this for every node, every layer. Also written: as a two-step composition per node, \(h_v^{(\ell)} = \phi\big(h_v^{(\ell-1)},\ \bigoplus_{u\in N(v)} \psi(h_v^{(\ell-1)}, h_u^{(\ell-1)})\big)\), where \(\bigoplus\) is a permutation-invariant pooling operator and \(\phi, \psi\) are learnable functions.
The AGGREGATE must be permutation-invariant (sum, mean, max, attention-weighted sum) — neighbors have no order, so the function reading them must not care about order. That single constraint is what makes the whole family respect graph symmetry. Different choices of AGGREGATE and UPDATE give you the specific architectures in the next sections.
The animation shows one round of “gossip”: messages flow inward along the edges to the center node, which then lights up as it updates.
After \(L\) layers you have a vector per node that has “seen” everything within \(L\) hops. Use it directly for node tasks; combine two endpoints’ vectors for edge tasks; pool all node vectors for graph tasks (§40.7).
The receptive field grows one hop per layer. This is the graph analog of a CNN’s growing receptive field, and it is the single most useful picture to keep in mind. The doodle below shows information reaching node \(v\) from farther and farther out as layers stack.
40.4 Graph Convolutional Networks (GCN)
The GCN (Kipf & Welling, 2017) is the simplest useful message-passing layer: aggregate by a normalized mean of neighbors (including the node itself), then apply a linear map and nonlinearity. In matrix form the whole layer is one line:
\[H^{(\ell+1)} = \sigma\!\left(\tilde{D}^{-1/2}\,\tilde{A}\,\tilde{D}^{-1/2}\,H^{(\ell)}\,W^{(\ell)}\right)\]
In words: smooth each node’s features over itself and its neighbors (the \(\tilde D^{-1/2}\tilde A\tilde D^{-1/2}\) part), then linearly transform and apply a nonlinearity — a blur-then-learn step on the graph. Also written: per node, \(h_v^{(\ell+1)} = \sigma\!\Big(W^{(\ell)\top}\sum_{u\in N(v)\cup\{v\}} \frac{1}{\sqrt{\tilde d_v \tilde d_u}}\, h_u^{(\ell)}\Big)\) — the matrix line is just this sum written for all nodes at once.
where \(\tilde{A} = A + I\) adds self-loops (so a node keeps its own information), \(\tilde{D}\) is the diagonal degree matrix of \(\tilde{A}\), \(W^{(\ell)}\) is a learnable weight matrix, and \(\sigma\) is e.g. ReLU. The symmetric normalization \(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}\) down-weights edges to high-degree nodes, so a hub neighbor does not drown out everyone else.
Let us do one layer by hand on a 3-node path graph \(0 - 1 - 2\), with scalar features \(x = [1, 0, 0]\) and \(W = 1\), $= $ identity:
\[A=\begin{bmatrix}0&1&0\\1&0&1\\0&1&0\end{bmatrix},\quad \tilde{A}=A+I=\begin{bmatrix}1&1&0\\1&1&1\\0&1&1\end{bmatrix},\quad \tilde{D}=\text{diag}(2,3,2)\]
Node 1 (the middle) aggregates itself and both neighbors. Its normalized output is
\[h_1 = \frac{\tilde{A}_{10}\,x_0}{\sqrt{d_1 d_0}} + \frac{\tilde{A}_{11}\,x_1}{\sqrt{d_1 d_1}} + \frac{\tilde{A}_{12}\,x_2}{\sqrt{d_1 d_2}} = \frac{1\cdot 1}{\sqrt{3\cdot2}} + \frac{1\cdot 0}{3} + \frac{1\cdot 0}{\sqrt{3\cdot2}} = \frac{1}{\sqrt6}\approx 0.408\]
import numpy as np
A = np.array([[0,1,0],[1,0,1],[0,1,0]], float)
At = A + np.eye(3) # add self-loops
d = At.sum(1) # tilde-D diagonal: [2,3,2]
Dinv = np.diag(d ** -0.5)
S = Dinv @ At @ Dinv # symmetric-normalized propagation
x = np.array([[1.],[0.],[0.]]) # feature lives on node 0
print((S @ x).ravel()) # -> [0.5 0.408 0.0] : signal spread to node 1Reading the output [0.5, 0.408, 0.0]: node 0 kept some of its own signal (\(\tfrac{1}{\sqrt{2\cdot2}}=0.5\) from its self-loop), node 1 picked up \(0.408\) from node 0 across the edge, and node 2 is still \(0\) because it is two hops away — the signal has not reached it yet. The signal that started only on node 0 has flowed to its neighbor. Stack a second layer and it reaches node 2 — each layer extends the reach by one hop. GCN is fast and a strong baseline, but its aggregation weights are fixed by degree: it cannot learn that some neighbors matter more than others. That is what GAT addresses (§40.6).
In practice you would reach for PyTorch Geometric (PyG), the most common GNN framework, rather than rebuilding propagation by hand. A complete two-layer GCN classifier is just:
# pip install torch torch_geometric
import torch, torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_dim, hid, n_classes):
super().__init__()
self.conv1 = GCNConv(in_dim, hid) # one message-passing layer
self.conv2 = GCNConv(hid, n_classes) # ...exactly the matrix line above
def forward(self, x, edge_index): # edge_index: 2 x E list of edges
x = F.relu(self.conv1(x, edge_index))
return self.conv2(x, edge_index) # logits per node
# x: [N, in_dim] node features; edge_index: [2, E] connectivity (COO format)
model = GCN(in_dim=1433, hid=64, n_classes=7) # e.g. the Cora citation graph
logits = model(x, edge_index) # one forward pass over the whole graph40.5 GraphSAGE: sampling neighbors for scale
GCN as written multiplies by the full adjacency matrix — it needs the whole graph in memory at once. On a social network with a billion edges that is hopeless. GraphSAGE (Hamilton et al., 2017) makes two changes that turn GNNs into something you can train at industrial scale.
First, sample a fixed number of neighbors rather than using all of them. For each node, draw (say) 25 neighbors at the first hop and 10 of each of those at the second hop. Now the computation for a node depends on a bounded-size neighborhood, regardless of whether the node has 5 or 5 million neighbors — so you can train in mini-batches and never materialize the full graph.
Second, learn an UPDATE that concatenates self and aggregated-neighbor, rather than blending them:
\[h_v^{(\ell)} = \sigma\!\left(W^{(\ell)} \cdot \text{CONCAT}\Big(h_v^{(\ell-1)},\ \text{AGG}\big(\{h_u^{(\ell-1)}: u \in \mathcal{S}(v)\}\big)\Big)\right)\]
In words: summarize a sampled handful of neighbors, glue that summary onto the node’s own vector side-by-side, and pass the pair through one learned layer — “here’s me, and here’s my surroundings, kept separate.” Also written: with \(a = \text{AGG}(\{h_u^{(\ell-1)}\})\), this is \(h_v^{(\ell)} = \sigma\!\big(W_{\text{self}} h_v^{(\ell-1)} + W_{\text{nbr}}\, a\big)\) after splitting \(W\) into the two halves that multiply the concatenated blocks.
where \(\mathcal{S}(v)\) is the sampled neighbor set and AGG is mean, max-pool, or an LSTM-pool. Keeping the node’s own vector in a separate slot (concatenation) preserves “who I am” distinctly from “what’s around me.”
The deeper payoff is that GraphSAGE is inductive: it learns the aggregator weights \(W^{(\ell)}\), not one fixed embedding per node. Show it a brand-new node with features and a neighborhood it has never seen, and it produces an embedding by running the same learned functions — no retraining. This is the direct cure for DeepWalk/node2vec’s transductive limitation (§40.2).
flowchart LR T["target node"] -->|sample 2 of K hop-1| N1[nbr a] T -->|sample| N2[nbr b] N1 -->|sample 2 hop-2| L1[nbr c] N1 --> L2[nbr d] N2 --> L3[nbr e] N2 --> L4[nbr f] L1 & L2 & L3 & L4 -->|AGG| AGG2[hop-2 reps] AGG2 --> AGG1[hop-1 reps] AGG1 --> EMB["target embedding"]
In PyG, the sampling and the layer are two separate pieces — a NeighborLoader that draws bounded neighborhoods into mini-batches, and SAGEConv for the layer itself:
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
# draw 25 neighbors at hop-1, 10 at hop-2 -> bounded compute regardless of degree
loader = NeighborLoader(data, num_neighbors=[25, 10], batch_size=512,
input_nodes=data.train_mask)
class SAGE(torch.nn.Module):
def __init__(self, in_dim, hid, out):
super().__init__()
self.c1 = SAGEConv(in_dim, hid) # concat(self, sampled-neighbor-agg)
self.c2 = SAGEConv(hid, out)
def forward(self, x, edge_index):
x = F.relu(self.c1(x, edge_index))
return self.c2(x, edge_index)
for batch in loader: # each batch is a small sampled subgraph
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward(); ... # mini-batch training, never the full graphNeighbor sampling adds variance: the same node gets slightly different embeddings on different forward passes because a different neighbor sample is drawn. This is usually fine (it acts like a regularizer), but for inference you may want to average several samples or sample more neighbors to stabilize predictions.
40.6 Graph Attention Networks (GAT)
GCN weights every neighbor by a fixed function of degree. But intuitively, not all neighbors are equally relevant — in a citation network, one cited paper might be central to yours and another only tangential. GAT (Veličković et al., 2018) lets the model learn how much to weight each neighbor, by importing the attention mechanism from Transformers onto the edges.
For each edge \((v, u)\) the network computes an attention coefficient, then normalizes across \(v\)’s neighbors with a softmax:
\[e_{vu} = \text{LeakyReLU}\big(a^\top [\,W h_v \,\Vert\, W h_u\,]\big), \qquad \alpha_{vu} = \frac{\exp(e_{vu})}{\sum_{k \in N(v)} \exp(e_{vk})}\]
In words: score how relevant each neighbor \(u\) is to node \(v\) from their (transformed) features, then turn those raw scores into percentages that add to 1 across the neighborhood — a learned “how much should I listen to each neighbor?” Also written: \(\alpha_{vu} = \operatorname{softmax}_{u \in N(v)}\big(e_{vu}\big)\), identical in spirit to Transformer attention weights but with the neighbor set \(N(v)\) playing the role of the context window.
The new node state is the attention-weighted sum of transformed neighbors, \(h_v' = \sigma\!\big(\sum_{u \in N(v)} \alpha_{vu} W h_u\big)\). As in Transformers, you run several attention heads in parallel and concatenate them, so the node can attend to neighbors in several different ways at once.
Worked example of the weights. Suppose node \(v\) has three neighbors with raw scores \(e = [2.0, 1.0, 0.1]\). The softmax gives attention
\[\alpha = \frac{[e^{2.0}, e^{1.0}, e^{0.1}]}{e^{2.0}+e^{1.0}+e^{0.1}} = \frac{[7.39, 2.72, 1.11]}{11.22} \approx [0.659,\ 0.242,\ 0.099].\]
import numpy as np
e = np.array([2.0, 1.0, 0.1]) # learned edge scores for v's 3 neighbors
alpha = np.exp(e) / np.exp(e).sum() # softmax over the neighborhood
print(alpha.round(3)) # [0.659 0.242 0.099] -> first neighbor dominatesThe bars are the attention weights \(\alpha\) above: the strongest neighbor glows brightest, and the widths sum to one whole row.
In PyG, multi-head GAT is a drop-in replacement for the GCN layer — heads=8 runs eight independent attention patterns and concatenates them:
from torch_geometric.nn import GATConv
# 8 heads on layer 1 (outputs hid*8), single averaged head on the output layer
conv1 = GATConv(in_dim, hid, heads=8, dropout=0.6)
conv2 = GATConv(hid * 8, n_classes, heads=1, concat=False, dropout=0.6)The first neighbor gets two-thirds of the weight — the model decided it matters most. Note the coefficients are computed per node from features, so GAT is naturally inductive like GraphSAGE, and it is more expressive than GCN whenever neighbor importance is uneven. The cost is more parameters and compute per edge, and attention can occasionally be unstable to train (multiple heads help).
Which message-passing layer should I use?
flowchart TD
Q1{Graph fits in<br/>memory at once?} -- no --> SAGE[GraphSAGE<br/>sample neighbors, inductive, scales]
Q1 -- yes --> Q2{Do neighbors differ<br/>in importance?}
Q2 -- yes --> GAT[GAT<br/>learned attention per edge]
Q2 -- no --> Q3{Graph-level task<br/>needing max expressiveness?}
Q3 -- yes --> GIN["GIN (§40.8)<br/>sum aggregator, 1-WL power"]
Q3 -- no --> GCN[GCN<br/>fast normalized-mean baseline]
40.7 Readout/pooling and the limits of depth
Readout for graph-level tasks. Message passing gives you a vector per node. For a graph-level prediction (is this molecule toxic?) you must collapse all node vectors into one graph vector with a permutation-invariant readout:
\[h_G = \text{READOUT}\big(\{h_v^{(L)} : v \in V\}\big) \in \{\text{sum},\ \text{mean},\ \text{max}\}\]
In words: mash every node’s final vector into a single graph-level vector with an operation that ignores node order — usually just add them up, average them, or take the per-feature maximum. Also written: \(h_G = \bigoplus_{v\in V} h_v^{(L)}\) where \(\bigoplus\) is the chosen order-independent pool (e.g. \(\sum_v h_v^{(L)}\) for sum-pooling).
then feed \(h_G\) to an MLP head. Sum-pooling preserves graph size (a 50-atom and 5-atom molecule differ), mean-pooling is size-invariant, max-pooling captures whether any node has a feature. More elaborate options exist (hierarchical/differentiable pooling that coarsens the graph in stages), but sum/mean/max cover most needs.
In PyG these are one-liners that take node embeddings plus a batch vector (which node belongs to which graph in the mini-batch):
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
# h: [total_nodes, hid] node embeddings; batch: [total_nodes] graph id per node
hG = global_add_pool(h, batch) # one row per graph in the batch -> MLP headOver-smoothing: why you can’t just stack layers. With CNNs, deeper is usually better. With GNNs there is a sharp trap. Each message-passing layer mixes a node with its neighbors — so after many layers, every node’s representation is an average over an ever-larger neighborhood, and in the limit all nodes converge to the same vector. This is over-smoothing, and it means stacking 20 GNN layers typically makes accuracy collapse, not improve.
The everyday intuition: if everyone in a town keeps averaging their opinion with their neighbors’, day after day, eventually the whole town holds the identical opinion and you can no longer tell anyone apart. Useful node distinctions get blurred away by too much mixing.
In practice most GNNs are shallow (2–4 layers). When you genuinely need long-range information, the fixes mirror deep-net tricks from earlier chapters: residual/skip connections, jumping-knowledge (concatenate every layer’s output, not just the last), normalization tuned for graphs, and architectural choices that widen receptive field without naive depth.
If a deeper GNN performs worse than a 2-layer one, suspect over-smoothing before you suspect a bug. The signature is node representations becoming nearly identical (their pairwise distances shrink toward zero) as depth grows. Add skip connections or simply use fewer layers.
Over-smoothing’s quieter cousin: over-squashing. Picture a message from a faraway node that must reach yours. Each hop it travels, it merges with more and more other messages — and the number of messages grows explosively with distance. By the time they arrive, they have all been squeezed into one fixed-size node vector, often passing through a single narrow “bridge” edge on the way. The faraway signal is still in there, but blurred past recognition. This is over-squashing: it is why a GNN struggles with long-range reasoning even before over-smoothing sets in. The fixes either widen the pipe (rewire the graph with shortcut edges) or skip it entirely (graph-transformer-style global attention, so distant nodes talk directly — §40.12).
40.8 GIN and the Weisfeiler–Lehman expressiveness test
A deeper question: how powerful is message passing, fundamentally? Two non-isomorphic graphs that “look the same” to a GNN will always get the same prediction — a hard ceiling on what any such model can distinguish. The yardstick for this is the classic Weisfeiler–Lehman (1-WL) graph isomorphism test, a 1968 algorithm that iteratively refines node “colors”: each round, replace a node’s color with a hash of its own color plus the multiset of its neighbors’ colors, and repeat. If two graphs end with different color histograms, they are provably non-isomorphic.
The striking result (Xu et al., 2019): standard message-passing GNNs are at most as powerful as the 1-WL test. They can never distinguish two graphs that 1-WL cannot. And a GNN reaches that maximal power only if its AGGREGATE and READOUT are injective — they map distinct neighbor multisets to distinct outputs.
Here is the catch that motivates GIN (Graph Isomorphism Network): mean- and max-pooling are not injective. Mean-pooling cannot tell a node with neighbors \(\{A\}\) from one with \(\{A, A\}\) — the average is identical. Sum-pooling can: \(A \neq A + A\). So GIN uses a sum aggregator wrapped in an MLP (a universal function approximator), which together can be injective:
\[h_v^{(\ell)} = \text{MLP}^{(\ell)}\!\left( \big(1 + \epsilon^{(\ell)}\big)\, h_v^{(\ell-1)} + \sum_{u \in N(v)} h_u^{(\ell-1)} \right)\]
In words: add up the neighbors’ vectors, add a slightly-weighted copy of the node’s own vector so it stays distinguishable, then push the result through a small neural net — a sum keeps “how many of each kind of neighbor” information that an average would erase. Also written: with \(\epsilon = 0\) this is simply \(\text{MLP}\big(\sum_{u \in N(v)\cup\{v\}} h_u^{(\ell-1)}\big)\) — a sum over the closed neighborhood, then an MLP.
The \((1+\epsilon)\) term keeps the node’s own contribution distinguishable from the neighbor sum. This makes GIN provably as discriminative as 1-WL — the most expressive a standard message-passing GNN can be.
Worked example of why sum beats mean. Two nodes, one with neighbor features \(\{[1,0]\}\) and one with \(\{[1,0],[1,0]\}\):
import numpy as np
nbrs_A = np.array([[1,0]]) # one neighbor
nbrs_B = np.array([[1,0],[1,0]]) # two identical neighbors
print("mean:", nbrs_A.mean(0), nbrs_B.mean(0)) # [1 0] [1 0] -> indistinguishable!
print("sum :", nbrs_A.sum(0), nbrs_B.sum(0)) # [1 0] [2 0] -> distinguishableMean collapses the two cases; sum keeps them apart. That single distinction is the whole reason GIN exists — and a reminder that the choice of aggregator caps what your model can ever learn.
In PyG, GINConv wraps any MLP you hand it and does the sum aggregation internally:
from torch_geometric.nn import GINConv
mlp = torch.nn.Sequential(torch.nn.Linear(in_dim, hid), torch.nn.ReLU(),
torch.nn.Linear(hid, hid))
conv = GINConv(mlp, train_eps=True) # sum over neighbors -> mlp; epsilon is learnedEven GIN is bounded by 1-WL, which famously cannot count triangles or distinguish certain regular graphs (e.g. a 6-cycle from two triangles). If your task hinges on such substructures — counting rings in a molecule, detecting specific motifs — plain message passing is provably insufficient. You need higher-order GNNs, subgraph-based methods, or explicit structural features added to \(X\).
40.9 Knowledge-graph embeddings for link prediction
A knowledge graph stores facts as triples \((h, r, t)\) — head entity, relation, tail entity — like (Paris, capital_of, France). These graphs are vast and radically incomplete; the central task is link prediction: score whether a candidate triple is true, to fill in missing facts. The dominant approach learns a vector for every entity and every relation so that true triples score high and false ones score low.
TransE models a relation as a translation in embedding space: if \((h, r, t)\) holds, then \(h + r \approx t\). Its score is the negative distance \(f(h,r,t) = -\lVert h + r - t \rVert\).
In words: a relation is an arrow you add to the head to (almost) land on the tail; a fact scores well when adding the relation’s arrow to the head gets you close to the tail. Also written: equivalently, score is high when \(h + r - t \approx \mathbf{0}\), i.e. the residual vector \(h + r - t\) has small norm.
This is elegant and captures composition (Paris + capital_of ≈ France), but it structurally cannot model symmetric relations (if \(h+r=t\) then \(t+r=h\) forces \(r=0\)) nor one-to-many relations cleanly.
DistMult scores via a multiplicative interaction, \(f(h,r,t) = \sum_i h_i \, r_i \, t_i\) (a weighted dot product with \(r\) as diagonal weights).
In words: line up head, relation, and tail dimension-by-dimension, multiply the three together in each dimension, and add up — agreement across all three drives the score. Also written: \(f(h,r,t) = h^\top \operatorname{diag}(r)\, t\), the bilinear form with relation \(r\) on the diagonal.
It is simple and strong, but because multiplication is commutative in \(h\) and \(t\), it scores \((h,r,t)\) and \((t,r,h)\) identically — so it can only model symmetric relations and gets antisymmetric ones (like parent_of) wrong.
RotatE fixes both by moving to complex space and modeling each relation as a rotation: \(t \approx h \circ r\) where \(\circ\) is elementwise complex multiplication and each \(r_i\) has modulus 1 (a pure rotation by some angle).
In words: instead of sliding the head to the tail (TransE), spin it; each relation rotates the head’s complex coordinates by a fixed angle to reach the tail. Also written: writing \(r_i = e^{i\theta_i}\), the rotation acts per dimension as \(t_i \approx h_i\, e^{i\theta_i}\), and the score is \(-\lVert h \circ r - t\rVert\).
Rotations compose, and crucially a rotation by angle \(\theta\) has an inverse rotation \(-\theta\) — so RotatE can represent symmetry (\(\theta=0\) or \(\pi\)), antisymmetry, inversion, and composition, the main relation patterns.
| Model | Score \(f(h,r,t)\) | Captures | Misses |
|---|---|---|---|
| TransE | \(-\lVert h + r - t\rVert\) | composition, inversion, antisymmetry | symmetry, 1-to-many |
| DistMult | \(\sum_i h_i r_i t_i\) | symmetry | antisymmetry (scores \(h,t\) equally) |
| RotatE | \(-\lVert h \circ r - t\rVert\) (complex) | symmetry, antisymmetry, inversion, composition | very complex hierarchies |
import numpy as np
# TransE sanity check in 2D: learn r so that h + r = t for (Paris, capital_of, France)
h = np.array([0.2, 0.9]) # Paris
t = np.array([0.5, 0.4]) # France
r = t - h # the ideal "capital_of" translation
print("score true :", -np.linalg.norm(h + r - t)) # 0.0 (perfect)
wrong_t = np.array([0.1, 0.1])
print("score false:", -np.linalg.norm(h + r - wrong_t)) # negative -> ranked lowerModels are trained by negative sampling: take a true triple, corrupt its head or tail with a random entity, and use a margin or logistic loss to score the true triple above the corrupted one. At inference, to answer “(Paris, capital_of, ?)” you score every entity as the tail and rank them.
Choose the scorer by the relations in your data, not by leaderboard rank. If your graph is full of symmetric relations (spouse_of, sibling_of), DistMult or RotatE will fit them; TransE structurally can’t. If antisymmetry matters (parent_of, capital_of), avoid DistMult. When in doubt, RotatE handles the most patterns — start there.
In practice, libraries like PyKEEN train and evaluate these models end-to-end (negative sampling, ranking metrics, and dozens of models) in a few lines:
# pip install pykeen
from pykeen.pipeline import pipeline
result = pipeline(dataset="Nations", model="RotatE", # try "TransE", "DistMult" too
training_kwargs=dict(num_epochs=50))
result.plot_losses() # rank-based metrics (MRR, Hits@k) in result.metric_results40.10 Relational deep learning over databases
Most enterprise data does not live in a single graph file — it lives in a relational database: a customer table, an orders table, a products table, joined by foreign keys. The traditional ML recipe is to flatten this by hand: write SQL to compute features like “number of orders in last 30 days,” “average basket value,” and dump them into one wide table for an XGBoost model. This feature engineering is where most of the project time goes, and every hand-crafted aggregate throws away structure.
Relational deep learning reframes the database itself as a graph and learns on it directly. The mapping is natural: each row becomes a node, columns become its features, and a foreign-key reference becomes an edge between two rows. A customer row links to its order rows, which link to product rows — exactly the heterogeneous graph a GNN can consume.
flowchart LR
subgraph DB[Relational database]
C[(Customers)] -.FK.-> O[(Orders)]
O -.FK.-> P[(Products)]
end
DB ==> G
subgraph G[Schema graph]
cu[customer node] --> or1[order node] --> pr[product node]
cu --> or2[order node] --> pr
end
G ==> GNN[GNN learns aggregates<br/>no manual SQL features]
Concrete example. To predict churn for one customer the old way, an analyst writes SQL for, say, “orders in last 30 days,” “days since last login,” and “average rating given” — three columns, each a guess about what matters. The graph version skips the guessing: it links that customer row to its order rows to their product rows, and the GNN learns whichever aggregates actually predict churn, including ones no analyst would write by hand.
A GNN over this graph learns the aggregations a human would otherwise hand-write — and can discover predictive multi-hop patterns (“customers who bought from a category whose other buyers churned”) that nobody thought to engineer. Because the tables are different (customers vs. orders vs. products), this uses a heterogeneous GNN with per-edge-type and per-node-type weights, and predictions are usually time-aware (only message along edges that existed before the prediction time, to avoid leakage). This is an active frontier — promising for churn, fraud, and demand prediction — but newer and less battle-tested than gradient-boosted trees on flat features, which remain a very strong baseline.
In PyG, a heterogeneous GNN is built by wrapping a homogeneous module with to_hetero, which automatically gives each edge type its own message-passing weights:
from torch_geometric.nn import SAGEConv, to_hetero
class Net(torch.nn.Module):
def __init__(self, hid, out):
super().__init__()
self.c1 = SAGEConv((-1, -1), hid) # (-1,-1): infer per-relation input dims
self.c2 = SAGEConv((-1, -1), out)
def forward(self, x, edge_index):
x = F.relu(self.c1(x, edge_index)); return self.c2(x, edge_index)
# data.metadata() = (node_types, edge_types) e.g. ('customer','places','order')
model = to_hetero(Net(64, 1), data.metadata(), aggr="sum")
out = model(data.x_dict, data.edge_index_dict) # dict of per-node-type predictions40.11 Applications and limitations
Where graph ML genuinely earns its keep:
- Chemistry & drug discovery — molecules are graphs (atoms/bonds); GNNs predict toxicity, solubility, binding, and screen huge compound libraries. This is the flagship success story.
- Recommenders — the user–item interaction graph is enormous and sparse; GNNs (e.g. PinSAGE-style, building on recommender systems) power web-scale recommendation by aggregating over the bipartite graph.
- Knowledge-graph completion — search, question answering, and enterprise knowledge bases use the embedding models of §40.9 to infer missing facts.
- Traffic & logistics — road networks are graphs; spatio-temporal GNNs forecast travel times (this powers real ETA systems).
- Fraud & security — fraud rings show up as suspicious subgraph structure that node-feature models miss entirely.
And the honest limitations:
- Scale and latency. Real graphs are billions of edges; sampling (§40.5) helps but serving low-latency GNN predictions is an engineering project, not a
model.predict(). - Depth/over-smoothing (§40.7) caps long-range reasoning; truly long-range dependencies remain awkward.
- Expressiveness ceiling (§40.8). Plain message passing can’t see past 1-WL — count motifs at your peril.
- Strong simple baselines. On many tabular-ish or feature-rich problems, an MLP or gradient-boosted trees on good features matches a GNN. Always check whether the graph structure actually adds signal before reaching for a GNN — sometimes the edges carry little the node features don’t already.
Decision heuristic: reach for a GNN when (a) the connectivity itself is predictive — remove the edges and a strong model on node features alone clearly degrades — and (b) you have enough labeled data to train one. If a logistic regression on node features is already 95% of the way there, the graph was decoration, not signal.
40.12 Graph Transformers and the global-attention alternative
Message passing’s two recurring headaches — over-smoothing and over-squashing (§40.7) — both come from the same root cause: information can only travel along edges, one hop per layer. Graph Transformers take a different bet, borrowed wholesale from attention & transformers: let every node attend to every other node directly, the way a Transformer token attends to all tokens in a sequence. Two faraway nodes can then exchange information in a single layer, with no bottleneck edge in between.
The intuition: a message-passing GNN is like a rumor that has to physically pass person-to-person down a chain; a graph transformer is like a group video call where anyone can speak to anyone instantly. The catch is that a video call with \(n\) people costs \(O(n^2)\) attention — fine for a small molecule, ruinous for a million-node social graph.
But if you throw away the edges and let everyone attend to everyone, you also throw away the structure — the very signal that made it a graph. The fix is to feed structure back in as features:
- Positional / structural encodings. Just as a sequence Transformer adds positional encodings so it knows token order, a graph transformer adds encodings that describe where a node sits in the graph — most commonly the eigenvectors of the graph Laplacian (a spectral “address” for each node) or random-walk landing probabilities.
- Attention biases from the graph. Models like Graphormer add a learned bias to each attention score based on the shortest-path distance between the two nodes, plus edge-feature encodings — so nearby and bonded nodes still get a structural boost even though attention is global.
The payoff: graph transformers sidestep over-smoothing and over-squashing, and they currently top several molecular-property leaderboards. The price: quadratic cost, a hunger for more data, and reliance on good structural encodings. The pragmatic middle ground, increasingly common, is hybrid — interleave local message-passing layers (cheap, structure-aware) with occasional global-attention layers (long-range) — so you pay the quadratic cost sparingly while still letting distant nodes talk.
# A graph-transformer-style layer in PyG: attention over neighbors with edge features.
# TransformerConv is local (over edges) but multi-head + edge-aware, a common building block.
from torch_geometric.nn import TransformerConv
conv = TransformerConv(in_dim, hid, heads=4, edge_dim=edge_feat_dim)
h = conv(x, edge_index, edge_attr) # multi-head attention messages, optionally + Laplacian PE in xWhen to consider a graph transformer: small-to-medium graphs (molecules, small circuits, scene graphs) where long-range interactions matter and you have ample training data. For web-scale node-classification graphs, sampled message passing (§40.5) usually remains the practical choice — full global attention over millions of nodes is rarely affordable.
40.1 — Quick reference
| Term / method | Meaning in one line | When / why |
|---|---|---|
| Permutation equivariance | \(f(PX, PAP^\top)=P f(X,A)\) — relabel nodes, outputs relabel too | The design constraint behind every GNN; node-level outputs |
| Permutation invariance | \(g(PX, PAP^\top)=g(X,A)\) — relabel nodes, output unchanged | Graph-level readout: the label can’t depend on numbering |
| DeepWalk / node2vec | Random walks as “sentences” → skip-gram embeddings | Quick node embeddings; transductive (no new nodes) |
| Message passing | AGGREGATE neighbors → UPDATE node, repeat \(L\) layers | The unifying recipe; receptive field grows 1 hop/layer |
| GCN | \(H'=\sigma(\tilde D^{-1/2}\tilde A\tilde D^{-1/2}HW)\) — normalized-mean | Fast strong baseline; whole graph in memory, fixed weights |
| GraphSAGE | Sample \(k\) neighbors, CONCAT(self, agg) | Scales to huge graphs; inductive (new nodes work) |
| GAT | Attention weights \(\alpha_{vu}=\mathrm{softmax}(e_{vu})\) over neighbors | When some neighbors matter more than others; inductive |
| GIN | \(\mathrm{MLP}((1+\epsilon)h_v+\sum_u h_u)\) — sum aggregator | Max expressiveness (1-WL); graph-level tasks needing it |
| Readout | sum / mean / max pool of node vectors → graph vector | Collapse to one vector for graph-level prediction |
| Over-smoothing | Deep layers drive all node vectors to one value | Why GNNs stay 2–4 layers; use skip / jumping-knowledge |
| Over-squashing | Long-range messages crushed through bottleneck edges | Why long-range reasoning is hard; rewire or go global |
| 1-WL test | Color-refinement isomorphism check; bounds GNN power | The ceiling: message passing can’t beat it (no triangles) |
| TransE | \(-\lVert h+r-t\rVert\) — relation as translation | KG link prediction; misses symmetry, 1-to-many |
| DistMult | \(\sum_i h_i r_i t_i\) — bilinear diagonal | KG with symmetric relations only |
| RotatE | \(-\lVert h\circ r-t\rVert\) in complex space — rotation | KG capturing symmetry/antisymmetry/inversion/composition |
| Relational deep learning | DB rows = nodes, foreign keys = edges; learn aggregates | Skip hand-engineered SQL features; heterogeneous GNN |
| Graph transformer | Global attention + structural/positional encodings | Small/medium graphs, long-range; quadratic cost |
40.2 — Key takeaways
- Graphs break the assumptions of grid/sequence models — variable degree, no canonical order — so we need architectures that are permutation-invariant/equivariant by construction.
- Tasks come at node, edge, and graph level; the architecture’s body is shared, only the head (per-node, per-edge-pair, or pooled) changes.
- DeepWalk/node2vec turn random walks into “sentences” and run skip-gram — simple and effective but transductive (no new nodes without retraining).
- Message passing (AGGREGATE neighbors → UPDATE node) unifies modern GNNs. GCN = normalized mean; GraphSAGE = sampling + concat for inductive, scalable learning; GAT = learned attention weights over neighbors.
- Readout (sum/mean/max) collapses node vectors for graph-level tasks. Over-smoothing punishes naive depth — most GNNs are 2–4 layers, with skip/jumping-knowledge connections when more reach is needed; over-squashing separately caps long-range reasoning.
- GIN + the 1-WL test pin down expressiveness: message passing is at most as powerful as 1-WL, achieved only with an injective (sum) aggregator — mean/max throw away multiset information.
- KG embeddings score triples for link prediction: TransE (translation, misses symmetry), DistMult (only symmetric), RotatE (rotation, captures the main patterns).
- Relational deep learning treats a database as a graph (rows = nodes, foreign keys = edges) and learns the aggregates you used to hand-engineer — promising but newer than GBM-on-flat-features.
- Graph transformers swap edge-bound message passing for global attention plus structural encodings, dodging over-smoothing/over-squashing at quadratic cost — strong on small/medium graphs, often used in hybrid form.
- Always sanity-check that the graph structure adds real signal over node features alone before committing to a GNN.
40.3 — See also
- Chapter 20 (Natural Language Processing) — skip-gram and negative sampling, the engine behind DeepWalk/node2vec.
- Chapter 17 (Attention & Transformers) — the attention mechanism GAT imports onto edges, and the global attention graph transformers (§40.12) generalize; Transformers are essentially GNNs on a fully connected graph.
- Chapter 15 (Convolutional Neural Networks) — the grid-convolution intuition GCN generalizes to irregular neighborhoods.
- Chapter 10 (Ensemble Methods) — the strong tabular baseline (gradient-boosted trees) relational deep learning competes with.
- Chapter 26 (Recommender Systems) — the user–item graphs where web-scale GNNs (PinSAGE) shine.
- Chapter 39 (Frontier & Emerging Directions) — higher-order and geometric deep learning that push past the 1-WL expressiveness ceiling.
↪ The thread continues → Chapter 41 · 🤖 Robotics & Autonomy
Graphs let models reason over structure; embodiment lets them act in the physical world. Robotics adds perception, state estimation, planning, and control to learning.
📖 All chapters | ← 39 · 🌠 Frontier & Emerging Directions | 41 · 🤖 Robotics & Autonomy →