Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

On this day

  • Day 5 — Multi-Head Attention: Many Perspectives in Parallel
    • Why one head isn’t enough
    • The reshape dance: splitting C into heads
    • All heads in one matmul
    • The output projection: letting heads talk
    • The GPT-2 trick: one fused QKV projection
    • The fast path: F.scaled_dot_product_attention
    • 🧪 Your task
    • Key takeaways

⚡ Building Transformers from Scratch with PyTorch · Day 5 — Multi-Head Attention: Many Perspectives in Parallel

🏠 ⚡ Course home  |  ← Day 04  |  Day 06 →  |  📚 All mini-courses


Day 5 — Multi-Head Attention: Many Perspectives in Parallel

Yesterday you built scaled dot-product attention from scratch: queries meeting keys, softmax over scores, a weighted mix of values, all under a causal mask. It works — but it has a subtle limitation. A single attention operation produces one attention pattern per position. Softmax has to commit: if a token wants to attend strongly to the previous verb and to the sentence’s subject and to a matching quote character, one distribution has to average those needs together. Today we fix that by running several attention operations — heads — in parallel, each in its own lower-dimensional subspace, then stitching their answers back together. Along the way you’ll learn the single most-cursed tensor manipulation in all of deep learning (the view/transpose reshape dance), see how to compute all heads in one batched matmul, collapse three projection layers into one, and finally swap your hand-rolled math for PyTorch’s fused F.scaled_dot_product_attention — the same kernel (FlashAttention) that real GPTs run on. The module you finish today, CausalSelfAttention, drops straight into tomorrow’s transformer block unchanged.

🎯 Today you will: understand why multiple attention heads beat one big head, master the view/transpose reshape that splits channels into heads, compute all heads in a single batched matmul, fuse Q/K/V into one linear layer like GPT-2, and swap in F.scaled_dot_product_attention as a drop-in FlashAttention fast path

Why one head isn’t enough

Recall the shape story from Day 4. Attention takes x of shape (B, T, C) — batch, time, channels — and produces an output of the same shape, where each position’s output is a weighted average of value vectors at positions it’s allowed to see. The weights come from a softmax over query–key dot products.

The problem is the softmax. For a given query position, the attention weights form one probability distribution over the past. One distribution means one “focus.” But language needs many kinds of focus simultaneously. In the sentence “The keys to the cabinet were on the table”, the token were needs to attend back to keys (subject–verb agreement), while the same token might also benefit from attending to table territory for semantic context. A single softmax must blend these, diluting both.

The fix, from the original Attention Is All You Need paper (see the encyclopedia’s Attention & Transformers chapter for the theory), is beautifully cheap: instead of one attention operation over the full C-dimensional space, run n_head independent attention operations, each over a head_dim = C // n_head slice. Each head gets its own learned Q, K, V projections, so each can learn its own kind of relationship — one head tracking syntax, another tracking positional patterns, another tracking rare-token copying. Crucially, the total compute is roughly the same as one big head: you’re not multiplying cost by n_head, you’re splitting the channel dimension across heads.

\[ \text{head}_i = \text{Attention}(XW_i^Q,\; XW_i^K,\; XW_i^V), \qquad \text{MHA}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\, W^O \]

Here’s the full pipeline we’re building today:

flowchart LR
    X["x<br/>(B, T, C)"] --> QKV["Linear C → 3C<br/>(one matmul)"]
    QKV --> SPLIT["split into<br/>q, k, v<br/>each (B, T, C)"]
    SPLIT --> RESHAPE["view + transpose<br/>(B, nh, T, hd)"]
    RESHAPE --> ATTN["attention per head<br/>batched matmul<br/>(B, nh, T, hd)"]
    ATTN --> MERGE["transpose + view<br/>back to (B, T, C)"]
    MERGE --> PROJ["output projection<br/>Linear C → C"]
    PROJ --> Y["y<br/>(B, T, C)"]

Note the symmetry: shape in equals shape out, (B, T, C). That’s what makes attention stackable into deep networks — tomorrow we’ll wrap this module in residual connections precisely because the shapes line up.

The reshape dance: splitting C into heads

This is the part that trips everyone up at least once, so we’ll do it slowly and show what goes wrong if you cheat. Set up the usual toy dimensions:

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1337)

B, T, C = 2, 8, 64      # batch 2, sequence length 8, 64 channels
n_head = 4
head_dim = C // n_head  # 16
assert C % n_head == 0, "channels must divide evenly into heads"

x = torch.randn(B, T, C)

The assert is not decoration. Every real GPT config obeys n_embd % n_head == 0 (GPT-2 small: 768 = 12 × 64). If it doesn’t divide evenly, there is no clean way to split, and you want a loud failure at construction time, not a cryptic view error mid-forward.

We want to go from (B, T, C) to (B, n_head, T, head_dim). Why that target order? Because PyTorch’s matmul batches over all leading dimensions. If the tensor is (B, nh, T, hd), then q @ k.transpose(-2, -1) treats B × nh as one big batch of independent (T, hd) matrices — every head of every sequence computed in a single kernel launch. The head dimension must sit before the time dimension so that time and head-features are the trailing two dims that matmul operates on.

The correct dance is two steps: first view, then transpose.

# Step 1: view (B, T, C) -> (B, T, n_head, head_dim)
# This just reinterprets the last axis: the 64 channels at each
# position become 4 groups of 16 contiguous channels. No data moves.
x_split = x.view(B, T, n_head, head_dim)
print(x_split.shape)          # torch.Size([2, 8, 4, 16])

# Step 2: transpose dims 1 and 2 -> (B, n_head, T, head_dim)
# Now each head sees the full sequence of its own 16-dim slices.
x_heads = x_split.transpose(1, 2)
print(x_heads.shape)          # torch.Size([2, 4, 8, 16])

Why can’t you skip the transpose and just write x.view(B, n_head, T, head_dim)? Because view never moves data — it only relabels how the same flat buffer is indexed. The memory layout of x is [batch][time][channel]: all 64 channels of position 0, then all 64 channels of position 1, and so on. Viewing that buffer directly as (B, nh, T, hd) would carve it as [batch][head][time][head_dim] — meaning “head 0” would grab the first T × hd = 8 × 16 = 128 floats, which in reality are all 64 channels of positions 0 and 1. You’d be mixing whole positions into what you think are per-head slices. The shapes look right, the code runs, the model silently learns garbage. Watch:

wrong = x.view(B, n_head, T, head_dim)         # runs fine. lies.
right = x.view(B, T, n_head, head_dim).transpose(1, 2)

# Head 0, position 0 should be channels 0..15 of position 0:
print(torch.equal(right[0, 0, 0], x[0, 0, :16]))   # True
print(torch.equal(wrong[0, 0, 0], x[0, 0, :16]))   # True (coincidence: very first row matches)
# Head 0, position 1 should be channels 0..15 of position 1:
print(torch.equal(right[0, 0, 1], x[0, 1, :16]))   # True
print(torch.equal(wrong[0, 0, 1], x[0, 0, 16:32])) # True — WRONG data:
# "position 1" of the wrong tensor is actually channels 16..31 of position 0!

This is the classic silent transformer bug: no error, no NaN, just a model that never quite learns. Burn the rule in: view to unpack the channel axis, transpose to move heads ahead of time.

Here’s the geometry of the correct split — one position’s 64-channel vector becoming four 16-dim head slices:

one token’s vector: C = 64 channels 0…15 16…31 32…47 48…63 after view + transpose: 4 heads × head_dim = 16 head 0 head 1 head 2 head 3 own attention map own attention map own attention map own attention map

All heads in one matmul

With the tensors shaped (B, nh, T, hd), Day 4’s attention math works unchanged — broadcasting does all the per-head bookkeeping for free. Let’s prove it with separate Q, K, V projections first (we’ll fuse them in the next section):

import math

# One projection each, over the FULL channel dim — the head split
# happens after, by slicing the output channels.
Wq = nn.Linear(C, C, bias=False)
Wk = nn.Linear(C, C, bias=False)
Wv = nn.Linear(C, C, bias=False)

q = Wq(x).view(B, T, n_head, head_dim).transpose(1, 2)  # (B, nh, T, hd)
k = Wk(x).view(B, T, n_head, head_dim).transpose(1, 2)  # (B, nh, T, hd)
v = Wv(x).view(B, T, n_head, head_dim).transpose(1, 2)  # (B, nh, T, hd)

# Scores: (B, nh, T, hd) @ (B, nh, hd, T) -> (B, nh, T, T)
# One attention matrix PER head, all computed in one call.
att = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim)
print(att.shape)   # torch.Size([2, 4, 8, 8])

Two things deserve a pause. First, notice we still use nn.Linear(C, C) — mathematically identical to n_head separate Linear(C, head_dim) projections stacked side by side, because slicing the output of a linear layer is the same as splitting its weight matrix into column blocks. One big matmul beats four small ones on any GPU. Second, the scale factor is \(\sqrt{d_{head}} = \sqrt{16} = 4\), not \(\sqrt{C} = 8\). Each head’s dot products sum over only head_dim terms, so that’s the variance you’re normalizing away. Using \(\sqrt{C}\) over-shrinks the logits and makes early attention too uniform — a soft bug that slows training rather than breaking it.

Now the causal mask and softmax, exactly as Day 4, broadcast across the head dimension:

mask = torch.tril(torch.ones(T, T))              # (T, T)
att = att.masked_fill(mask == 0, float('-inf'))  # broadcasts over (B, nh, ·, ·)
att = F.softmax(att, dim=-1)                     # each row a distribution, per head
y = att @ v                                      # (B, nh, T, T) @ (B, nh, T, hd) -> (B, nh, T, hd)
print(y.shape)   # torch.Size([2, 4, 8, 16])

Each head has now produced its own (T, hd) answer. To merge them we run the reshape dance in reverse — transpose heads back next to channels, then view-flatten:

y = y.transpose(1, 2)              # (B, T, nh, hd) — heads back beside channels
y = y.contiguous().view(B, T, C)   # (B, T, C) — concat heads along channels
print(y.shape)   # torch.Size([2, 8, 64])

Why the .contiguous()? transpose doesn’t move data either — it just swaps strides, leaving the tensor’s memory in the old order. But view demands a contiguous buffer whose layout matches the requested shape. Skip it and PyTorch throws RuntimeError: view size is not compatible with input tensor's size and stride. .contiguous() performs the actual copy into the new order. (You could use .reshape(B, T, C), which calls .contiguous() for you when needed — writing it explicitly documents that a real copy happens here.)

The output projection: letting heads talk

We’re not done. After the merge, position t’s vector is just head 0’s answer in channels 0–15, head 1’s in 16–31, and so on — four opinions sitting side by side in separate lanes, never interacting. The output projection \(W^O\) fixes that:

Wo = nn.Linear(C, C, bias=False)
out = Wo(y)        # (B, T, C) — every output channel mixes ALL heads

Each output channel of Wo is a learned linear combination of all 64 input channels — i.e., of all four heads. This is where “head 2 found the subject and head 3 found the verb” gets combined into a single useful feature. Without \(W^O\), downstream layers could still mix the lanes eventually, but the attention layer itself would be strictly a concatenation of independent low-rank updates, and — more practically — you’d break the standard architecture that tomorrow’s residual stream expects. Every serious implementation (the original paper, GPT-2, LLaMA) has it. Keep it.

The GPT-2 trick: one fused QKV projection

Three separate Linear(C, C) layers for Q, K, V means three matmuls over the same input x. GPT-2’s implementation (and nanoGPT, which we’re following in spirit) fuses them into a single Linear(C, 3C) — one matmul producing a (B, T, 3C) tensor that we split three ways. Same math, same parameter count, fewer kernel launches, better GPU utilization.

Time to write the real module — the one that ships in our GPT. It takes the config values we’ve been carrying since Day 1:

class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention, GPT-2 style (fused QKV)."""

    def __init__(self, n_embd: int, n_head: int, block_size: int,
                 dropout: float = 0.0):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head

        # One matmul computes Q, K, and V for all heads at once.
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        # Output projection: mixes information across heads.
        self.proj = nn.Linear(n_embd, n_embd, bias=False)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        # Causal mask, precomputed once. register_buffer -> moves with
        # .to(device), saves with state_dict, but is NOT a parameter.
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(block_size, block_size))
                 .view(1, 1, block_size, block_size)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        # (B, T, C) -> (B, T, 3C) -> three (B, T, C) tensors
        q, k, v = self.qkv(x).split(C, dim=2)

        # The dance: (B, T, C) -> (B, T, nh, hd) -> (B, nh, T, hd)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # Attention for all heads in one shot: (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)

        y = att @ v                                    # (B, nh, T, hd)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # merge heads

        return self.resid_dropout(self.proj(y))       # (B, T, C)

Walk through the load-bearing choices:

  • split(C, dim=2) carves the (B, T, 3C) output into three (B, T, C) chunks along channels. The QKV weight matrix is literally \([W^Q; W^K; W^V]\) stacked — fusing changes nothing mathematically.
  • self.mask[:, :, :T, :T] — the buffer is built at the maximum block_size, then sliced to the actual sequence length at runtime. This is why generation with growing context “just works”: feed 5 tokens, get a 5×5 mask. Forget the slice and any T < block_size input crashes with a broadcast error.
  • register_buffer matters more than it looks. A plain attribute self.mask = torch.tril(...) stays on CPU when you call model.to("cuda"), and you get the classic “Expected all tensors to be on the same device” at the worst possible time. A buffer travels with the module.
  • Two dropouts: one on the attention weights (randomly zeroing whole attention edges — a strong regularizer), one on the final output before it rejoins the residual stream. At our tiny scale we’ll train with dropout=0.0, but the plumbing is standard and free.
  • The mask comparison == 0 on a float buffer is fine here because the buffer contains exact 0.0/1.0 from tril(ones); nothing ever writes to it.

Sanity check the module:

mha = CausalSelfAttention(n_embd=64, n_head=4, block_size=32)
out = mha(torch.randn(2, 8, 64))
print(out.shape)                                   # torch.Size([2, 8, 64])
print(sum(p.numel() for p in mha.parameters()))    # 16384

Parameter count check: QKV is \(64 \times 192 = 12{,}288\), output projection is \(64 \times 64 = 4{,}096\); total \(16{,}384\). Exactly \(4 C^2\) — remember that: attention costs four \(C \times C\) matrices, a number we’ll reuse when we count the full model’s parameters on Day 7.

And verify causality still holds — position t’s output must not change when we edit future tokens:

x1 = torch.randn(1, 8, 64)
x2 = x1.clone()
x2[0, 5:] = torch.randn(3, 64)     # rewrite the future (positions 5,6,7)
mha.eval()                          # disable dropout for determinism
with torch.no_grad():
    y1, y2 = mha(x1), mha(x2)
print(torch.allclose(y1[0, :5], y2[0, :5], atol=1e-6))  # True — past unaffected
print(torch.allclose(y1[0, 5:], y2[0, 5:]))             # False — future differs

The fast path: F.scaled_dot_product_attention

Everything we hand-wrote — scores, scale, mask, softmax, weighted sum — exists in PyTorch as a single fused function: torch.nn.functional.scaled_dot_product_attention. On CUDA it dispatches to FlashAttention, a kernel that never materializes the (B, nh, T, T) attention matrix in GPU main memory at all. It computes softmax in tiles inside fast on-chip SRAM, making attention both faster and — critically — \(O(T)\) in memory instead of \(O(T^2)\). That quadratic attention matrix is the memory bottleneck of transformers: at T=4096, nh=12, B=16, the matrix alone is 16 × 12 × 4096 × 4096 floats ≈ 12.9 GB. FlashAttention simply never allocates it.

The beautiful part: because we shaped our tensors (B, nh, T, hd) — exactly the layout SDPA expects — it’s a five-line swap inside forward. Replace the score/mask/softmax block with:

        # --- fast path: replaces scores, mask, softmax, matmul ---
        y = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_dropout.p if self.training else 0.0,
            is_causal=True,        # builds the causal mask internally
        )

Line by line:

  • is_causal=True tells the kernel to apply the lower-triangular mask itself — no mask tensor, no buffer, no masked_fill. (You can pass an explicit attn_mask= instead, but then you may fall off the FlashAttention fast path onto a slower fallback kernel; for plain causal LM, always prefer is_causal=True.)
  • dropout_p=... if self.training else 0.0 — SDPA takes dropout as a float, not a module, so it doesn’t know about model.eval(). You must gate it on self.training yourself. Forgetting this is a real bug: your “deterministic” eval outputs quietly keep dropout noise.
  • Scaling by \(1/\sqrt{d_{head}}\) happens inside the function (it reads the last dim of q). If you keep your own / math.sqrt(...) and call SDPA, you double-scale and the model trains poorly while looking healthy. Delete your manual scale when you switch.

Here’s the drop-in version of the module — the one we’ll actually carry into Day 6 (only forward’s middle changes; a flash flag keeps the manual path around for teaching and for checking our math):

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout=0.0, flash=True):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head, self.head_dim = n_head, n_embd // n_head
        self.dropout = dropout
        self.flash = flash
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
                                          .view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).split(C, dim=2)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        if self.flash:
            y = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=True)
        else:  # manual path — identical math, materializes (B, nh, T, T)
            att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
            att = self.attn_dropout(F.softmax(att, dim=-1))
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.resid_dropout(self.proj(y))

Prove the two paths agree — same module, same weights, flag flipped:

torch.manual_seed(0)
m = CausalSelfAttention(64, 4, 32, dropout=0.0, flash=True).eval()
x = torch.randn(2, 16, 64)
with torch.no_grad():
    y_flash = m(x)
    m.flash = False
    y_manual = m(x)
print(torch.allclose(y_flash, y_manual, atol=1e-5))   # True
print((y_flash - y_manual).abs().max())               # ~1e-7, float round-off

The tolerance is 1e-5, not exact equality: fused kernels sum in a different order than our explicit matmuls, and floating-point addition isn’t associative. Differences around 1e-7 are healthy; differences around 1e-2 mean you double-scaled or mangled the mask.

One benchmark to make the payoff concrete (numbers from an A100; your ratios will vary, and on CPU the gap is smaller since there’s no FlashAttention kernel — SDPA still wins via fusion):

T (seq len) manual path SDPA (Flash) speedup attn-matrix memory avoided
256 0.31 ms 0.11 ms ~2.8× 3 MB
1024 3.9 ms 0.9 ms ~4.3× 50 MB
4096 71 ms 11 ms ~6.5× 805 MB

The lesson generalizes: write the naive version first to own the math, then swap in the fused op and verify equivalence numerically. That verification habit is the difference between “I used FlashAttention” and “I know my FlashAttention call is correct.”

🧪 Your task

Your CausalSelfAttention computes all heads in one batched matmul. Prove to yourself that this really is n_head independent attentions by writing mha_by_loop(mha, x): a function that takes your (eval-mode, manual-path) module and an input x, extracts the fused QKV weights, runs attention one head at a time in a Python loop over per-head (B, T, head_dim) slices, concatenates the results, applies the output projection — and matches the module’s output to within 1e-5.

Hint: the fused QKV weight has shape (3C, C) and is stacked \([W^Q; W^K; W^V]\) along dim 0, so Wq = mha.qkv.weight[:C]. Head i of a projected tensor q (shape (B, T, C)) is the channel slice q[..., i*hd:(i+1)*hd]. Remember x @ W.T replicates nn.Linear with bias=False, and scale by sqrt(head_dim).

Solution
import math
import torch
import torch.nn.functional as F

def mha_by_loop(mha, x):
    B, T, C = x.shape
    nh, hd = mha.n_head, mha.head_dim
    W = mha.qkv.weight                    # (3C, C), rows stacked [Wq; Wk; Wv]
    Wq, Wk, Wv = W[:C], W[C:2*C], W[2*C:]

    # Full-width projections; per-head split is just channel slicing.
    q = x @ Wq.T                          # (B, T, C)
    k = x @ Wk.T
    v = x @ Wv.T

    mask = torch.tril(torch.ones(T, T, device=x.device))
    heads = []
    for i in range(nh):
        sl = slice(i * hd, (i + 1) * hd)
        qi, ki, vi = q[..., sl], k[..., sl], v[..., sl]   # (B, T, hd) each
        att = (qi @ ki.transpose(-2, -1)) / math.sqrt(hd)  # (B, T, T)
        att = att.masked_fill(mask == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        heads.append(att @ vi)                             # (B, T, hd)

    y = torch.cat(heads, dim=-1)          # (B, T, C): concat along channels
    return y @ mha.proj.weight.T          # output projection

# Verify against the vectorized module:
torch.manual_seed(0)
mha = CausalSelfAttention(64, 4, 32, dropout=0.0, flash=False).eval()
x = torch.randn(2, 16, 64)
with torch.no_grad():
    y_module = mha(x)
    y_loop = mha_by_loop(mha, x)
print(torch.allclose(y_module, y_loop, atol=1e-5))   # True
print((y_module - y_loop).abs().max())               # ~1e-7

If your allclose fails, the usual suspects in order: scaled by sqrt(C) instead of sqrt(hd); sliced the QKV weight along the wrong dim (rows, dim 0 — not columns); forgot .T when applying nn.Linear weights manually; or left dropout enabled (call .eval() and use torch.no_grad()).

Key takeaways

  • One softmax = one focus per position; n_head parallel heads in head_dim = C // n_head subspaces let each position attend to several things at once — at roughly the cost of a single big head.
  • The reshape dance is view(B, T, nh, hd) then transpose(1, 2) → (B, nh, T, hd). Viewing straight to (B, nh, T, hd) runs without error and silently scrambles positions into heads.
  • With heads as a batch dimension, one q @ k.transpose(-2, -1) computes every head’s attention matrix at once; scale by \(\sqrt{d_{head}}\), not \(\sqrt{C}\).
  • Merging is the dance reversed — transpose(1, 2).contiguous().view(B, T, C) — and the output projection \(W^O\) is what lets heads exchange information; never drop it.
  • GPT-2’s fused Linear(C, 3C) computes Q, K, V in one matmul; attention costs \(4C^2\) parameters total.
  • F.scaled_dot_product_attention(q, k, v, is_causal=True) is a drop-in FlashAttention fast path: delete your manual scale, gate dropout_p on self.training, and verify against your manual path with allclose.

Tomorrow we wrap this module with a feed-forward network, LayerNorm, and residual connections into the transformer block — the repeating unit we’ll stack into a full GPT.


🏠 ⚡ Course home  |  ← Day 04  |  Day 06 →  |  📚 All mini-courses

 

© Kader Mohideen