flowchart LR
X["x<br/>(B, T, C)"] --> QKV["Linear C → 3C<br/>(one matmul)"]
QKV --> SPLIT["split into<br/>q, k, v<br/>each (B, T, C)"]
SPLIT --> RESHAPE["view + transpose<br/>(B, nh, T, hd)"]
RESHAPE --> ATTN["attention per head<br/>batched matmul<br/>(B, nh, T, hd)"]
ATTN --> MERGE["transpose + view<br/>back to (B, T, C)"]
MERGE --> PROJ["output projection<br/>Linear C → C"]
PROJ --> Y["y<br/>(B, T, C)"]
⚡ Building Transformers from Scratch with PyTorch · Day 5 — Multi-Head Attention: Many Perspectives in Parallel
🏠 ⚡ Course home | ← Day 04 | Day 06 → | 📚 All mini-courses
Day 5 — Multi-Head Attention: Many Perspectives in Parallel
Yesterday you built scaled dot-product attention from scratch: queries meeting keys, softmax over scores, a weighted mix of values, all under a causal mask. It works — but it has a subtle limitation. A single attention operation produces one attention pattern per position. Softmax has to commit: if a token wants to attend strongly to the previous verb and to the sentence’s subject and to a matching quote character, one distribution has to average those needs together. Today we fix that by running several attention operations — heads — in parallel, each in its own lower-dimensional subspace, then stitching their answers back together. Along the way you’ll learn the single most-cursed tensor manipulation in all of deep learning (the view/transpose reshape dance), see how to compute all heads in one batched matmul, collapse three projection layers into one, and finally swap your hand-rolled math for PyTorch’s fused F.scaled_dot_product_attention — the same kernel (FlashAttention) that real GPTs run on. The module you finish today, CausalSelfAttention, drops straight into tomorrow’s transformer block unchanged.
🎯 Today you will: understand why multiple attention heads beat one big head, master the view/transpose reshape that splits channels into heads, compute all heads in a single batched matmul, fuse Q/K/V into one linear layer like GPT-2, and swap in F.scaled_dot_product_attention as a drop-in FlashAttention fast path
Why one head isn’t enough
Recall the shape story from Day 4. Attention takes x of shape (B, T, C) — batch, time, channels — and produces an output of the same shape, where each position’s output is a weighted average of value vectors at positions it’s allowed to see. The weights come from a softmax over query–key dot products.
The problem is the softmax. For a given query position, the attention weights form one probability distribution over the past. One distribution means one “focus.” But language needs many kinds of focus simultaneously. In the sentence “The keys to the cabinet were on the table”, the token were needs to attend back to keys (subject–verb agreement), while the same token might also benefit from attending to table territory for semantic context. A single softmax must blend these, diluting both.
The fix, from the original Attention Is All You Need paper (see the encyclopedia’s Attention & Transformers chapter for the theory), is beautifully cheap: instead of one attention operation over the full C-dimensional space, run n_head independent attention operations, each over a head_dim = C // n_head slice. Each head gets its own learned Q, K, V projections, so each can learn its own kind of relationship — one head tracking syntax, another tracking positional patterns, another tracking rare-token copying. Crucially, the total compute is roughly the same as one big head: you’re not multiplying cost by n_head, you’re splitting the channel dimension across heads.
\[ \text{head}_i = \text{Attention}(XW_i^Q,\; XW_i^K,\; XW_i^V), \qquad \text{MHA}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\, W^O \]
Here’s the full pipeline we’re building today:
Note the symmetry: shape in equals shape out, (B, T, C). That’s what makes attention stackable into deep networks — tomorrow we’ll wrap this module in residual connections precisely because the shapes line up.
The reshape dance: splitting C into heads
This is the part that trips everyone up at least once, so we’ll do it slowly and show what goes wrong if you cheat. Set up the usual toy dimensions:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)
B, T, C = 2, 8, 64 # batch 2, sequence length 8, 64 channels
n_head = 4
head_dim = C // n_head # 16
assert C % n_head == 0, "channels must divide evenly into heads"
x = torch.randn(B, T, C)The assert is not decoration. Every real GPT config obeys n_embd % n_head == 0 (GPT-2 small: 768 = 12 × 64). If it doesn’t divide evenly, there is no clean way to split, and you want a loud failure at construction time, not a cryptic view error mid-forward.
We want to go from (B, T, C) to (B, n_head, T, head_dim). Why that target order? Because PyTorch’s matmul batches over all leading dimensions. If the tensor is (B, nh, T, hd), then q @ k.transpose(-2, -1) treats B × nh as one big batch of independent (T, hd) matrices — every head of every sequence computed in a single kernel launch. The head dimension must sit before the time dimension so that time and head-features are the trailing two dims that matmul operates on.
The correct dance is two steps: first view, then transpose.
# Step 1: view (B, T, C) -> (B, T, n_head, head_dim)
# This just reinterprets the last axis: the 64 channels at each
# position become 4 groups of 16 contiguous channels. No data moves.
x_split = x.view(B, T, n_head, head_dim)
print(x_split.shape) # torch.Size([2, 8, 4, 16])
# Step 2: transpose dims 1 and 2 -> (B, n_head, T, head_dim)
# Now each head sees the full sequence of its own 16-dim slices.
x_heads = x_split.transpose(1, 2)
print(x_heads.shape) # torch.Size([2, 4, 8, 16])Why can’t you skip the transpose and just write x.view(B, n_head, T, head_dim)? Because view never moves data — it only relabels how the same flat buffer is indexed. The memory layout of x is [batch][time][channel]: all 64 channels of position 0, then all 64 channels of position 1, and so on. Viewing that buffer directly as (B, nh, T, hd) would carve it as [batch][head][time][head_dim] — meaning “head 0” would grab the first T × hd = 8 × 16 = 128 floats, which in reality are all 64 channels of positions 0 and 1. You’d be mixing whole positions into what you think are per-head slices. The shapes look right, the code runs, the model silently learns garbage. Watch:
wrong = x.view(B, n_head, T, head_dim) # runs fine. lies.
right = x.view(B, T, n_head, head_dim).transpose(1, 2)
# Head 0, position 0 should be channels 0..15 of position 0:
print(torch.equal(right[0, 0, 0], x[0, 0, :16])) # True
print(torch.equal(wrong[0, 0, 0], x[0, 0, :16])) # True (coincidence: very first row matches)
# Head 0, position 1 should be channels 0..15 of position 1:
print(torch.equal(right[0, 0, 1], x[0, 1, :16])) # True
print(torch.equal(wrong[0, 0, 1], x[0, 0, 16:32])) # True — WRONG data:
# "position 1" of the wrong tensor is actually channels 16..31 of position 0!This is the classic silent transformer bug: no error, no NaN, just a model that never quite learns. Burn the rule in: view to unpack the channel axis, transpose to move heads ahead of time.
Here’s the geometry of the correct split — one position’s 64-channel vector becoming four 16-dim head slices:
All heads in one matmul
With the tensors shaped (B, nh, T, hd), Day 4’s attention math works unchanged — broadcasting does all the per-head bookkeeping for free. Let’s prove it with separate Q, K, V projections first (we’ll fuse them in the next section):
import math
# One projection each, over the FULL channel dim — the head split
# happens after, by slicing the output channels.
Wq = nn.Linear(C, C, bias=False)
Wk = nn.Linear(C, C, bias=False)
Wv = nn.Linear(C, C, bias=False)
q = Wq(x).view(B, T, n_head, head_dim).transpose(1, 2) # (B, nh, T, hd)
k = Wk(x).view(B, T, n_head, head_dim).transpose(1, 2) # (B, nh, T, hd)
v = Wv(x).view(B, T, n_head, head_dim).transpose(1, 2) # (B, nh, T, hd)
# Scores: (B, nh, T, hd) @ (B, nh, hd, T) -> (B, nh, T, T)
# One attention matrix PER head, all computed in one call.
att = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim)
print(att.shape) # torch.Size([2, 4, 8, 8])Two things deserve a pause. First, notice we still use nn.Linear(C, C) — mathematically identical to n_head separate Linear(C, head_dim) projections stacked side by side, because slicing the output of a linear layer is the same as splitting its weight matrix into column blocks. One big matmul beats four small ones on any GPU. Second, the scale factor is \(\sqrt{d_{head}} = \sqrt{16} = 4\), not \(\sqrt{C} = 8\). Each head’s dot products sum over only head_dim terms, so that’s the variance you’re normalizing away. Using \(\sqrt{C}\) over-shrinks the logits and makes early attention too uniform — a soft bug that slows training rather than breaking it.
Now the causal mask and softmax, exactly as Day 4, broadcast across the head dimension:
mask = torch.tril(torch.ones(T, T)) # (T, T)
att = att.masked_fill(mask == 0, float('-inf')) # broadcasts over (B, nh, ·, ·)
att = F.softmax(att, dim=-1) # each row a distribution, per head
y = att @ v # (B, nh, T, T) @ (B, nh, T, hd) -> (B, nh, T, hd)
print(y.shape) # torch.Size([2, 4, 8, 16])Each head has now produced its own (T, hd) answer. To merge them we run the reshape dance in reverse — transpose heads back next to channels, then view-flatten:
y = y.transpose(1, 2) # (B, T, nh, hd) — heads back beside channels
y = y.contiguous().view(B, T, C) # (B, T, C) — concat heads along channels
print(y.shape) # torch.Size([2, 8, 64])Why the .contiguous()? transpose doesn’t move data either — it just swaps strides, leaving the tensor’s memory in the old order. But view demands a contiguous buffer whose layout matches the requested shape. Skip it and PyTorch throws RuntimeError: view size is not compatible with input tensor's size and stride. .contiguous() performs the actual copy into the new order. (You could use .reshape(B, T, C), which calls .contiguous() for you when needed — writing it explicitly documents that a real copy happens here.)
The output projection: letting heads talk
We’re not done. After the merge, position t’s vector is just head 0’s answer in channels 0–15, head 1’s in 16–31, and so on — four opinions sitting side by side in separate lanes, never interacting. The output projection \(W^O\) fixes that:
Wo = nn.Linear(C, C, bias=False)
out = Wo(y) # (B, T, C) — every output channel mixes ALL headsEach output channel of Wo is a learned linear combination of all 64 input channels — i.e., of all four heads. This is where “head 2 found the subject and head 3 found the verb” gets combined into a single useful feature. Without \(W^O\), downstream layers could still mix the lanes eventually, but the attention layer itself would be strictly a concatenation of independent low-rank updates, and — more practically — you’d break the standard architecture that tomorrow’s residual stream expects. Every serious implementation (the original paper, GPT-2, LLaMA) has it. Keep it.
The GPT-2 trick: one fused QKV projection
Three separate Linear(C, C) layers for Q, K, V means three matmuls over the same input x. GPT-2’s implementation (and nanoGPT, which we’re following in spirit) fuses them into a single Linear(C, 3C) — one matmul producing a (B, T, 3C) tensor that we split three ways. Same math, same parameter count, fewer kernel launches, better GPU utilization.
Time to write the real module — the one that ships in our GPT. It takes the config values we’ve been carrying since Day 1:
class CausalSelfAttention(nn.Module):
"""Multi-head causal self-attention, GPT-2 style (fused QKV)."""
def __init__(self, n_embd: int, n_head: int, block_size: int,
dropout: float = 0.0):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
# One matmul computes Q, K, and V for all heads at once.
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
# Output projection: mixes information across heads.
self.proj = nn.Linear(n_embd, n_embd, bias=False)
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Causal mask, precomputed once. register_buffer -> moves with
# .to(device), saves with state_dict, but is NOT a parameter.
self.register_buffer(
"mask",
torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
# (B, T, C) -> (B, T, 3C) -> three (B, T, C) tensors
q, k, v = self.qkv(x).split(C, dim=2)
# The dance: (B, T, C) -> (B, T, nh, hd) -> (B, nh, T, hd)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
# Attention for all heads in one shot: (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, hd)
y = y.transpose(1, 2).contiguous().view(B, T, C) # merge heads
return self.resid_dropout(self.proj(y)) # (B, T, C)Walk through the load-bearing choices:
split(C, dim=2)carves the(B, T, 3C)output into three(B, T, C)chunks along channels. The QKV weight matrix is literally \([W^Q; W^K; W^V]\) stacked — fusing changes nothing mathematically.self.mask[:, :, :T, :T]— the buffer is built at the maximumblock_size, then sliced to the actual sequence length at runtime. This is why generation with growing context “just works”: feed 5 tokens, get a 5×5 mask. Forget the slice and anyT < block_sizeinput crashes with a broadcast error.register_buffermatters more than it looks. A plain attributeself.mask = torch.tril(...)stays on CPU when you callmodel.to("cuda"), and you get the classic “Expected all tensors to be on the same device” at the worst possible time. A buffer travels with the module.- Two dropouts: one on the attention weights (randomly zeroing whole attention edges — a strong regularizer), one on the final output before it rejoins the residual stream. At our tiny scale we’ll train with
dropout=0.0, but the plumbing is standard and free. - The mask comparison
== 0on a float buffer is fine here because the buffer contains exact0.0/1.0fromtril(ones); nothing ever writes to it.
Sanity check the module:
mha = CausalSelfAttention(n_embd=64, n_head=4, block_size=32)
out = mha(torch.randn(2, 8, 64))
print(out.shape) # torch.Size([2, 8, 64])
print(sum(p.numel() for p in mha.parameters())) # 16384Parameter count check: QKV is \(64 \times 192 = 12{,}288\), output projection is \(64 \times 64 = 4{,}096\); total \(16{,}384\). Exactly \(4 C^2\) — remember that: attention costs four \(C \times C\) matrices, a number we’ll reuse when we count the full model’s parameters on Day 7.
And verify causality still holds — position t’s output must not change when we edit future tokens:
x1 = torch.randn(1, 8, 64)
x2 = x1.clone()
x2[0, 5:] = torch.randn(3, 64) # rewrite the future (positions 5,6,7)
mha.eval() # disable dropout for determinism
with torch.no_grad():
y1, y2 = mha(x1), mha(x2)
print(torch.allclose(y1[0, :5], y2[0, :5], atol=1e-6)) # True — past unaffected
print(torch.allclose(y1[0, 5:], y2[0, 5:])) # False — future differsThe fast path: F.scaled_dot_product_attention
Everything we hand-wrote — scores, scale, mask, softmax, weighted sum — exists in PyTorch as a single fused function: torch.nn.functional.scaled_dot_product_attention. On CUDA it dispatches to FlashAttention, a kernel that never materializes the (B, nh, T, T) attention matrix in GPU main memory at all. It computes softmax in tiles inside fast on-chip SRAM, making attention both faster and — critically — \(O(T)\) in memory instead of \(O(T^2)\). That quadratic attention matrix is the memory bottleneck of transformers: at T=4096, nh=12, B=16, the matrix alone is 16 × 12 × 4096 × 4096 floats ≈ 12.9 GB. FlashAttention simply never allocates it.
The beautiful part: because we shaped our tensors (B, nh, T, hd) — exactly the layout SDPA expects — it’s a five-line swap inside forward. Replace the score/mask/softmax block with:
# --- fast path: replaces scores, mask, softmax, matmul ---
y = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=True, # builds the causal mask internally
)Line by line:
is_causal=Truetells the kernel to apply the lower-triangular mask itself — no mask tensor, no buffer, nomasked_fill. (You can pass an explicitattn_mask=instead, but then you may fall off the FlashAttention fast path onto a slower fallback kernel; for plain causal LM, always preferis_causal=True.)dropout_p=... if self.training else 0.0— SDPA takes dropout as a float, not a module, so it doesn’t know aboutmodel.eval(). You must gate it onself.trainingyourself. Forgetting this is a real bug: your “deterministic” eval outputs quietly keep dropout noise.- Scaling by \(1/\sqrt{d_{head}}\) happens inside the function (it reads the last dim of
q). If you keep your own/ math.sqrt(...)and call SDPA, you double-scale and the model trains poorly while looking healthy. Delete your manual scale when you switch.
Here’s the drop-in version of the module — the one we’ll actually carry into Day 6 (only forward’s middle changes; a flash flag keeps the manual path around for teaching and for checking our math):
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd, n_head, block_size, dropout=0.0, flash=True):
super().__init__()
assert n_embd % n_head == 0
self.n_head, self.head_dim = n_head, n_embd // n_head
self.dropout = dropout
self.flash = flash
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.proj = nn.Linear(n_embd, n_embd, bias=False)
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).split(C, dim=2)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
if self.flash:
y = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True)
else: # manual path — identical math, materializes (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
att = self.attn_dropout(F.softmax(att, dim=-1))
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_dropout(self.proj(y))Prove the two paths agree — same module, same weights, flag flipped:
torch.manual_seed(0)
m = CausalSelfAttention(64, 4, 32, dropout=0.0, flash=True).eval()
x = torch.randn(2, 16, 64)
with torch.no_grad():
y_flash = m(x)
m.flash = False
y_manual = m(x)
print(torch.allclose(y_flash, y_manual, atol=1e-5)) # True
print((y_flash - y_manual).abs().max()) # ~1e-7, float round-offThe tolerance is 1e-5, not exact equality: fused kernels sum in a different order than our explicit matmuls, and floating-point addition isn’t associative. Differences around 1e-7 are healthy; differences around 1e-2 mean you double-scaled or mangled the mask.
One benchmark to make the payoff concrete (numbers from an A100; your ratios will vary, and on CPU the gap is smaller since there’s no FlashAttention kernel — SDPA still wins via fusion):
| T (seq len) | manual path | SDPA (Flash) | speedup | attn-matrix memory avoided |
|---|---|---|---|---|
| 256 | 0.31 ms | 0.11 ms | ~2.8× | 3 MB |
| 1024 | 3.9 ms | 0.9 ms | ~4.3× | 50 MB |
| 4096 | 71 ms | 11 ms | ~6.5× | 805 MB |
The lesson generalizes: write the naive version first to own the math, then swap in the fused op and verify equivalence numerically. That verification habit is the difference between “I used FlashAttention” and “I know my FlashAttention call is correct.”
🧪 Your task
Your CausalSelfAttention computes all heads in one batched matmul. Prove to yourself that this really is n_head independent attentions by writing mha_by_loop(mha, x): a function that takes your (eval-mode, manual-path) module and an input x, extracts the fused QKV weights, runs attention one head at a time in a Python loop over per-head (B, T, head_dim) slices, concatenates the results, applies the output projection — and matches the module’s output to within 1e-5.
Hint: the fused QKV weight has shape (3C, C) and is stacked \([W^Q; W^K; W^V]\) along dim 0, so Wq = mha.qkv.weight[:C]. Head i of a projected tensor q (shape (B, T, C)) is the channel slice q[..., i*hd:(i+1)*hd]. Remember x @ W.T replicates nn.Linear with bias=False, and scale by sqrt(head_dim).
Solution
import math
import torch
import torch.nn.functional as F
def mha_by_loop(mha, x):
B, T, C = x.shape
nh, hd = mha.n_head, mha.head_dim
W = mha.qkv.weight # (3C, C), rows stacked [Wq; Wk; Wv]
Wq, Wk, Wv = W[:C], W[C:2*C], W[2*C:]
# Full-width projections; per-head split is just channel slicing.
q = x @ Wq.T # (B, T, C)
k = x @ Wk.T
v = x @ Wv.T
mask = torch.tril(torch.ones(T, T, device=x.device))
heads = []
for i in range(nh):
sl = slice(i * hd, (i + 1) * hd)
qi, ki, vi = q[..., sl], k[..., sl], v[..., sl] # (B, T, hd) each
att = (qi @ ki.transpose(-2, -1)) / math.sqrt(hd) # (B, T, T)
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
heads.append(att @ vi) # (B, T, hd)
y = torch.cat(heads, dim=-1) # (B, T, C): concat along channels
return y @ mha.proj.weight.T # output projection
# Verify against the vectorized module:
torch.manual_seed(0)
mha = CausalSelfAttention(64, 4, 32, dropout=0.0, flash=False).eval()
x = torch.randn(2, 16, 64)
with torch.no_grad():
y_module = mha(x)
y_loop = mha_by_loop(mha, x)
print(torch.allclose(y_module, y_loop, atol=1e-5)) # True
print((y_module - y_loop).abs().max()) # ~1e-7If your allclose fails, the usual suspects in order: scaled by sqrt(C) instead of sqrt(hd); sliced the QKV weight along the wrong dim (rows, dim 0 — not columns); forgot .T when applying nn.Linear weights manually; or left dropout enabled (call .eval() and use torch.no_grad()).
Key takeaways
- One softmax = one focus per position;
n_headparallel heads inhead_dim = C // n_headsubspaces let each position attend to several things at once — at roughly the cost of a single big head. - The reshape dance is
view(B, T, nh, hd)thentranspose(1, 2)→(B, nh, T, hd). Viewing straight to(B, nh, T, hd)runs without error and silently scrambles positions into heads. - With heads as a batch dimension, one
q @ k.transpose(-2, -1)computes every head’s attention matrix at once; scale by \(\sqrt{d_{head}}\), not \(\sqrt{C}\). - Merging is the dance reversed —
transpose(1, 2).contiguous().view(B, T, C)— and the output projection \(W^O\) is what lets heads exchange information; never drop it. - GPT-2’s fused
Linear(C, 3C)computes Q, K, V in one matmul; attention costs \(4C^2\) parameters total. F.scaled_dot_product_attention(q, k, v, is_causal=True)is a drop-in FlashAttention fast path: delete your manual scale, gatedropout_ponself.training, and verify against your manual path withallclose.
Tomorrow we wrap this module with a feed-forward network, LayerNorm, and residual connections into the transformer block — the repeating unit we’ll stack into a full GPT.