flowchart LR X[Input X] --> H1[Head 1: prev-token] X --> H2[Head 2: subject↔verb] X --> H3[Head 3: coreference] X --> H4[Head ...] H1 --> C[Concat] H2 --> C H3 --> C H4 --> C C --> WO[Linear W_O] --> Out[Output]
Chapter 17 — ⚡ Attention & Transformers
📖 All chapters | ← 16 · 🔁 Recurrent & Sequence Models | 18 · 🎨 Generative Models →
📚 Jump to any chapter
🧮 Mathematical Foundations
- 01 · 🧮 Linear Algebra
- 02 · ∂ Calculus & Differentiation
- 03 · 📉 Optimization
- 04 · 🎲 Probability & Statistics
🧭 The ML Workflow
🧩 Classical Machine Learning
- 08 · 📈 Regression
- 09 · 📐 Classification Algorithms
- 10 · 🌳 Ensemble Methods
- 11 · 🔮 Clustering & Unsupervised Learning
- 12 · 🎯 Model Evaluation & Tuning
🎲 Probabilistic Models
🧠 Deep Learning
- 14 · 🧠 Neural Networks (Core)
- 15 · 🖼️ Convolutional Neural Networks
- 16 · 🔁 Recurrent & Sequence Models
- 17 · ⚡ Attention & Transformers
- 18 · 🎨 Generative Models
🗣️ Applied AI: Vision, Language, Audio & Time
- 19 · 👁️ Computer Vision
- 20 · 💬 Natural Language Processing
- 21 · 🔊 Speech & Audio Processing
- 22 · ⏳ Time Series & Forecasting
- 23 · 📚 Large Language Models
- 24 · 🌈 Multimodal AI
🕹️ Reinforcement Learning
🛠️ Applied ML Systems & Industries
🚀 Production, Tooling & Infrastructure
📚 Classical & Symbolic AI
- 32 · 🧭 Search & Problem Solving
- 33 · 📖 Knowledge Representation & Reasoning
- 34 · 🗺️ Planning, Constraint Satisfaction & Game Playing
- 35 · 🧬 Evolutionary Computation & Metaheuristics
⚖️ Responsible AI & Frontier
- 36 · 🔍 Explainable AI & Interpretability
- 37 · 🧷 Causal Inference
- 38 · ⚖️ AI Ethics, Fairness & Safety
- 39 · 🌠 Frontier & Emerging Directions
🎓 Advanced & Specialized Topics
- 40 · 🔗 Graph Machine Learning
- 41 · 🤖 Robotics & Autonomy
- 42 · 📐 Learning Theory
- 43 · 🔎 Information Retrieval & Data Mining
- 44 · 🏗️ LLM Systems: Building LLMs from Scratch
🎚️ Post-Training & Fine-Tuning
- 45 · 🎚️ Post-Training I — Transfer, Fine-Tuning & PEFT
- 46 · 🏅 Post-Training II — Alignment & Evaluation
🚢 Model Serving & Deployment
Almost every state-of-the-art model you have heard of — GPT, BERT, Stable Diffusion’s text encoder, AlphaFold’s Evoformer — runs on one idea: attention. Attention lets a model decide, for each piece of its input, which other pieces matter and pull in exactly that information. The Transformer is the architecture built entirely from attention (plus a couple of ordinary layers), and it displaced the recurrent networks that dominated sequence modeling because it trains faster and reasons across longer distances. This chapter sits at the heart of modern Deep Learning: it is the bridge from the sequence models of Chapter 16 to the Large Language Models of Chapter 23.
🧭 In context: Deep Learning architectures · used to model sequences, language, vision, and almost everything else · the one key idea is attention as a learned, dynamic weighted lookup over the whole input at once.
💡 Remember this: Attention is a similarity-weighted average of values — a query scores every key, softmax turns those scores into weights, and the output blends the values; everything else (heads, masking, position, the KV cache) is engineering around that one operation.
17.1 — The attention mechanism
The cleanest way to understand attention is as a soft dictionary lookup. A normal Python dictionary is a hard lookup: you hand it a key, it returns the one value stored under the exact matching key. Attention softens this. You hand it a query, it compares that query against every key, and returns a blend of all the values — weighted by how well each key matched. Nothing is all-or-nothing; everything contributes a little, the best matches contribute a lot.
Three roles, all just vectors:
- Query (q) — what the current position is looking for.
- Key (k) — what each position advertises about itself, used for matching.
- Value (v) — the actual content each position will hand over if attended to.
The match score between a query and a key is their dot product \(q \cdot k\): large and positive when the vectors point the same way, near zero when orthogonal. We score the query against all keys, turn those scores into a probability distribution with softmax (so the weights are positive and sum to 1), and use that distribution to take a weighted average of the values.
For a whole sequence we stack the queries into a matrix \(Q\) (one row per position), and likewise \(K\) and \(V\). The entire operation is one formula — scaled dot-product attention:
\[ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]
In words: score every query against every key (the \(QK^\top\) table), shrink the scores so softmax stays well-behaved, turn each row of scores into weights that sum to 1, and use those weights to mix the values. Also written: per output row \(i\), \(o_i = \sum_j \text{softmax}_j\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}}\right) v_j\) — a weighted average of value vectors.
Here \(d_k\) is the dimension of the key/query vectors. \(QK^\top\) is the matrix of all pairwise scores; the softmax is applied row by row; the result multiplies \(V\).
Why divide by \(\sqrt{d_k}\)? This scaling is the part people forget and it matters. If the entries of \(q\) and \(k\) are roughly independent with unit variance, their dot product \(q\cdot k = \sum_{i=1}^{d_k} q_i k_i\) has variance \(\approx d_k\) — so its magnitude grows with dimension. For \(d_k = 64\), raw scores swing around \(\pm 8\) or more. Feed scores that large into softmax and it saturates: one weight goes to ~1, the rest to ~0, and the gradient through softmax goes nearly flat (vanishing gradients, no learning). Dividing by \(\sqrt{d_k}\) rescales the scores back to unit variance, keeping softmax in its responsive region.
And here is the same idea in motion — watch a query sweep across the keys, the matching one brightening as its softmax weight swells:
Worked example. Take \(d_k = 2\), one query, three key/value pairs.
\[ q = [1, 0], \quad K = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}, \quad V = \begin{bmatrix} 10 \\ 20 \\ 30 \end{bmatrix} \]
Raw scores \(qK^\top = [1,\ 0,\ 1]\). Scale by \(\sqrt{2}\approx1.414\): \([0.71,\ 0,\ 0.71]\). Softmax: \(e^{0.71}{=}2.03,\ e^{0}{=}1,\ e^{0.71}{=}2.03\), sum \(=5.06\), giving weights \([0.40,\ 0.20,\ 0.40]\). Output \(=0.40(10)+0.20(20)+0.40(30)=20.0\). The query matched key 1 and key 3 equally well (and key 2 less), so the output blends their values accordingly.
import numpy as np
def attention(Q, K, V):
dk = K.shape[-1]
scores = Q @ K.T / np.sqrt(dk) # pairwise match, scaled
w = np.exp(scores - scores.max(-1, keepdims=True))
w /= w.sum(-1, keepdims=True) # softmax rows -> sum to 1
return w @ V # weighted blend of values
Q = np.array([[1.,0.]]); K = np.array([[1.,0.],[0.,1.],[1.,1.]])
V = np.array([[10.],[20.],[30.]])
print(attention(Q, K, V)) # ~[[20.0]]And the same operation in PyTorch — note this single built-in call is what production code actually uses (it dispatches to a fused, memory-efficient kernel under the hood):
import torch
import torch.nn.functional as F
Q = torch.randn(1, 8, 10, 64) # (batch, heads, seq, d_k)
K = torch.randn(1, 8, 10, 64)
V = torch.randn(1, 8, 10, 64)
# scaled dot-product attention, fused; pass is_causal=True for a decoder
out = F.scaled_dot_product_attention(Q, K, V) # (1, 8, 10, 64)
print(out.shape)Mantra: attention is a similarity-weighted average of values. Q vs K decides how much, V decides of what. Everything else (scaling, masking, heads) is engineering around that core.
17.2 — Self-attention, multi-head attention & cross-attention
In the previous section \(Q\), \(K\), and \(V\) arrived ready-made. Where do they come from? In self-attention, all three are learned linear projections of the same input sequence \(X\). Each token produces its own query, key, and value:
\[ Q = XW_Q, \quad K = XW_K, \quad V = XW_V \]
In words: the same input \(X\) is run through three different learned “lenses” — one that asks (Q), one that advertises (K), one that carries content (V). Also written: row-wise, \(q_t = W_Q^\top x_t\), \(k_t = W_K^\top x_t\), \(v_t = W_V^\top x_t\) for each token \(x_t\).
where \(W_Q, W_K, W_V\) are weight matrices learned by gradient descent. “Self” means every position attends to every position in the same sequence — so the word “it” can reach back and pull meaning from “the animal” earlier in the sentence. This is how a Transformer builds context: each token’s new representation is a blend of all tokens, weighted by relevance.
The picture below traces one token’s query fanning out to every key in its own sequence, then collapsing back into one context-rich output vector.
Multi-head attention. A single attention pass produces one set of weights — one “opinion” about what is relevant. But a token may relate to others in several different ways at once: syntactically, semantically, positionally. Think of it like a committee reading the same sentence: one member tracks grammar, another tracks meaning, another tracks who-refers-to-whom — then they pool their notes. Multi-head attention runs \(h\) attention operations in parallel, each with its own projections \(W_Q^{(i)}, W_K^{(i)}, W_V^{(i)}\) into a smaller subspace of dimension \(d_k = d_{\text{model}}/h\). Each head learns to focus on a different kind of relationship; their outputs are concatenated and passed through a final projection \(W_O\):
\[ \text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\,W_O, \qquad \text{head}_i = \text{Attention}(XW_Q^{(i)}, XW_K^{(i)}, XW_V^{(i)}) \]
In words: run \(h\) separate small attention computations, glue their outputs side by side, then mix them with one final linear layer. Also written: \(\text{MultiHead}(X) = \sum_{i=1}^{h} \text{head}_i\, W_O^{(i)}\), where \(W_O\) is split row-wise into per-head blocks \(W_O^{(i)}\) — concatenation-then-project is the same as project-each-then-sum.
With \(d_{\text{model}}=512\) and \(h=8\), each head works in 64 dimensions, so multi-head costs about the same as one full-width head but buys eight independent views. In practice, trained heads specialize in interpretable ways — one tracks the previous token, another links verbs to their subjects, another follows coreference.
# multi-head, vectorized over heads (terse)
def multihead(X, Wq, Wk, Wv, Wo, h):
n, d = X.shape; dk = d // h
Q, K, V = X@Wq, X@Wk, X@Wv # (n, d)
out = []
for i in range(h): # slice each head's subspace
s = slice(i*dk, (i+1)*dk)
out.append(attention(Q[:,s], K[:,s], V[:,s]))
return np.concatenate(out, -1) @ Wo # join heads, mixIn real code you never hand-roll this — PyTorch ships the whole block:
import torch, torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
x = torch.randn(2, 10, 512) # (batch, seq, d_model)
out, attn_weights = mha(x, x, x) # query=key=value=x -> self-attention
print(out.shape, attn_weights.shape) # (2,10,512) (2,10,10)Cross-attention. Self-attention takes \(Q\), \(K\), \(V\) from one sequence. Cross-attention takes the queries from one sequence and the keys/values from another. This is how a decoder consults an encoder: in translation, each output position (query) looks over the entire source sentence (keys/values) to decide what to attend to. Same formula, two sources — it is the mechanism that lets one modality or sequence read from another (and reappears in Multimodal AI, Chapter 24). In nn.MultiheadAttention above, you get cross-attention simply by passing a different sequence as the query: mha(decoder_x, encoder_x, encoder_x).
| Variant | Queries from | Keys / values from | Typical use |
|---|---|---|---|
| Self-attention | sequence \(X\) | sequence \(X\) | building context within one sequence |
| Cross-attention | sequence \(A\) | sequence \(B\) | decoder reading encoder; text reading image |
| Multi-head | same as above, \(\times h\) subspaces | same | several relationship types at once |
A common confusion: heads do not see each other during attention — they run independently and only meet at the concat + \(W_O\) step. If you accidentally share projection weights across heads, every head computes the same thing and multi-head collapses to single-head. Keep the per-head projections distinct.
17.3 — The Transformer block
Attention mixes information across positions, but it is entirely linear in the values and has no per-token nonlinearity. The Transformer block wraps attention with the rest of what a deep network needs. One block, in order:
- Multi-head self-attention — mixes information across tokens.
- Residual connection + layer normalization — add the input back, then normalize.
- Feed-forward network (FFN) — a small MLP applied to each token independently: \(\text{FFN}(x) = \max(0,\ xW_1 + b_1)W_2 + b_2\), usually expanding to \(4\times\) width and back. This is where per-token nonlinear “thinking” happens.
- Residual + layer norm again.
The FFN formula in plain terms and an alternative reading:
\[\text{FFN}(x) = \max(0,\ xW_1 + b_1)\,W_2 + b_2\]
In words: blow each token’s vector up to a wider space, keep only the positive parts (ReLU), then squeeze it back down — a per-token nonlinear transformation. Also written: \(\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2\); modern models often swap ReLU for GELU or a gated SwiGLU unit, \(\text{FFN}(x) = \big(\text{SiLU}(xW_g) \odot (xW_1)\big)W_2\).
The residual connection (\(x + \text{sublayer}(x)\)) gives gradients a clean highway straight through the network, which is what makes stacking dozens of blocks trainable (the same idea as ResNet, Chapter 15). Layer normalization rescales each token’s vector to zero mean and unit variance across its features, stabilizing the distribution that each sublayer sees. Stack \(N\) identical blocks and you have the body of a Transformer.
A note on Pre-LN vs Post-LN. The original Transformer placed layer norm after each residual add (Post-LN). Modern Transformers (GPT-2 onward) put it before each sublayer (Pre-LN: \(x + \text{sublayer}(\text{LN}(x))\)), which keeps the residual path completely clean and makes very deep stacks trainable without a learning-rate warmup crutch. The diagram below shows the common Pre-LN arrangement.
flowchart TB X[Token embeddings + positional encoding] --> L1[LayerNorm] L1 --> A[Multi-Head Self-Attention] X -->|residual| R1((+)) A --> R1 R1 --> L2[LayerNorm] L2 --> F[Feed-Forward Network] R1 -->|residual| R2((+)) F --> R2 R2 --> O[Block output → next block]
A Pre-LN block is only a few lines in PyTorch:
import torch, torch.nn as nn
class TransformerBlock(nn.Module):
def __init__(self, d=512, h=8, mult=4):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.attn = nn.MultiheadAttention(d, h, batch_first=True)
self.ffn = nn.Sequential(nn.Linear(d, mult*d), nn.GELU(),
nn.Linear(mult*d, d))
def forward(self, x, mask=None):
a, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask)
x = x + a # residual around attention
x = x + self.ffn(self.ln2(x)) # residual around FFN
return x
print(TransformerBlock()(torch.randn(2, 10, 512)).shape) # (2, 10, 512)Worked example — the shapes. Take \(d_{\text{model}}=512\), FFN hidden \(=2048\), sequence length \(n=10\). The input is \(X \in \mathbb{R}^{10\times512}\). Self-attention keeps the shape at \(10\times512\). The FFN lifts each of the 10 token vectors \(512 \to 2048\), applies ReLU, projects \(2048 \to 512\) — back to \(10\times512\). Every sublayer is shape-preserving, which is exactly why blocks stack: the output of block \(k\) is a drop-in input to block \(k{+}1\).
Encoder vs decoder. The original Transformer had two stacks. The encoder reads the whole input at once with unmasked self-attention — every token sees every other token, ideal for understanding tasks (BERT is an encoder). The decoder generates output one token at a time and adds two things: causal masking in its self-attention, and a cross-attention sublayer that reads the encoder’s output. Three families result: encoder-only (BERT, classification/embedding), decoder-only (GPT, generation — Chapter 23), and encoder-decoder (T5, original translation Transformer).
flowchart LR
subgraph ENC[Encoder stack]
E1[Self-attn unmasked] --> E2[FFN]
end
subgraph DEC[Decoder stack]
D1[Masked self-attn] --> D2[Cross-attn] --> D3[FFN]
end
src[Source tokens] --> ENC
ENC -->|K,V| D2
tgt[Target so far] --> DEC
DEC --> out[Next-token logits]
Causal masking. When a decoder predicts token \(t\), it must not peek at tokens \(t{+}1, t{+}2, \dots\) — those are the future it is trying to generate. Causal (autoregressive) masking enforces this by setting the attention scores for future positions to \(-\infty\) before softmax, so their weights become exactly 0:
\[ \text{mask}_{ij} = \begin{cases} 0 & j \le i \\ -\infty & j > i \end{cases} \]
In words: a token may attend to itself and everything before it, but the future is blanked out so softmax assigns it zero weight. Also written: the masked score matrix is \(\frac{QK^\top}{\sqrt{d_k}} + M\) where \(M\) is lower-triangular-zero / upper-triangular-\((-\infty)\); equivalently scores.masked_fill(j > i, -inf).
Why it beat RNNs. Two reasons dominate. Parallelism: an RNN processes a sequence step by step — token \(t\) must wait for token \(t{-}1\) — so training cannot be parallelized over the time axis. Self-attention computes all positions’ interactions in one big matrix multiply, fully parallel on a GPU, which is why Transformers scale to billions of parameters. Long-range dependencies: in an RNN, information from token 1 must survive being squeezed through every intermediate hidden state to reach token 1000 — a path of length 1000 where signal decays (the vanishing-gradient problem of Chapter 16). In self-attention, any two tokens are connected by a single step: the path length is \(O(1)\) regardless of distance. The cost is that attention is \(O(n^2)\) in sequence length \(n\) — every pair is compared — which is the central efficiency challenge addressed in Chapter 30.
| Property | RNN / LSTM | Transformer (self-attention) |
|---|---|---|
| Train-time parallelism over sequence | No (sequential) | Yes (one matmul) |
| Path length between distant tokens | \(O(n)\) | \(O(1)\) |
| Compute per layer | \(O(n \cdot d^2)\) | \(O(n^2 \cdot d)\) |
| Handles very long context cheaply | Yes (linear) | No (quadratic) |
| Inherent sense of order | Yes (built in) | No (must be injected) |
That last row is the catch that the next section resolves.
17.4 — Positional encoding
Self-attention has a subtle blind spot: it is permutation-equivariant. Because it computes a weighted average over a set of tokens, shuffling the input tokens just shuffles the outputs identically — the mechanism itself has no notion of “first” or “next.” “Dog bites man” and “man bites dog” would produce the same representations. Since order carries meaning in language, code, and time series, order has to be injected explicitly. That is the job of positional encoding: give each position a distinctive vector and combine it with the token embedding (usually by adding) so the model can tell positions apart.
Sinusoidal encoding (the original Transformer). Position \(pos\) and feature dimension \(i\) get a fixed value from sines and cosines of geometrically increasing wavelengths:
\[ PE_{(pos,\,2i)} = \sin\!\left(\frac{pos}{10000^{2i/d}}\right), \qquad PE_{(pos,\,2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d}}\right) \]
In words: each pair of feature slots is a clock hand spinning at its own speed; reading all the hands at once gives a unique time-stamp for every position. Also written: with angular frequency \(\omega_i = 10000^{-2i/d}\), the pair is \(\big(\sin(pos\,\omega_i),\ \cos(pos\,\omega_i)\big)\) — a point on the unit circle at angle \(pos\,\omega_i\).
Each dimension is a sinusoid of a different frequency, so the full vector is like a binary-clock fingerprint of the position — low dimensions tick fast, high dimensions tick slowly. A neat property: the encoding for position \(pos+k\) is a fixed linear function of the encoding at \(pos\), so the model can learn to attend by relative offset. And because it is a deterministic function, it extrapolates to sequence lengths never seen in training.
Worked example — telling two positions apart. Take \(d=4\), so two frequency pairs. The denominators are \(10000^{0/4}=1\) and \(10000^{2/4}=100\). Position 0 gives \([\sin 0, \cos 0, \sin 0, \cos 0] = [0, 1, 0, 1]\). Position 1 gives \([\sin 1, \cos 1, \sin(0.01), \cos(0.01)] \approx [0.84, 0.54, 0.01, 1.00]\). The fast pair (dims 0–1) swings a lot between adjacent positions; the slow pair (dims 2–3) barely moves — together they form a unique code per position that the model adds onto each token embedding.
def sinusoidal_pe(n, d):
pos = np.arange(n)[:, None] # (n,1)
i = np.arange(0, d, 2)[None, :] # even indices
angle = pos / (10000 ** (i / d)) # (n, d/2)
pe = np.zeros((n, d))
pe[:, 0::2] = np.sin(angle); pe[:, 1::2] = np.cos(angle)
return pe # add to token embeddingsLearned positional embeddings (BERT, GPT-2). Instead of a fixed formula, just keep a trainable embedding vector per position, learned like any other parameter. Simpler and often slightly better in-distribution — but it cannot generalize past the maximum length seen in training, since position 2049 has no learned vector if you only trained to 2048. In PyTorch this is literally one nn.Embedding(max_len, d_model) indexed by position.
Rotary Position Embedding (RoPE) (LLaMA, GPT-NeoX, most modern LLMs). RoPE takes a different tack: instead of adding a position vector, it rotates the query and key vectors by an angle proportional to their position, in 2D feature pairs. Because the dot product of two rotated vectors depends only on the difference of their rotation angles, the resulting attention score naturally encodes relative position — token distance, not absolute index. This makes RoPE work directly inside the attention dot product, behave well on long contexts, and extrapolate via tricks like frequency scaling. It is the de facto standard in today’s large language models (Chapter 23), and §17.5 develops it in full.
| Scheme | Absolute / relative | Extrapolates? | Where applied | Used by |
|---|---|---|---|---|
| Sinusoidal | Absolute (relative-friendly) | Yes | Added to embeddings | Original Transformer |
| Learned | Absolute | No (capped at max len) | Added to embeddings | BERT, GPT-2 |
| RoPE | Relative | Yes (with scaling) | Rotates Q, K in attention | LLaMA, most modern LLMs |
Forgetting positional information is a silent bug, not a crash: the model trains and runs, it just plateaus at mediocre accuracy because it literally cannot tell word order. If a from-scratch Transformer learns suspiciously poorly on ordered data, check that you actually added positional encoding before the first block.
17.5 — Relative position and Rotary Position Embeddings (RoPE)
Sinusoidal and learned positional encodings (§17.4) both share one assumption: they tell the model where each token sits in an absolute sense — token 0, token 1, token 2 — by adding a position-specific vector to each embedding. That works, but it is often the wrong question. When you read “the cat that the dog chased,” what matters for resolving “chased” is how far apart it is from “dog,” not whether the sentence started at index 5 or index 500. The model overwhelmingly cares about the distance between two tokens, not their absolute slots.
This is the idea behind relative position encoding: make attention depend on \((n - m)\), the gap between a query at position \(m\) and a key at position \(n\), rather than on \(m\) and \(n\) separately. A model built this way generalizes better — a pattern learned at the start of a sequence still holds halfway through, because only the relative offset enters the computation.
The RoPE idea: rotate, don’t add
Rotary Position Embeddings (RoPE) achieve relative position through a beautifully simple trick. Instead of adding a position vector to the embedding, RoPE rotates the query and key vectors by an angle proportional to their position.
Group the \(d\)-dimensional query and key into \(d/2\) pairs of coordinates, so each pair \((x_1, x_2)\) is a little 2D vector. For a token at position \(m\), rotate each pair by angle \(m\theta_i\), where \(\theta_i\) is a fixed frequency that differs from pair to pair:
\[R(m\theta_i)\begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{bmatrix}\begin{bmatrix} x_1 \\ x_2 \end{bmatrix}\]
In words: take each pair of coordinates as a 2D arrow and spin it by an angle that grows with the token’s position — far-along tokens get more twist. Also written: in complex form, treat \(z = x_1 + i x_2\); then the rotation is just \(z \mapsto z\,e^{i m \theta_i}\) (multiply by a unit complex number).
Low-index pairs use a high frequency (they spin fast as position grows); high-index pairs use a low frequency (they barely turn). The standard choice is \(\theta_i = 10000^{-2i/d}\) — the same geometric spread of wavelengths used by sinusoidal encoding, but applied as rotation rather than addition.
A single token’s vector is rotated by its own absolute angle \(m\theta_i\). The relative term only shows up later, in the dot product between a query and a key — not in how any one vector is turned.
The magic: the dot product becomes relative
Here is the plain-English version. Think of the query and key as two clock hands. RoPE turns the query’s hand by an amount set by its position, and the key’s hand by an amount set by its position. The attention score is the dot product, which depends on the angle between the two hands. Turn both hands by the same starting offset and the gap between them does not change — so only the difference of the two positions survives. Where they each started cancels out.
In symbols: rotate the query by \(m\theta_i\) and the key by \(n\theta_i\), and within each 2D pair the dot product depends only on the difference of the two angles:
\[\big(R(m\theta_i)\,q\big) \cdot \big(R(n\theta_i)\,k\big) = q^\top R\big((n-m)\theta_i\big)\,k\]
In words: rotating both vectors and then taking their dot product is the same as rotating just one of them by the gap between the two positions — the absolute positions cancel out. Also written: since rotations satisfy \(R(a)^\top R(b) = R(b - a)\), we have \((R(m)q)^\top(R(n)k) = q^\top R(m)^\top R(n)\,k = q^\top R(n-m)\,k\).
The absolute positions \(m\) and \(n\) cancel; only \((n - m)\) survives. So even though RoPE applies an absolute rotation to each token independently (cheap, no cross-token bookkeeping), the attention it produces is automatically relative. You get relative-position behavior for free, without ever computing a relative offset.
Why RoPE is the modern default
RoPE is the position method in Llama, Mistral, Qwen, DeepSeek, Gemma, and most current open models. The reasons are practical:
- Zero extra parameters — the rotation angles are fixed by the formula, nothing to learn.
- Applied where it belongs — rotation happens on \(q\) and \(k\) inside each attention head, leaving the value vectors and the residual stream untouched.
- Relative by construction — the distance-only property above is exactly what language modeling wants.
- Long-context friendly — because behavior keys off relative distance, the same model handles much longer sequences gracefully, and the rotation can be re-scaled to stretch even further.
Stretching RoPE past its training length
A model trained at, say, 4K tokens has only ever seen rotation angles up to \(4096\,\theta_i\). Feed it 32K tokens and the angles run off into territory it never saw, and quality collapses. A family of lightweight rescaling tricks — Position Interpolation, NTK-aware scaling, and YaRN — fixes this by squashing or stretching the rotation angles back into the trained range, letting a RoPE model reach far beyond its original length with little or no retraining.
Rotation in a few lines
import numpy as np
def rope_rotate(vec, pos, base=10000.0):
# vec: 1D array of even length d; rotate each (x1, x2) pair by pos * theta_i
d = len(vec)
out = np.empty_like(vec)
for i in range(d // 2):
theta = base ** (-2.0 * i / d)
a = pos * theta
x1, x2 = vec[2*i], vec[2*i + 1]
out[2*i] = x1 * np.cos(a) - x2 * np.sin(a)
out[2*i + 1] = x1 * np.sin(a) + x2 * np.cos(a)
return out
# The payoff: score depends only on the gap (n - m), not absolute m, n.
q = np.random.randn(8); k = np.random.randn(8)
s1 = rope_rotate(q, 2) @ rope_rotate(k, 5) # gap = 3
s2 = rope_rotate(q, 100) @ rope_rotate(k, 103) # gap = 3
assert np.allclose(s1, s2) # same gap -> same attention scoreThe two scores match because both pairs sit three positions apart — the absolute indices (2,5) versus (100,103) make no difference. That is relative position, delivered purely through geometry.
17.6 — ALiBi and other position methods (brief)
RoPE injects position by rotating \(q\) and \(k\). ALiBi (Attention with Linear Biases) takes an even more direct route: don’t touch the embeddings at all — just subtract a penalty from each attention score that grows with the distance between the two tokens. Before the softmax, the score for a query attending \(j\) positions back gets a bias of \(-s \cdot j\), where \(s\) is a fixed per-head slope (different heads use different slopes, so some look near, others far). Tokens far in the past are gently discouraged; nearby tokens dominate.
The biased score is a one-line change:
\[\text{score}_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}} - s\,(i - j)\]
In words: start from the usual match score and shave off a fixed amount for every step the key is further in the past — the farther back, the bigger the penalty. Also written: add a distance bias matrix to the scores, \(S' = \frac{QK^\top}{\sqrt{d_k}} + B\) with \(B_{ij} = -s\,(i-j)\) for \(j \le i\) (and \(-\infty\) for \(j > i\) to keep it causal).
Because the bias is a simple linear function of distance, it keeps working at any length: a model trained at 1K tokens extrapolates to far longer sequences with little degradation, which was ALiBi’s headline result. Two other methods round out the landscape. T5 relative bias also adds a bias to the scores, but a learned scalar per relative-distance bucket rather than a fixed slope — flexible, but the buckets cap how far it extrapolates. And NoPE (No Positional Encoding) is the surprising observation that a decoder-only model can, in some settings, infer order on its own: the causal mask already breaks the symmetry (token \(i\) sees only tokens \(\le i\)), which can give the model enough signal to reconstruct position without any explicit encoding — though it is mainly a small-scale, decoder-only result, not a reliable default.
| Method | How it injects position | Relative? | Extrapolation to longer context |
|---|---|---|---|
| Sinusoidal / learned (§17.4) | Add a position vector to the input embedding | No (absolute) | Poor |
| RoPE | Rotate \(q\), \(k\) by a position-dependent angle | Yes (emergent from rotation) | Good, strong with PI / NTK / YaRN |
| ALiBi | Subtract a linear distance penalty from scores | Yes | Strong, built for it |
| T5 relative bias | Add a learned bias per distance bucket | Yes | Limited by bucket range |
| NoPE | Nothing explicit — relies on the causal mask | Implicit | Surprisingly decent at small scale |
Which position method should I use?
flowchart TD
A{New decoder-only LLM?} -->|Yes| B{Need long-context<br/>extrapolation?}
A -->|No, encoder for<br/>understanding| E[Learned or sinusoidal<br/>e.g. BERT-style]
B -->|Yes, mostly via<br/>relative distance| C[RoPE + PI / NTK / YaRN<br/>the default for Llama/Mistral/Qwen]
B -->|Yes, want train-short<br/>infer-long, simplest| D[ALiBi<br/>linear distance bias]
B -->|No, fixed length is fine| C
17.7 — Inference: the KV cache and attention variants that shrink it
So far we have treated attention as a one-shot computation over a fixed sequence. But a decoder-only model generates text one token at a time, and that streaming setting exposes a different cost — one that dominates the economics of serving LLMs.
The intuition. Imagine writing an essay where, to add each new word, you re-read every word you have written so far from the beginning. That is exactly what a naive decoder does: to produce token 1000 it recomputes the keys and values for tokens 1 through 999, even though those never change. The fix is the obvious one — write each word’s notes down once and keep them. That notebook is the KV cache.
The animation below shows that notebook filling: each step appends one fresh key/value tile and the new query reads back over everything stored so far — nothing earlier is ever recomputed.
What gets cached. At generation step \(t\), the new token’s query \(q_t\) must attend over the keys and values of all previous tokens. But \(k_1,\dots,k_{t-1}\) and \(v_1,\dots,v_{t-1}\) were already computed at earlier steps and do not depend on \(t\). So we store them and, each step, only compute the new \(k_t, v_t\), append them, and run one query against the growing cache:
\[ o_t = \text{softmax}\!\left(\frac{q_t [k_1,\dots,k_t]^\top}{\sqrt{d_k}}\right)[v_1,\dots,v_t] \]
In words: at each step you append one new key/value to a stored list and attend the single new query over the whole list — no past key or value is ever recomputed. Also written: with cache \(\mathcal{C}_t = \mathcal{C}_{t-1} \cup \{(k_t, v_t)\}\), step cost is \(O(t\,d)\) instead of the \(O(t^2 d)\) of recomputing from scratch.
This turns per-step work from quadratic-in-position back down to linear, and it is why every production LLM serving stack (vLLM, TGI, llama.cpp) is built around the KV cache. The price is memory: the cache holds \(2 \times n_{\text{layers}} \times n_{\text{heads}} \times n \times d_k\) numbers and grows with every generated token. For a long conversation it can dwarf the model weights themselves — the KV cache, not the parameters, is usually what fills GPU memory at long context.
# Hugging Face does KV-caching for you; use_cache=True is the default at generation.
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
ids = tok("Attention is all you", return_tensors="pt").input_ids
out = model.generate(ids, max_new_tokens=20, use_cache=True) # cache on
print(tok.decode(out[0]))Shrinking the cache. Because the cache is the bottleneck, modern architectures cut its size by sharing keys and values across query heads:
- Multi-Query Attention (MQA) — all query heads share a single key/value head. The cache shrinks by a factor of \(h\), but quality can dip.
- Grouped-Query Attention (GQA) — the middle ground used by Llama-2/3, Mistral, and most current models: query heads are split into a few groups, each group sharing one K/V head. A handful of K/V heads (say 8 query heads sharing 1, repeated for 8 groups) recovers nearly full quality at a fraction of the cache.
flowchart TB
subgraph MHA["Multi-Head: 8 Q, 8 KV"]
direction LR
q1[Q×8] --- kv1[KV×8]
end
subgraph GQA["Grouped-Query: 8 Q, 2 KV"]
direction LR
q2[Q×8] --- kv2[KV×2]
end
subgraph MQA["Multi-Query: 8 Q, 1 KV"]
direction LR
q3[Q×8] --- kv3[KV×1]
end
| Scheme | KV heads (for 8 query heads) | KV cache size | Quality | Used by |
|---|---|---|---|---|
| Multi-head (MHA) | 8 | \(1\times\) (full) | Best | original Transformer, GPT-2 |
| Grouped-query (GQA) | 2–4 | \(\tfrac14\)–\(\tfrac12\) | ≈ MHA | Llama-2/3, Mistral |
| Multi-query (MQA) | 1 | \(\tfrac18\) | slight drop | PaLM, Falcon |
A useful rule of thumb when sizing a deployment: at long context the KV cache, not the weights, is what you run out of. GQA and MQA exist almost entirely to make that cache fit — they are inference-economics decisions baked into the architecture.
17.8 — Positional encoding meets vision: the Vision Transformer
Attention is not just for text. The intuition that transfers: if you can chop anything into a sequence of pieces and embed each piece as a vector, a Transformer can model it. Images are the cleanest example, and the Vision Transformer (ViT) showed that the very same encoder block from §17.3 — no convolutions at all — can match or beat CNNs on image classification given enough data.
The trick is “patchify.” A \(224\times224\) image is sliced into a grid of, say, \(16\times16\)-pixel patches — \(14\times14 = 196\) patches. Each patch is flattened (\(16\times16\times3 = 768\) numbers) and run through one linear layer to become a token embedding, exactly like a word embedding. Add a positional encoding (so the model knows which patch sat where), prepend a learnable [CLS] token whose final representation feeds the classifier, and pour the sequence into a standard Transformer encoder.
flowchart LR IMG[224×224 image] --> P[split into 16×16 patches] P --> E[linear embed each patch → tokens] E --> POS[+ positional encoding, prepend CLS] POS --> T[Transformer encoder stack] T --> CLS[CLS token → classifier]
# A pretrained Vision Transformer in a few lines with Hugging Face.
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
proc = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
inputs = proc(images=Image.open("cat.jpg"), return_tensors="pt")
logits = model(**inputs).logits
print(model.config.id2label[logits.argmax(-1).item()])Why this matters beyond images. The patchify idea generalizes: audio becomes a sequence of spectrogram patches, video adds a time axis of frame-patches, and AlphaFold treats residues as tokens. The lesson is architectural — attention is a general-purpose set-mixing operator, and “what is a token?” is a modeling choice. Positional encoding is what re-injects the structure (2D grid position for images, time for audio) that flattening into a set threw away. This is the doorway into Multimodal AI (Chapter 24), where cross-attention (§17.2) lets a text query read directly from a sequence of image patches.
17.9 — Beyond Quadratic Attention: Efficient Attention and State-Space Models
Self-attention is wonderful because every token can look at every other token directly. That is also its curse. The moment you write down the attention scores, you have committed to a table with one row and one column for every token in the sequence. For a 1,000-token document that is a million entries; for a 100,000-token book it is ten billion. The dream of feeding a model an entire codebase, a legal contract, or an hour of audio keeps running into the same wall: attention costs grow with the square of the sequence length. This section is about how we climb over that wall — first by approximating attention so it scales linearly, and then by abandoning the attention formula entirely in favor of state-space models that process sequences the way a running tally processes numbers.
The shape of the bottleneck
Recall the attention computation for a sequence of \(n\) tokens with model dimension \(d\). We form queries, keys, and values \(Q, K, V \in \mathbb{R}^{n \times d}\), then compute
\[ \text{Attention}(Q, K, V) = \underbrace{\text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right)}_{S \,\in\, \mathbb{R}^{n \times n}} V . \]
The trouble lives in \(S = QK^\top\). That matrix is \(n \times n\). Building it costs \(O(n^2 d)\) arithmetic, and — usually the real killer — storing it costs \(O(n^2)\) memory. Double the sequence length and you quadruple both. Everything else in a transformer (the feed-forward layers, the projections) scales linearly in \(n\), so for long sequences this one matrix dominates the entire budget.
Here is the asymmetry that motivates everything below. A recurrent network reads a sequence in \(O(n)\) time and remembers the past in a fixed-size state — cheap, but it must compress all history into that one state and it cannot be parallelized across time. Attention refuses to compress: it keeps every token available, which is why it is so expressive, but it pays \(O(n^2)\) for the privilege. The research program of this section is to get attention’s quality at a recurrence’s price.
| Approach | Time | Memory | Parallel over sequence? | Keeps full history? |
|---|---|---|---|---|
| RNN / LSTM | \(O(n)\) | \(O(1)\) state | No (sequential) | No (compressed) |
| Full attention | \(O(n^2 d)\) | \(O(n^2)\) | Yes | Yes |
| Linear attention | \(O(n d^2)\) | \(O(d^2)\) state | Yes (or recurrent) | Compressed |
| State-space (S4/Mamba) | \(O(n d)\)–\(O(n \log n)\) | \(O(d)\) state | Yes (convolution/scan) | Compressed |
The quadratic cost is not in “the attention idea” — it is specifically in materializing the \(n \times n\) score matrix and applying softmax to it. Almost every efficient method is a different answer to the question: can we get the effect of \(S V\) without ever building \(S\)?
flowchart LR
A["n tokens"] --> B["Q Kᵀ<br/>n×n scores"]
B --> C["softmax"]
C --> D["× V<br/>output"]
B -.->|"this is the<br/>O(n²) wall"| W["memory blows up<br/>as n grows"]
style B fill:#fdd,stroke:#c33
style W fill:#fee,stroke:#c33
17.10 — Linear and Efficient Attention: Getting Rid of the Square
The first family of fixes keeps the attention recipe but reorganizes the arithmetic so the \(n \times n\) matrix never appears. The cleanest version is linear attention, and the trick is a piece of associativity you already know from multiplying matrices in a smart order.
Reordering the multiplication
Softmax couples the two matrix products: you must finish \(QK^\top\) before you can normalize, and only then multiply by \(V\). Suppose instead we replace the exponential similarity with a simpler one that factorizes. Write the similarity between query \(i\) and key \(j\) as \(\phi(q_i)^\top \phi(k_j)\) for some feature map \(\phi\) applied to each vector independently. Then the output for token \(i\) becomes
\[ o_i = \frac{\sum_j \big(\phi(q_i)^\top \phi(k_j)\big)\, v_j}{\sum_j \phi(q_i)^\top \phi(k_j)} = \frac{\phi(q_i)^\top \Big(\sum_j \phi(k_j) v_j^\top\Big)}{\phi(q_i)^\top \sum_j \phi(k_j)} . \]
In words: because the similarity splits into a query-part times a key-part, you can sum up all the keys-and-values once into a small running summary, then let each query read from that summary instead of touching every key. Also written: the whole layer is \(O = \big(\phi(Q)\big)\big(\phi(K)^\top V\big)\) — compute the small \(d'\times d\) matrix \(\phi(K)^\top V\) first, instead of the big \(n\times n\) matrix \(\phi(Q)\phi(K)^\top\).
Look at what happened. The inner sum \(\sum_j \phi(k_j) v_j^\top\) does not depend on \(i\) at all. We can compute it once — it is a \(d' \times d\) matrix, where \(d'\) is the feature dimension — and then every query just reads from it. We never form an \(n \times n\) object. The cost drops to \(O(n\, d' d)\): linear in sequence length.
The associativity is the whole game. Full attention computes \((QK^\top)V\) — first an \(n \times n\), then times \(V\). Linear attention computes \(Q(K^\top V)\) — first a small \(d \times d\), then times \(Q\). Same parentheses-shuffle that lets you multiply a tall-skinny by a small matrix cheaply.
flowchart TB
subgraph FULL["Full attention: (Q Kᵀ) V"]
direction LR
F1["Q<br/>n×d"] --> F2["Q Kᵀ<br/>n×n 😱"] --> F3["× V<br/>n×d"]
end
subgraph LIN["Linear attention: Q (Kᵀ V)"]
direction LR
L1["Kᵀ V<br/>d×d 🙂"] --> L2["Q × ·<br/>n×d"]
end
style F2 fill:#fdd,stroke:#c33
style L1 fill:#dfd,stroke:#3a3
A tiny worked example
Let us make the saving concrete with \(n = 4\) tokens and \(d = 2\), using the identity feature map \(\phi(x) = x\) (the simplest choice — real methods use \(\phi(x) = \text{elu}(x)+1\) or random features to keep things positive).
import numpy as np
np.random.seed(0)
n, d = 4, 2
Q = np.random.randn(n, d); K = np.random.randn(n, d); V = np.random.randn(n, d)
# Full path: build the n×n matrix (unnormalized, no softmax, for illustration)
S = Q @ K.T # n×n <- the object we want to avoid
out_full = S @ V # n×d
# Linear path: never build n×n
KV = K.T @ V # d×d <- summary of all keys+values
out_lin = Q @ KV # n×d
print(np.allclose(out_full, out_lin)) # True — same answer, different cost
print("n×n entries built:", S.size, "| d×d entries built:", KV.size)True
n×n entries built: 16 | d×d entries built: 4
The two paths produce identical numbers, but the second never allocated the \(n \times n\) matrix. At \(n = 16\) the difference is trivial; at \(n = 100{,}000\) it is the difference between a model that runs and one that runs out of memory. The catch, of course, is the softmax we dropped — genuine softmax attention does not factorize, so linear attention only approximates it, and the approximation costs some expressive power.
The recurrent twin
There is a beautiful second view. Because \(\sum_j \phi(k_j) v_j^\top\) is just an accumulating sum, we can build it token by token. Define a running state \(S_t = S_{t-1} + \phi(k_t) v_t^\top\). Then linear attention is a recurrent network with a matrix-valued hidden state:
\[ S_t = S_{t-1} + \phi(k_t)\, v_t^\top, \qquad o_t = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t}, \qquad z_t = z_{t-1} + \phi(k_t). \]
In words: keep a running tally of all keys-times-values seen so far; each new token adds its own contribution to the tally, and the output is just the current query read against that tally. Also written: \(S_t = \sum_{j \le t} \phi(k_j) v_j^\top\) — the recurrence is the prefix sum of the parallel form, so both compute the same \(S_t\).
So linear attention has two faces: a parallel form for training (compute all features, do one big matrix product) and a recurrent form for inference (carry a fixed-size state, \(O(1)\) per new token). This duality — parallel to train, recurrent to generate — is exactly the property that the next family, state-space models, takes to its logical conclusion.
Other efficient-attention strategies
Linear attention is one branch of a larger family. The main ideas, briefly:
| Family | Core idea | Examples | Tradeoff |
|---|---|---|---|
| Kernel / linear | Factorize similarity via \(\phi\) | Linear Transformer, Performer | Approximation error vs. softmax |
| Sparse / local | Each token attends to a few others (window, strided, global tokens) | Longformer, BigBird, Sparse Transformer | Picks patterns by hand; can miss long-range links |
| Low-rank | Project \(K, V\) to length \(k \ll n\) | Linformer | Assumes attention is low-rank |
| IO-aware exact | Keep full attention, tile it to avoid storing \(S\) | FlashAttention | Still \(O(n^2)\) compute, but \(O(n)\) memory |
FlashAttention is often lumped in with “efficient attention,” but it is a different kind of win: it computes exact softmax attention and is still \(O(n^2)\) in arithmetic. What it removes is the \(O(n^2)\) memory traffic by never writing the full score matrix to slow GPU memory. Do not expect it to make a million-token context cheap — it makes the contexts you already use faster and lighter, but the quadratic compute wall is still there.
17.11 — State-Space Models: S4 and Mamba
Sparse and linear attention chip away at the constant; state-space models change the algorithm. The starting point is an idea borrowed from control theory and signal processing: model a sequence as a continuous system with a hidden state that evolves over time. You feed in a signal \(u(t)\), an internal state \(x(t)\) integrates it, and you read out \(y(t)\).
The state-space recurrence
A linear state-space model is defined by four matrices \(A, B, C, D\):
\[ x'(t) = A\,x(t) + B\,u(t), \qquad y(t) = C\,x(t) + D\,u(t). \]
In words: a fixed-size memory (\(x\)) drifts and decays on its own (\(A\)), absorbs each new input (\(B\)), and the output is read off the memory (\(C\)), plus an optional direct passthrough of the input (\(D\)). Also written: in discrete time this becomes \(x_t = \bar A x_{t-1} + \bar B u_t,\ y_t = C x_t\) — an RNN whose update has no nonlinearity inside it.
The state \(x(t)\) is a fixed-size vector — say 64 numbers — that summarizes everything the model has seen so far. \(A\) says how the state decays and mixes on its own; \(B\) says how new input enters; \(C\) says how to read an output off the state. To use this on a sequence of discrete tokens, we discretize it with step size \(\Delta\), turning the differential equation into a plain recurrence:
\[ x_t = \bar{A}\, x_{t-1} + \bar{B}\, u_t, \qquad y_t = C\, x_t, \]
where \(\bar{A}, \bar{B}\) come from \(A, B, \Delta\) via a fixed formula. This is just an RNN with a linear update — no nonlinearity inside the recurrence. That linearity is the secret to everything that follows.
The recurrence–convolution duality
Because the update is linear, you can unroll it. Starting from \(x_0 = 0\):
\[ y_t = \sum_{j=0}^{t} C\,\bar{A}^{\,t-j}\,\bar{B}\; u_j = \sum_{j=0}^{t} \bar{K}_{t-j}\, u_j, \qquad \bar{K}_k = C\,\bar{A}^k\,\bar{B}. \]
In words: unrolling the linear recurrence shows each output is a weighted sum of all past inputs, where the weights \(\bar K_k\) are a single fixed filter — so running the SSM is the same as sliding one big kernel over the input. Also written: \(y = \bar K * u\) — a convolution of the input with kernel \(\bar K = (\bar K_0, \bar K_1, \dots)\), computable in \(O(n\log n)\) with an FFT.
That sum is a convolution. The whole output sequence is the input convolved with a single long kernel \(\bar{K} = (\bar{K}_0, \bar{K}_1, \bar{K}_2, \dots)\). So an SSM has the same dual nature as linear attention, but more dramatically:
- As a recurrence (\(x_t = \bar{A} x_{t-1} + \bar{B} u_t\)): \(O(1)\) work and memory per token — perfect for generating one token at a time.
- As a convolution (\(y = \bar{K} * u\)): compute the kernel once, then apply it to the whole sequence in parallel with an FFT in \(O(n \log n)\) — perfect for training on long sequences at once.
flowchart LR
subgraph TRAIN["Training view: convolution"]
direction LR
T1["whole input<br/>u₀…uₙ"] --> T2["FFT convolve<br/>with kernel K̄"] --> T3["all outputs<br/>O(n log n), parallel"]
end
subgraph GEN["Inference view: recurrence"]
direction LR
G1["uₜ + state xₜ₋₁"] --> G2["xₜ = Ā xₜ₋₁ + B̄ uₜ"] --> G3["yₜ<br/>O(1) per token"]
end
style T2 fill:#dff,stroke:#39c
style G2 fill:#ffd,stroke:#ca3
You train in the fast parallel mode and deploy in the cheap streaming mode — the same weights, two algorithms. This is why an SSM can have a million-token context: nothing ever scales with \(n^2\).
What S4 added, and a tiny example
A naive linear RNN forgets the distant past: powers \(\bar{A}^k\) either explode or decay to zero, so \(\bar{K}_k\) vanishes and information from far back is lost. S4 (Structured State Space) made this work by choosing \(A\) with a special structure (the HiPPO matrix) designed so the state provably retains a compressed, multi-resolution memory of the entire history, plus a clever way to compute the long convolution kernel stably. The payoff was the first model to handle sequences of tens of thousands of steps and crush long-range benchmarks that defeated transformers.
Here is the recurrence in its barest form, scalar state per channel, to feel the dynamics:
import numpy as np
# one channel, scalar SSM: x_t = a x_{t-1} + b u_t ; y_t = c x_t
a, b, c = 0.9, 0.5, 1.0 # a<1 -> state remembers but slowly forgets
u = np.array([1, 0, 0, 0, 0, 0], float) # a single impulse at t=0
# recurrent form
x, ys = 0.0, []
for ut in u:
x = a*x + b*ut
ys.append(c*x)
print("recurrent:", np.round(ys, 3))
# convolution form: kernel K_k = c * a^k * b
K = np.array([c * a**k * b for k in range(len(u))])
ys_conv = np.convolve(u, K)[:len(u)]
print("convolved:", np.round(ys_conv, 3))recurrent: [0.5 0.45 0.405 0.365 0.328 0.295]
convolved: [0.5 0.45 0.405 0.365 0.328 0.295]
The impulse response is the kernel itself — a decaying echo \(0.5, 0.45, 0.405, \dots\) The recurrence and the convolution are literally the same function computed two ways. Notice how \(a = 0.9\) controls memory: closer to 1 means the echo lingers longer (more long-range memory); smaller means it fades fast.
Mamba: making the state-space selective
S4 had one weakness that attention does not: its \(A, B, C\) matrices are fixed — the same filter is applied to every token regardless of content. Attention, by contrast, decides at runtime what to attend to based on the actual tokens. A fixed convolution cannot say “this token is important, remember it; that one is filler, skip it.”
Mamba fixes this by making the SSM selective: it lets \(B\), \(C\), and the step size \(\Delta\) be functions of the current input \(u_t\). Now the model can, on the fly, decide how much of each token to write into the state and how fast the state should forget. A delimiter token can trigger a reset; a key fact can be written in boldly; padding can be ignored. This input-dependence is what closes much of the quality gap with attention.
But it breaks the magic trick. Once \(\bar{A}, \bar{B}\) depend on the token, the model is no longer a fixed convolution — there is no single kernel to FFT. Mamba recovers parallel training with a hardware-aware parallel scan: a prefix-scan (the same associative-scan idea behind a parallel cumulative sum) computes the time-varying recurrence across the sequence in \(O(n)\) work and \(O(\log n)\) depth, fused into fast GPU memory so the large states never spill to slow memory. The result is linear-time training and linear-time, constant-memory generation, with selectivity that rivals attention on language.
| S4 | Mamba (selective SSM) | |
|---|---|---|
| \(A, B, C, \Delta\) | Fixed for all tokens | \(B, C, \Delta\) depend on the input token |
| Can ignore/focus per token? | No — same filter everywhere | Yes — content-based gating |
| Parallel training | Global convolution (FFT) | Associative parallel scan |
| Generation cost | \(O(1)\) per token | \(O(1)\) per token |
| Long-context scaling | Linear | Linear |
State-space models are not a free lunch that dominates transformers. Their fixed-size state is a hard memory budget: tasks that need exact recall of a specific earlier token — “what was the 4,000th word?”, precise copying, in-context lookup tables — favor attention, which keeps every token addressable. SSMs win on long, streaming, throughput-bound sequences where a compressed running summary is enough. In practice many strong long-context models are hybrids, interleaving a few attention layers (for sharp recall) with many SSM or linear-attention layers (for cheap long-range mixing) — taking the recurrence’s price and the attention’s memory only where each is worth it.
The throughline across this whole section is one idea seen from three angles. Attention keeps all of history and pays \(O(n^2)\). A recurrence compresses history into a fixed state and pays \(O(n)\) but cannot parallelize. Linear attention and state-space models thread the needle: a linear update that can be unrolled into a parallel form for training and rolled up into a recurrent form for inference. Whether you arrive there by factorizing the softmax or by discretizing a differential equation, the destination is the same — subquadratic sequence modeling that finally lets the context window grow.
17.12 — Assembling and training the full Transformer
We have built every piece — embeddings, positional encoding, the block, the stack. But a stack of Transformer blocks is just a sequence-to-sequence function \(\mathbb{R}^{n\times d} \to \mathbb{R}^{n\times d}\); it does not yet predict anything. This section closes the loop: how the pieces snap together into a model that turns token IDs into a probability distribution over the next token, and what objective trains it.
The intuition. Think of the full model as a sandwich. The bottom slice turns discrete token IDs into vectors (the embedding layer). The thick filling is the stack of blocks, each one re-mixing those vectors with context. The top slice turns the final vectors back into token IDs — or rather, into a score for every possible next token (the output head, or “unembedding”). Training just nudges all three so the top slice’s guesses match the real next tokens.
Here is a token of data flowing up through that sandwich, from ID to a next-token guess:
The output head and softmax. After the last block, each token’s \(d\)-dimensional vector is projected by a matrix \(W_U \in \mathbb{R}^{d \times |V|}\) (vocabulary size \(|V|\)) into a vector of logits — one raw score per vocabulary word. Softmax turns those logits into a probability distribution:
\[ P(\text{token} = w \mid \text{context}) = \frac{\exp(z_w)}{\sum_{w'} \exp(z_{w'})}, \qquad z = h_{\text{last}}\, W_U \]
In words: project the final hidden vector onto every word in the vocabulary to get a raw score each, then exponentiate-and-normalize so the scores become probabilities summing to 1. Also written: \(P = \text{softmax}(h_{\text{last}} W_U)\); many models tie weights, setting \(W_U = W_E^\top\) (the unembedding is the transpose of the input embedding), which saves parameters and often helps.
The training objective: next-token cross-entropy. A decoder-only language model is trained by self-supervision — the data is its own label. For every position, the target is simply the actual next token in the text, and the loss is the negative log-probability the model assigned to it. Averaged over the sequence:
\[ \mathcal{L} = -\frac{1}{n}\sum_{t=1}^{n} \log P(x_{t+1} \mid x_1, \dots, x_t) \]
In words: at each position, look up how much probability the model gave to the token that actually came next, take its log, negate it, and average — punishing the model in proportion to how surprised it was by the truth. Also written: \(\mathcal{L} = \frac{1}{n}\sum_t \text{CrossEntropy}(\text{logits}_t,\ x_{t+1})\); exponentiating it gives perplexity \(= e^{\mathcal{L}}\), the effective number of words the model is choosing between.
Because of causal masking (§17.3), all \(n\) positions can be trained in parallel in a single forward pass — every position predicts its own next token, and none can cheat by looking ahead. This is the trick that makes pretraining on trillions of tokens feasible: one pass over a sequence yields \(n\) supervised examples at once.
flowchart TB ID["token IDs<br/>[The, cat, sat]"] --> EMB["Embedding lookup W_E"] EMB --> PE["+ positional encoding"] PE --> BLK["N × Transformer blocks<br/>(causal)"] BLK --> LN["final LayerNorm"] LN --> HEAD["Unembed W_U → logits |V|"] HEAD --> SM["softmax → next-token P"] SM --> LOSS["cross-entropy vs. actual next token"]
A complete (tiny) decoder-only model and its training step in PyTorch, reusing the TransformerBlock from §17.3:
import torch, torch.nn as nn, torch.nn.functional as F
class GPTLite(nn.Module):
def __init__(self, vocab, d=256, n_layers=4, n_heads=8, max_len=512):
super().__init__()
self.tok = nn.Embedding(vocab, d)
self.pos = nn.Embedding(max_len, d) # learned positions
self.blocks = nn.ModuleList(TransformerBlock(d, n_heads) for _ in range(n_layers))
self.ln_f = nn.LayerNorm(d)
self.head = nn.Linear(d, vocab, bias=False)
self.head.weight = self.tok.weight # weight tying
def forward(self, idx):
n = idx.size(1)
x = self.tok(idx) + self.pos(torch.arange(n, device=idx.device))
mask = torch.triu(torch.full((n, n), float('-inf')), diagonal=1) # causal
for blk in self.blocks:
x = blk(x, mask=mask)
return self.head(self.ln_f(x)) # (batch, n, vocab) logits
model = GPTLite(vocab=1000)
idx = torch.randint(0, 1000, (2, 64)) # (batch, seq)
logits = model(idx[:, :-1]) # predict next token
loss = F.cross_entropy(logits.reshape(-1, 1000), idx[:, 1:].reshape(-1))
loss.backward() # one training step
print(loss.item())That cross_entropy(logits, targets) call, with targets being the input shifted left by one, is language-model pretraining in miniature. Scale this up — more layers, more heads, RoPE instead of learned positions, GQA for the KV cache, trillions of tokens — and you have the recipe behind every model in Chapter 23.
17.13 — Reading attention: interpretability and what the weights do (and don’t) mean
Attention is unusually inviting to interpret. Every layer hands you a clean \(n \times n\) matrix of weights saying “token \(i\) paid this much attention to token \(j\)” — it looks like a built-in explanation of the model’s reasoning. This section is about what you can and cannot read off those maps, plus a few robust, surprising regularities everyone working with Transformers should know.
The intuition. An attention map is like a heat-map of eye gaze: it tells you where the model looked, which is genuinely informative. But where you look is not the same as why you decided — a value carried by an unattended-to token can still flow through the residual stream, and the FFN does heavy lifting the attention map never shows. So treat attention maps as a useful lens, not a confession.
What attention maps reliably reveal. Probing studies of trained Transformers find heads that specialize in human-legible ways — a strong confirmation of the §17.2 committee picture:
- Positional heads — attend to the previous token, or a fixed offset, implementing local n-gram-like mixing.
- Syntactic heads — link a verb to its subject, a noun to its determiner, a word to its dependency-parse parent.
- Coreference / “induction” heads — a token attends back to the earlier occurrence of the same or a related token. Induction heads (attend to the token that followed the current token last time it appeared, then copy what came next) are the mechanistic core of in-context learning — they let a model continue a pattern like
A B … A → Bit has only ever seen in the prompt.
Robust regularities worth knowing. Two empirical phenomena reshape how you read and serve attention:
- Attention sinks. In almost every trained LLM, a huge fraction of attention from every query piles onto the first token (often a
[BOS]or the literal first word), even though it carries no relevant content. The model uses it as a no-op — a place to dump attention weight when a head has nothing useful to attend to (softmax forces the row to sum to 1, so the weight must go somewhere). This matters in practice: streaming inference (keeping a sliding window of recent tokens) breaks unless you keep the first few sink tokens in the KV cache — drop them and quality collapses. “StreamingLLM” is built entirely on this observation. - Massive activations / outlier dimensions. A few hidden dimensions carry enormous magnitudes; they interact with attention sinks and are the main headache for low-bit quantization (Chapter 30), which must handle those outliers specially.
# Pull real attention weights out of a Hugging Face model and inspect them.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tok = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2", attn_implementation="eager")
ids = tok("The cat sat on the mat", return_tensors="pt").input_ids
with torch.no_grad():
out = model(ids, output_attentions=True) # ask for the maps
attn = out.attentions[0][0] # layer 0: (heads, n, n)
print(attn.shape) # e.g. (12, 7, 7)
print("attention onto token 0 (sink), avg over heads/queries:",
attn[:, :, 0].mean().item()) # typically large“Attention is not explanation.” A well-known line of work showed you can often alter a model’s attention weights substantially without changing its prediction, and that different attention patterns can yield the same output. The lesson: attention maps are evidence about information routing, not a faithful account of why the model decided. For trustworthy explanations, pair attention inspection with gradient attributions, activation patching, and the broader mechanistic-interpretability toolkit (Chapter 33). Use attention maps to form hypotheses, then test them — do not quote them as proof.
17.14 — Quick reference
| Term / formula | Meaning in one line | When / why it matters |
|---|---|---|
| Scaled dot-product attention \(\text{softmax}(QK^\top/\sqrt{d_k})V\) | Similarity-weighted average of values | The single operation every Transformer is built from |
| \(\sqrt{d_k}\) scaling | Rescales scores to unit variance before softmax | Without it softmax saturates and gradients vanish at large \(d_k\) |
| Self-attention | \(Q,K,V\) all projected from the same sequence | Builds context: every token blends in all others |
| Cross-attention | Queries from one sequence, keys/values from another | Decoder reads encoder; text reads image (Ch 24) |
| Multi-head | \(h\) attention ops in parallel subspaces, concat + \(W_O\) | Each head specializes (syntax, coreference, position) |
| FFN \(\max(0,xW_1{+}b_1)W_2{+}b_2\) | Per-token MLP, usually \(4\times\) width then back | Adds per-token nonlinearity attention lacks |
| Residual + LayerNorm | \(x + \text{sublayer}(\text{LN}(x))\) (Pre-LN) | Clean gradient highway; makes deep stacks trainable |
| Causal mask | Set future scores to \(-\infty\) before softmax | Stops a decoder peeking at tokens it must predict |
| Positional encoding | Inject order into a permutation-equivariant set | Mandatory — without it “dog bites man” = “man bites dog” |
| RoPE | Rotate \(q,k\) by angle \(\propto\) position | Attention score depends only on relative distance; today’s default |
| ALiBi | Subtract \(-s(i{-}j)\) distance penalty from scores | Cheap relative position, strong length extrapolation |
| KV cache | Store past \(k,v\); recompute only the new token’s | Turns generation from \(O(t^2)\) to \(O(t)\) per step |
| GQA / MQA | Query heads share fewer K/V heads | Shrinks the KV cache, the real long-context memory wall |
| Linear attention | \(Q(K^\top V)\) instead of \((QK^\top)V\) | \(O(n)\) via associativity; approximates softmax |
| State-space (S4 / Mamba) | Linear recurrence = parallel conv (train) + scan (infer) | Subquadratic long context; trades exact recall for cheap scaling |
| Next-token cross-entropy \(-\frac1n\sum_t\log P(x_{t+1}\mid x_{\le t})\) | Negative log-prob of the true next token | The self-supervised objective that pretrains every LLM |
| Attention sink | First token soaks up “leftover” attention weight | Keep sink tokens in cache or streaming inference collapses |
17.15 — Key takeaways
- Attention is a soft, learned dictionary lookup: a query is compared against all keys, softmax turns the match scores into weights, and the output is a weighted average of the values — \(\text{softmax}(QK^\top/\sqrt{d_k})V\).
- The \(\sqrt{d_k}\) scaling keeps dot-product magnitudes from growing with dimension and pushing softmax into saturated, no-gradient regions.
- Self-attention projects \(Q,K,V\) from the same sequence so every token gathers context from all others; cross-attention draws queries and keys/values from different sequences.
- Multi-head attention runs several attention operations in parallel subspaces, letting each head specialize in a different relationship, then concatenates them.
- A Transformer block = multi-head attention + feed-forward network, each wrapped in a residual connection and layer norm; stack \(N\) of them. Encoders see all tokens; decoders add causal masking and cross-attention.
- Transformers beat RNNs through full sequence parallelism and \(O(1)\) path length between any two tokens — at the price of \(O(n^2)\) attention cost.
- Order must be injected: self-attention is permutation-equivariant, so positional encoding (sinusoidal, learned, RoPE, or ALiBi) is mandatory; RoPE’s rotation makes attention depend only on relative distance and is today’s default.
- Inference is dominated by the KV cache: caching past keys/values turns generation from quadratic to linear per step, and GQA/MQA shrink that cache to make long context fit.
- Attention is a general set-mixer: patchify an image (ViT), spectrogram, or video and the same Transformer block applies — “what is a token?” is a modeling choice.
- Subquadratic alternatives (linear attention, S4, Mamba) buy long context by compressing history into a fixed state that trains in parallel and generates recurrently — trading exact recall for cheap scaling.
- The full model sandwiches the block stack between an embedding layer and an unembedding head, and trains by next-token cross-entropy — one causal forward pass yields \(n\) supervised predictions at once.
- Attention maps are a lens, not a confession: heads specialize legibly (positional, syntactic, induction), and quirks like attention sinks matter for serving — but “attention is not explanation,” so test hypotheses rather than quoting weights as proof.
17.16 — See also
- Chapter 16 — Recurrent & Sequence Models: the RNN/LSTM approach Transformers replaced, and the vanishing-gradient problem attention sidesteps.
- Chapter 14 — Neural Networks (Core): residual connections, layer normalization, and the feed-forward MLP inside each block.
- Chapter 23 — Large Language Models: decoder-only Transformers, RoPE, KV-caching, and scaling attention to billions of parameters.
- Chapter 24 — Multimodal AI: cross-attention as the mechanism for letting one modality read from another, and patch-tokens for images.
- Chapter 30 — AI Infrastructure & Efficient Inference: taming the \(O(n^2)\) cost with FlashAttention, KV-caching, and long-context methods.
- Chapter 15 — Convolutional Neural Networks: residual connections (ResNet) and the Vision Transformer’s reuse of attention for images.
- Chapter 33 — Interpretability & Mechanistic Analysis: induction heads, activation patching, and the case against reading attention maps as explanations.
↪ The thread continues → Chapter 18 · 🎨 Generative Models
Transformers excel at understanding and predicting; aim that power at creating new data — images, audio, molecules — and you arrive at generative models.
📖 All chapters | ← 16 · 🔁 Recurrent & Sequence Models | 18 · 🎨 Generative Models →