Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

In this lesson

  • Lesson 4 — Scaled Dot-Product Attention from Scratch
    • The idea: every token asks a question
    • Q, K, V as three nn.Linear layers
    • Scores: \(QK^\top\), and why we divide by \(\sqrt{d_k}\)
    • The causal mask: no peeking at the future
    • Softmax, then blend the values
    • Packaging it: a single Head module with dropout
    • Verify it by hand: a worked micro-example
    • 🧪 Your task
    • Key takeaways

⚡ Building Transformers from Scratch with PyTorch · Lesson 4 — Scaled Dot-Product Attention from Scratch

🏠 ⚡ Course home  |  ← Lesson 03  |  Lesson 05 →  |  📚 All mini-courses


Lesson 4 — Scaled Dot-Product Attention from Scratch

In the previous lesson you gave every token an identity: a token embedding that says what it is, plus a positional embedding that says where it sits. But each token’s vector is still a hermit — it knows nothing about its neighbours. "bank" in river bank and bank account comes out of Lesson 3’s pipeline as the exact same vector. In this lesson we fix that. We build scaled dot-product attention, the mechanism that lets every token look back over the sequence, decide which earlier tokens matter to it, and pull in a weighted blend of their information. This is the single most important component in the entire course — everything from Lesson 5 onward is just packaging around what you build today. (For the theory-first treatment, see the Attention & Transformers chapter in the encyclopedia; here we build it line by line and verify it by hand.)

🎯 In this lesson you will: project embeddings into queries, keys and values with nn.Linear, compute attention scores as \(QK^\top/\sqrt{d_k}\), build the causal mask with tril + masked_fill, wrap it all into a reusable single-head module with dropout, and verify the whole thing against a hand-computed example

The idea: every token asks a question

Forget the math for a second. Attention is a soft lookup. Each token in the sequence emits three vectors, all derived from its embedding:

  • a query (\(q\)): “here is what I’m looking for”
  • a key (\(k\)): “here is what I contain, in case anyone is looking”
  • a value (\(v\)): “here is what I’ll actually hand over if you pick me”

Token \(i\) compares its query against every token’s key with a dot product. A big dot product means “your key matches my query — I want your information.” Those raw match scores get turned into a probability distribution with softmax, and the output for token \(i\) is the weighted average of everyone’s values under that distribution.

The key insight that trips people up: queries, keys, and values are all computed from the same input, just through three different learned linear maps. Nothing about the input tokens is intrinsically a “query” — the model learns projections that make useful questions and useful answers emerge. That’s why this is called self-attention: the sequence attends to itself.

flowchart LR
    X["x<br/>(B, T, C)"] --> LQ["Linear W_q"]
    X --> LK["Linear W_k"]
    X --> LV["Linear W_v"]
    LQ --> Q["Q (B, T, hs)"]
    LK --> K["K (B, T, hs)"]
    LV --> V["V (B, T, hs)"]
    Q --> MM1["Q @ K^T"]
    K --> MM1
    MM1 --> SC["÷ √hs"]
    SC --> MASK["causal mask<br/>(tril, -inf)"]
    MASK --> SM["softmax<br/>(dim=-1)"]
    SM --> DO["dropout"]
    DO --> MM2["weights @ V"]
    V --> MM2
    MM2 --> OUT["out (B, T, hs)"]

One formula summarises the whole diagram:

\[ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)V \]

where \(M\) is the causal mask (zeros below the diagonal, \(-\infty\) above). We’ll now build each piece of that formula, checking shapes at every step.

Q, K, V as three nn.Linear layers

Set up a tiny, fully reproducible playground. We keep the (B, T, C) naming convention from the earlier lessons: Batch, Time (sequence length), Channels (embedding dimension).

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

B, T, C = 4, 8, 32     # batch of 4 sequences, 8 tokens each, 32-dim embeddings
head_size = 16         # dimension of the attention head's internal space

x = torch.randn(B, T, C)   # stand-in for Lesson 3's output: token emb + pos emb

x is exactly what your Lesson 3 pipeline produces: a batch of sequences where each token is already a C-dimensional vector carrying identity + position. Now the three projections:

query = nn.Linear(C, head_size, bias=False)
key   = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

q = query(x)   # (B, T, head_size)
k = key(x)     # (B, T, head_size)
v = value(x)   # (B, T, head_size)

print(q.shape, k.shape, v.shape)
torch.Size([4, 8, 16]) torch.Size([4, 8, 16]) torch.Size([4, 8, 16])

Three things worth pausing on:

  1. Why bias=False? Attention scores are relative — we only care how token \(i\)’s query compares to token \(j\)’s key versus token \(j'\)’s key. A constant bias added to every key shifts all scores of a row equally and mostly washes out in softmax. Dropping it follows GPT-2’s design and saves parameters. It’s a convention, not a law; the model works with biases too.
  2. nn.Linear on a 3-D tensor applies the same weight matrix to the last dimension of every (b, t) position independently. No loops, no reshaping — PyTorch broadcasts the matmul over all leading dimensions. If you’ve ever written x.view(B*T, C) before a linear layer, you can stop: it’s unnecessary.
  3. These are three separate layers with three separate weight matrices. A common bug is reusing one projection for all three roles — the model still runs (shapes match) but it can never learn asymmetric relationships like “adjectives should look for the nouns they modify”, because query-space and key-space are forced to be identical.

Each token has now been mapped from its C=32-dim embedding into a head_size=16-dim space where the dot-product comparisons will happen.

Scores: \(QK^\top\), and why we divide by \(\sqrt{d_k}\)

Token \(i\)’s affinity for token \(j\) is the dot product \(q_i \cdot k_j\). We want all pairs at once, and that’s exactly a matrix multiply of q against k transposed:

scores = q @ k.transpose(-2, -1)   # (B, T, hs) @ (B, hs, T) -> (B, T, T)
print(scores.shape)
torch.Size([4, 8, 8])

Read the shape carefully — it’s the heart of today. scores[b, i, j] is how much token i (the row, the one asking) cares about token j (the column, the one being looked at) in batch element b. Note transpose(-2, -1), not .T: we must swap only the last two dims and leave the batch dim alone. .T on a 3-D tensor would scramble the batch dimension into the matmul — shapes might even happen to line up, and you’d get silently wrong numbers.

Tensor Shape Meaning
x (B, T, C) embedded input from Lesson 3
q, k, v (B, T, hs) per-token projections
scores (B, T, T) row i, col j = affinity of token i for token j
weights (B, T, T) scores after mask + softmax; each row sums to 1
out (B, T, hs) weighted blend of values, one vector per token

Now the scaling. If the components of \(q_i\) and \(k_j\) are roughly independent with mean 0 and variance 1, their dot product is a sum of \(d_k\) such products, so its variance is \(d_k\) and its typical magnitude is \(\sqrt{d_k}\). With head_size = 64 (typical in real models), raw scores routinely reach ±8 or more. Feed logits like that into softmax and it saturates — one weight ≈ 1, the rest ≈ 0 — and a saturated softmax has near-zero gradients, so the attention pattern stops learning almost immediately. Dividing by \(\sqrt{d_k}\) renormalises the scores back to unit variance:

scores = scores * head_size**-0.5   # QK^T / sqrt(d_k)
print(f"score std ≈ {scores.std():.3f}")   # ~1, regardless of head_size
score std ≈ 1.031

You can see the saturation directly:

demo = torch.tensor([1.0, 2.0, 3.0])
print(F.softmax(demo, dim=-1))       # tensor([0.0900, 0.2447, 0.6652]) - soft
print(F.softmax(demo * 8, dim=-1))   # tensor([2.9e-07, 3.4e-04, 9.9e-01]) - saturated

Same relative preferences, but the unscaled version is effectively an argmax with dead gradients. That one * head_size**-0.5 is the “scaled” in scaled dot-product attention — a one-character-class fix for a training-killing pathology.

The causal mask: no peeking at the future

We are building a decoder-only model that generates text left to right. At training time the model predicts token \(t{+}1\) from tokens \(1..t\) — so token \(i\) must never read information from tokens \(j > i\). If it could, training becomes a trivial copy task (“the best predictor of the next token is… the next token, which I can see”) and the model is useless at generation time, when the future genuinely doesn’t exist yet.

The trick: set every future-looking score to \(-\infty\) before softmax. Since \(e^{-\infty} = 0\), those positions get exactly zero weight, and the remaining weights still sum to 1 over the allowed positions.

tril = torch.tril(torch.ones(T, T))          # lower-triangular matrix of 1s
scores = scores.masked_fill(tril == 0, float('-inf'))
print(scores[0, :4, :4])                     # top-left corner of batch item 0
tensor([[ 0.1885,    -inf,    -inf,    -inf],
        [ 0.2472, -0.1938,    -inf,    -inf],
        [-0.3053,  0.5747, -0.0472,    -inf],
        [ 0.0742,  0.1181, -0.4232,  0.2818]])

torch.tril keeps the lower triangle (including the diagonal) and zeroes the rest; masked_fill(mask, val) writes val wherever the boolean mask is True. The (T, T) mask broadcasts across the batch dimension of the (B, T, T) scores automatically.

Two classic mistakes to avoid here:

  • Masking with a large negative number after softmax, or multiplying weights by the mask. Then rows no longer sum to 1 and the “probabilities” leak. Mask the logits with -inf, then softmax — order matters.
  • Using float('-inf') vs -1e9: -inf is exact and safe here because every row has at least one unmasked entry (the diagonal — a token can always attend to itself). If an entire row were masked, softmax would produce NaNs; that never happens with a causal mask.

Here is the resulting attention-weight matrix as a picture — the single most useful mental image in the whole transformer. Each row is one token’s attention budget; darker cells got more of it; the upper triangle is structurally forbidden:

weights (T×T) — row i attends over columns 0..i the cat sat on the mat the cat sat on the 1.00 .31.69 .15.58.27 .10.18.48.24 .08.12.30.36.21 masked: −inf → 0 after softmax each row sums to 1 row i = token i’s attention budget over its own past (cols ≤ i) diagonal always allowed: a token may attend to itself

Notice row 0: the first token has no past, so 100% of its attention lands on itself. Every later row spreads its unit budget over strictly more options. When you train the real model on Lesson 8 and plot these matrices, you’ll see structure emerge — heads that lock onto the previous token, heads that find matching brackets, heads that track the subject of a sentence.

Softmax, then blend the values

Two lines finish the computation:

weights = F.softmax(scores, dim=-1)   # (B, T, T), rows sum to 1
out = weights @ v                     # (B, T, T) @ (B, T, hs) -> (B, T, hs)
print(out.shape)
print(weights[0].sum(dim=-1))         # sanity: every row is a distribution
torch.Size([4, 8, 16])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])

dim=-1 is critical: we normalise across columns within each row, i.e. over the tokens being attended to. Softmax over dim=-2 also runs without error and also produces numbers between 0 and 1 — it normalises each column instead, meaning “how much of token \(j\)’s outgoing influence goes to each reader”, which is not a valid attention distribution per query. This is the classic silent-wrong-answer bug of the lesson: the model trains, the loss even goes down a bit, and generation is garbage.

The final matmul weights @ v is where information actually moves. Row \(i\) of out is \(\sum_j w_{ij} v_j\) — a convex combination of the value vectors of token \(i\)’s past. Token \(i\)’s output is no longer just about token \(i\); it’s a context-aware summary. That’s it. That’s attention: two matmuls, a scale, a mask, and a softmax.

Packaging it: a single Head module with dropout

Now we wrap the whole computation in an nn.Module you’ll reuse verbatim inside Lesson 5’s multi-head attention. Two production details get added: dropout on the attention weights (randomly zeroing some connections so the model can’t over-rely on any single token-to-token edge), and the mask stored as a buffer.

class Head(nn.Module):
    """One head of causal (masked) self-attention."""

    def __init__(self, n_embd: int, head_size: int, block_size: int, dropout: float = 0.1):
        super().__init__()
        self.key   = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # not a parameter: no gradients, but moves with .to(device) and saves with state_dict
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        q = self.query(x)                                  # (B, T, hs)
        k = self.key(x)                                    # (B, T, hs)
        v = self.value(x)                                  # (B, T, hs)

        scores = q @ k.transpose(-2, -1) * k.size(-1)**-0.5      # (B, T, T)
        scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)                # (B, T, T)
        weights = self.dropout(weights)
        return weights @ v                                 # (B, T, hs)

Line-by-line, the parts that are new:

  • register_buffer("tril", ...) — the mask is a fixed tensor, not a learnable parameter. If you stored it as self.tril = torch.tril(...) (a plain attribute), it would not follow the module in model.to("cuda"), and you’d hit a device-mismatch error the moment you train on GPU. register_buffer gives you device handling and checkpointing for free, with no gradients.
  • self.tril[:T, :T] — we allocate the mask once at the maximum context length (block_size, from Lesson 2) and slice it to the actual sequence length at runtime. Sequences shorter than block_size (common in generation, where the prompt starts tiny) just use the top-left corner. Rebuilding the mask every forward pass works too, but allocates a fresh tensor on every call for no reason.
  • * k.size(-1)**-0.5 — same scaling as before, but read off the tensor itself so the module can’t fall out of sync with its config.
  • self.dropout(weights) — dropout applied to the attention weights, exactly where “Attention Is All You Need” and GPT-2 put it. During training, random query→key edges are severed and the survivors are rescaled by \(1/(1-p)\), so rows temporarily won’t sum to 1 — that’s expected. In model.eval() mode dropout is the identity and exactness returns.

Quick smoke test:

head = Head(n_embd=C, head_size=head_size, block_size=T)
out = head(x)
print(out.shape)          # torch.Size([4, 8, 16])
out = head(x[:, :3, :])   # shorter sequence: mask slicing at work
print(out.shape)          # torch.Size([4, 3, 16])

Verify it by hand: a worked micro-example

Never trust attention code you haven’t checked against numbers you can compute on paper. We’ll force the three projections to be the identity (so \(Q = K = V = x\)) and use a 3-token, 2-dim input small enough to do with a pocket calculator.

head = Head(n_embd=2, head_size=2, block_size=3, dropout=0.0)
head.eval()   # belt and braces: disables dropout even if p > 0

with torch.no_grad():   # overwrite the random init with identity matrices
    for lin in (head.query, head.key, head.value):
        lin.weight.copy_(torch.eye(2))

x = torch.tensor([[[1., 0.],
                   [0., 1.],
                   [1., 1.]]])        # (B=1, T=3, C=2)
out = head(x)
print(out)

Paper version. With identity projections, \(\text{scores} = x x^\top / \sqrt{2}\):

\[ xx^\top = \begin{pmatrix} 1 & 0 & 1 \\ 0 & 1 & 1 \\ 1 & 1 & 2 \end{pmatrix} \;\xrightarrow{\;/\sqrt{2}\;}\; \begin{pmatrix} 0.707 & 0 & 0.707 \\ 0 & 0.707 & 0.707 \\ 0.707 & 0.707 & 1.414 \end{pmatrix} \]

Apply the causal mask (upper triangle → \(-\infty\)), then softmax each row:

  • Row 1: only itself visible → weights \((1, 0, 0)\).
  • Row 2: logits \((0, 0.707)\) → \(e^0 = 1\), \(e^{0.707} = 2.028\) → weights \((0.330,\ 0.670,\ 0)\).
  • Row 3: logits \((0.707, 0.707, 1.414)\) → \((2.028, 2.028, 4.113)\), sum \(8.169\) → weights \((0.248,\ 0.248,\ 0.503)\).

Blend the values (\(V = x\)):

  • Row 1: \(1 \cdot (1,0) = (1.000,\ 0.000)\)
  • Row 2: \(0.330(1,0) + 0.670(0,1) = (0.330,\ 0.670)\)
  • Row 3: \(0.248(1,0) + 0.248(0,1) + 0.503(1,1) = (0.752,\ 0.752)\)

And PyTorch agrees:

tensor([[[1.0000, 0.0000],
         [0.3302, 0.6698],
         [0.7517, 0.7517]]])

Read the result like a story. Token 1 could only look at itself, so its output is its value — the first token always passes through unchanged. Token 2 preferred itself (0.67) but pulled in a third of token 1. Token 3, whose vector \((1,1)\) overlaps both predecessors equally, split its remaining budget evenly between them. Information flowed strictly leftward-to-rightward, weighted by learned (here: identity) similarity. When training starts on Lesson 8, gradients will reshape \(W_q, W_k, W_v\) so these weights land on useful tokens instead of merely similar ones.

🧪 Your task

PyTorch ships a fused, memory-efficient implementation of this lesson’s exact computation: F.scaled_dot_product_attention (it dispatches to FlashAttention kernels on GPU). Prove that your hand-rolled Head computes the same thing. Write a function check_against_pytorch() that builds a Head with dropout 0.0, runs a random (2, 6, 32) input through it, then reproduces the output using F.scaled_dot_product_attention(q, k, v, is_causal=True) fed with the same projections, and asserts the two outputs match with torch.allclose(..., atol=1e-6).

Hint: the built-in takes already-projected q, k, v — reuse head.query(x), head.key(x), head.value(x) so both paths share identical weights. is_causal=True replaces your tril/masked_fill step, and the function applies the \(1/\sqrt{d_k}\) scaling internally, so don’t scale twice.

Solution
import torch
import torch.nn.functional as F

def check_against_pytorch():
    torch.manual_seed(0)
    B, T, C, hs = 2, 6, 32, 16

    head = Head(n_embd=C, head_size=hs, block_size=T, dropout=0.0)
    head.eval()

    x = torch.randn(B, T, C)

    # path 1: our hand-rolled head
    ours = head(x)

    # path 2: PyTorch's fused kernel, fed the SAME projections
    q = head.query(x)
    k = head.key(x)
    v = head.value(x)
    theirs = F.scaled_dot_product_attention(q, k, v, is_causal=True)

    assert ours.shape == theirs.shape == (B, T, hs)
    assert torch.allclose(ours, theirs, atol=1e-6), \
        f"max diff {(ours - theirs).abs().max().item():.2e}"
    print("OK — hand-rolled attention matches F.scaled_dot_product_attention")

check_against_pytorch()
OK — hand-rolled attention matches F.scaled_dot_product_attention

If the assert fires, the usual suspects in order: dropout not disabled (nonzero p without .eval()), scaling applied twice (once in your head, once inside the built-in), or softmax over the wrong dim in your Head. In the real course model we keep the hand-rolled version — understanding it is the point — but now you know the one-line production replacement.

Key takeaways

  • Self-attention = two matmuls, a scale, a mask, a softmax: softmax(Q Kᵀ / √d_k + mask) @ V.
  • Q, K, V are three separate nn.Linear(C, head_size, bias=False) maps of the same input; the asymmetry between query-space and key-space is what lets the model learn directional relationships.
  • Shapes to burn in: x (B,T,C) → q,k,v (B,T,hs) → scores/weights (B,T,T) → out (B,T,hs). weights[b,i,j] = how much token i reads from token j.
  • Scale by \(\sqrt{d_k}\) or softmax saturates and gradients die; the fix costs one multiplication.
  • Causal mask = tril + masked_fill(-inf) on the logits, before softmax; slice tril[:T,:T] for shorter sequences; store it with register_buffer so it follows the model to the GPU.
  • Softmax over dim=-1, never -2 — the wrong dim runs fine and is silently, completely wrong.
  • Dropout on attention weights regularises which edges carry information; rows won’t sum to 1 during training, and that’s by design.

In the next lesson: one head is one “conversation channel” — Lesson 5 runs many heads in parallel and concatenates them, so the model can track syntax, coreference, and position all at once.


🏠 ⚡ Course home  |  ← Lesson 03  |  Lesson 05 →  |  📚 All mini-courses

 

© Kader Mohideen