Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

In this lesson

  • Lesson 5 — WikiGPT-124M: The Model, Line by Line
    • The blueprint: one config, one diagram
    • RMSNorm: normalize the scale, skip the mean
    • RoPE: position as rotation
    • Attention and SwiGLU: the two workhorses
    • The block: pre-norm residuals
    • Weight tying, initialization, and the parameter count
    • The complete src/model.py
    • Run the smoke test — cost: $0.00
    • 🧪 Your task
    • Key takeaways
    • Coming up

📖 Build Your Own Wikipedia LLM · Lesson 5 — WikiGPT-124M: The Model, Line by Line

🏠 📖 Course home  |  ← Lesson 04  |  Lesson 06 →  |  📚 All mini-courses


Lesson 5 — WikiGPT-124M: The Model, Line by Line

In Lesson 4 you trained a 32,768-symbol BPE tokenizer on your cleaned Wikipedia corpus and reserved the <|user|>, <|assistant|>, and <|end|> chat tokens that will matter much later, in Lessons 9 and 10. You now have data/tokens/ full of uint16 token ids and a tokenizer/tokenizer.json that maps them back to text. What you don’t have yet is the thing that will consume those tokens: the network itself.

This lesson is the heart of the course. We build src/model.py — the complete WikiGPT-124M — in pure PyTorch, with no external model libraries, and we justify every single design decision: why RMSNorm and not LayerNorm, why RoPE and not learned position embeddings, why SwiGLU and not the classic 4× GELU MLP, why no biases anywhere, why the embedding matrix does double duty as the output head. These are not arbitrary choices; they are, almost exactly, the Llama-family recipe scaled down to a GPT-2-small skeleton. Best of all: everything in this lesson runs on your laptop CPU. Your vast.ai wallet stays closed until Lesson 6.

🎯 In this lesson you will: write the complete src/model.py — RMSNorm, RoPE, flash attention via SDPA, SwiGLU, pre-norm blocks, weight tying, the GPT-2 init scheme, a parameter-count breakdown, and a generate() you’ll reuse for the rest of the course — then smoke-test it on CPU by checking that the untrained loss lands at ln(32768) ≈ 10.4.

The blueprint: one config, one diagram

Before any code, hold the whole machine in your head. WikiGPT-124M is a decoder-only transformer: token ids go in, one embedding lookup turns them into vectors, twelve identical blocks refine those vectors, a final norm and a linear head turn them back into a probability distribution over the next token. That’s the entire architecture — everything else is detail inside the blocks.

flowchart TB
    tok["token ids<br/>B × T"] --> emb["token embedding<br/>32768 × 768"]
    emb --> b1["Block 1"]
    b1 --> b2["Block 2"]
    b2 --> dots["⋮<br/>12 blocks total"]
    dots --> b12["Block 12"]
    b12 --> fn["final RMSNorm"]
    fn --> head["LM head 768 → 32768<br/>same matrix as the embedding"]
    head --> logits["logits<br/>B × T × 32768"]
    rope["RoPE cos/sin buffers<br/>1024 × 32, precomputed once"] -.->|"rotates q,k in every block"| b1
    rope -.-> b12

The five numbers below are fixed for the whole course — the tokenizer (Lesson 4), the data packing (Lesson 4), the training config (Lesson 6), and every checkpoint you’ll ever load all assume them:

@dataclass
class GPTConfig:
    vocab_size: int = 32768   # our BPE from Lesson 4, chat tokens already reserved
    block_size: int = 1024    # maximum context length
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768         # so head_dim = 768 / 12 = 64
    rope_theta: float = 10000.0
    dropout: float = 0.0

Two things worth pausing on. First, dropout: 0.0 — dropout is a regularizer against overfitting, and in Lesson 7 we train for roughly one pass over ~4B fresh tokens. The model never sees the same batch twice; there is nothing to overfit to, and dropout would only slow learning. Every modern pretraining run (Llama, Qwen, Mistral) does the same. Second, rope_theta: 10000.0 — the classic base frequency from the original RoPE paper and Llama 1/2. Models targeting 100k+ contexts raise it to 500,000; at our 1024-token context, 10,000 is exactly right.

And here is how our choices line up against the model this one is named after:

Design choice GPT-2-124M (2019) WikiGPT-124M Who else does it our way
Positional info learned absolute embeddings (wpe) RoPE applied to q/k Llama 1–3, Mistral, Qwen
Normalization LayerNorm (mean + bias) RMSNorm Llama-family, T5
Norm placement pre-norm pre-norm both
FFN 4× GELU MLP SwiGLU at 8/3× Llama, PaLM, Mistral
Biases in every Linear + LayerNorm none anywhere Llama-family
Attention impl hand-written masked softmax F.scaled_dot_product_attention (flash) everything modern
Vocab 50,257 (GPT-2 BPE) 32,768 (your BPE) —
Weight tying yes yes GPT-2, Gemma
Exact params 124.4M 110.1M —

About that last row: the twelve transformer blocks are parameter-for-parameter the same size as GPT-2’s — 85.0M (SwiGLU at 8/3× has exactly the same parameter count as a 4× two-matrix MLP, as we’ll verify below). The difference is entirely the embedding: our lean 32k vocabulary costs 25.2M where GPT-2’s 50k costs 38.6M, and RoPE deletes the position-embedding table outright. Had we kept GPT-2’s vocab, we’d land within half a percent of 124.4M. The name marks the class — this is a GPT-2-small-scale model — and the smaller exact count is pure savings: fewer embedding parameters, same transformer.

RMSNorm: normalize the scale, skip the mean

Every block needs normalization, because stacking twelve residual blocks without it lets activation magnitudes drift — grow layer over layer until bf16 gradients overflow, or shrink until they vanish. Classic LayerNorm fixes this by subtracting the mean, dividing by the standard deviation, then applying a learned gain and bias. RMSNorm (Zhang & Sennrich, 2019) observed that the re-centering is dead weight: what stabilizes training is controlling the scale of the vector, not its mean. So RMSNorm just divides by the root-mean-square and applies a gain:

\[\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \odot g\]

Same stability, one fewer reduction over the 768-dim vector, one fewer parameter tensor, measurably faster — which is why the entire Llama family, T5, and Mistral use it. One implementation subtlety that will matter in Lesson 6: we compute the normalization in float32 even when the model runs in bf16. The sum of 768 squared values loses precision in bf16’s 8-bit mantissa; doing the reduction in fp32 and casting back costs nothing and prevents subtle loss spikes.

class RMSNorm(nn.Module):
    """y = x / rms(x) * gain. No mean subtraction, no bias."""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))   # gain, init to identity

    def forward(self, x):                              # x: (B, T, C)
        norm = x.float() * torch.rsqrt(
            x.float().pow(2).mean(-1, keepdim=True) + self.eps
        )                                              # fp32 for stability
        return norm.type_as(x) * self.weight

(PyTorch ≥ 2.4 ships an equivalent nn.RMSNorm; we write our own because seeing the six lines is the point of this lesson.)

RoPE: position as rotation

A transformer’s attention is permutation-invariant — shuffle the input tokens and, without positional information, the attention scores don’t change. GPT-2 fixed this with a learned table of 1024 position vectors added to the token embeddings. That works, but it has two flaws: position becomes an absolute property (“this is token #517”) when language mostly cares about relative offsets (“the adjective three tokens back”), and the table hard-caps the context at whatever length you trained.

Rotary Position Embedding (RoPE, Su et al. 2021) encodes position where it actually matters — inside the attention dot product. Split each 64-dim query/key head vector into 32 pairs, and rotate pair \(i\) of the vector at position \(m\) by the angle \(m \cdot \theta_i\), where \(\theta_i = 10000^{-2i/64}\). Low-index pairs spin fast (fine-grained, nearby positions), high-index pairs spin slowly (coarse, long-range). The magic is in what happens to the attention score: the dot product of a query rotated by \(m\theta\) and a key rotated by \(n\theta\) depends only on the difference \(m-n\). Relative position, for free, with zero learned parameters — this is the mechanism in Llama, Mistral, Qwen, and essentially every model trained since 2022. You will prove the relative-only property yourself in this lesson’s task.

position m = 0 position m = 5 position m = 10 same (q₁, q₂) pair, rotated by m·θᵢ — the angle IS the position

Implementation: the angles depend only on position and pair index, never on the data, so we precompute cos/sin tables once — shape (1024, 32) — and register them as non-persistent buffers (they follow the model to the GPU but don’t bloat checkpoints, since they’re recomputable from the config).

def precompute_rope(head_dim: int, max_seq_len: int, theta: float):
    """cos/sin tables, shape (max_seq_len, head_dim//2), float32."""
    # theta_i = theta ** (-2i / head_dim), i = 0..head_dim/2-1
    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
    t = torch.arange(max_seq_len, dtype=torch.float32)     # positions 0..T-1
    freqs = torch.outer(t, inv_freq)                       # (T, head_dim//2): m * theta_i
    return torch.cos(freqs), torch.sin(freqs)


def apply_rope(x, cos, sin):
    """Rotate q or k by position-dependent angles.  x: (B, n_head, T, head_dim)."""
    T = x.size(2)
    cos = cos[:T].view(1, 1, T, -1)     # (1, 1, T, 32) — broadcasts over B and heads
    sin = sin[:T].view(1, 1, T, -1)
    x1, x2 = x.float().chunk(2, dim=-1) # pair dim i with dim i + head_dim//2
    y1 = x1 * cos - x2 * sin            # standard 2-D rotation of each pair
    y2 = x1 * sin + x2 * cos
    return torch.cat([y1, y2], dim=-1).type_as(x)

Note we rotate in fp32 and cast back — same reasoning as RMSNorm: trigonometric mixing in bf16 accumulates rounding noise across 12 layers.

Attention and SwiGLU: the two workhorses

Attention. Each block’s attention lets every position gather information from every earlier position. We fuse the three q/k/v projections into one nn.Linear(768, 2304) — one matmul instead of three launches — reshape into 12 heads of 64 dims, rotate q and k with RoPE, and then hand everything to F.scaled_dot_product_attention. That one call is the single most important performance decision in the file: with is_causal=True and bf16 inputs on CUDA, PyTorch dispatches to a FlashAttention kernel that never materializes the T×T attention matrix. Memory drops from O(T²) to O(T), and this is worth roughly 2× throughput at our 1024-token context — a large slice of the 45–55k tok/s that keeps Lesson 7’s bill at $8–12 instead of $25. On CPU it silently falls back to the math kernel, which is why our smoke test still works.

What breaks without is_causal=True? Position \(t\) could attend to position \(t+1\) — whose embedding is the token it’s being asked to predict. Training loss would plummet toward zero within minutes, you’d think you invented a miracle, and the model would emit garbage at generation time when the future genuinely isn’t there. If you ever see suspiciously fast loss curves, check the mask first.

class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.n_head = config.n_head
        self.head_dim = config.n_embd // config.n_head        # 64
        self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.dropout = config.dropout

    def forward(self, x, cos, sin):
        B, T, C = x.shape                                     # (B, T, 768)
        q, k, v = self.qkv(x).split(C, dim=2)                 # 3 × (B, T, 768)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, 12, T, 64)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, 12, T, 64)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, 12, T, 64)
        q = apply_rope(q, cos, sin)   # position enters HERE, not in the embeddings
        k = apply_rope(k, cos, sin)   # (v is never rotated — only scores need position)
        y = F.scaled_dot_product_attention(                   # flash path on CUDA+bf16
            q, k, v,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=True,           # kernel-internal causal mask, no T×T tensor
        )                                                     # (B, 12, T, 64)
        y = y.transpose(1, 2).contiguous().view(B, T, C)      # re-assemble heads: (B, T, 768)
        return self.proj(y)                                   # (B, T, 768)

SwiGLU. The feed-forward network is where most of the model’s knowledge lives — it’s the per-token “lookup and transform” stage. GPT-2 used Linear(768→3072) → GELU → Linear(3072→768). Shazeer’s 2020 paper “GLU Variants Improve Transformer” showed that a gated variant consistently wins at equal parameter count: compute two parallel projections, pass one through SiLU (the smooth sigmoid-weighted activation), and multiply them elementwise — one path decides what to say, the other how much of it gets through. Because SwiGLU needs three matrices instead of two, the hidden width shrinks to 8/3 of the embedding dim to keep parameters equal: \(8/3 \times 768 = 2048\), and conveniently \(3 \times 768 \times 2048 = 2 \times 768 \times 3072 = 4{,}718{,}592\) — exactly GPT-2’s FFN budget, spent better. Llama, PaLM, and Mistral all made the same trade.

class SwiGLU(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden = int(8 * config.n_embd / 3)         # 2048 — same param count as a 4x GELU MLP
        hidden = 256 * ((hidden + 255) // 256)      # round UP to a multiple of 256: tensor cores
        self.w_gate = nn.Linear(config.n_embd, hidden, bias=False)
        self.w_up   = nn.Linear(config.n_embd, hidden, bias=False)
        self.w_down = nn.Linear(hidden, config.n_embd, bias=False)

    def forward(self, x):                            # (B, T, 768)
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))   # 768 → 2048 → 768

The rounding line is a no-op for us (2048 is already a multiple of 256) but it’s load-bearing at other widths: GPU tensor cores want matmul dimensions divisible by large powers of two, and an odd hidden size like 1365 (what 8/3 × 512 would give) can cost double-digit percent throughput.

No biases, anywhere. Every nn.Linear in this file has bias=False. Ablations across the Llama-scale literature show bias terms contribute nothing measurable to quality in pre-normed transformers — the norms’ gains absorb their job — while adding parameters, memory traffic, and kernel epilogue work. Deleting them is free performance.

The block: pre-norm residuals

Now compose. Each block is two residual updates: attention (tokens talk to each other) then the FFN (each token thinks alone). The norm goes before each sub-layer — “pre-norm” — so the residual stream itself is never normalized, only the input to each computation. This is what makes deep stacks trainable without warmup gymnastics: the identity path from embedding to logits is completely clean, and each block just adds a bounded correction onto it. GPT-2 already did this (its one genuinely modern trait), and nobody has gone back.

flowchart TB
    x["input x<br/>B × T × 768"] --> n1["RMSNorm"]
    n1 --> att["causal self-attention<br/>RoPE on q,k · SDPA flash · 12 heads × 64"]
    att --> a1(("+"))
    x --> a1
    a1 --> n2["RMSNorm"]
    n2 --> ffn["SwiGLU FFN<br/>768 → 2048 → 768"]
    ffn --> a2(("+"))
    a1 --> a2
    a2 --> out["output<br/>B × T × 768"]

class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.norm1 = RMSNorm(config.n_embd)
        self.attn  = CausalSelfAttention(config)
        self.norm2 = RMSNorm(config.n_embd)
        self.mlp   = SwiGLU(config)

    def forward(self, x, cos, sin):
        x = x + self.attn(self.norm1(x), cos, sin)   # pre-norm residual: communicate
        x = x + self.mlp(self.norm2(x))              # pre-norm residual: compute
        return x

Weight tying, initialization, and the parameter count

Three decisions live at the top level of the model.

Weight tying. The embedding maps token → vector; the LM head maps vector → token scores. These are near-transposes of the same semantic relationship, so we make them literally the same 32768×768 matrix: self.lm_head.weight = self.wte.weight. This saves 25.2M parameters — 19% of the model — and at our scale measurably helps quality, because the embedding gets gradient signal from every output position, not just from the tokens that appear in the input. GPT-2 tied; Gemma ties; large models (70B+) often untie because the savings become negligible. At 124M-class scale, tying is the clear win.

Initialization. All weights start from \(\mathcal{N}(0, 0.02^2)\) — the GPT-2 value, well-matched to \(d=768\) (it’s close to the Xavier-ish \(1/\sqrt{768} \approx 0.036\), slightly conservative). Then one crucial correction: the two matrices in each block that write into the residual stream — the attention output projection and the FFN down projection — get scaled down by \(1/\sqrt{2 \cdot n_{layer}}\). Why: the stream receives \(2 \times 12 = 24\) additive writes, and independent additions grow the stream’s variance linearly with depth. Dividing each writer’s init by \(\sqrt{24}\) keeps the stream’s magnitude roughly constant from block 1 to block 12. Skip this and you’ll see it immediately in Lesson 7’s W&B dashboard: activation norms climbing layer by layer, grad-norm spikes in the first few hundred steps, and with bad luck a diverged run — $10 of GPU time gone. It’s two lines of code.

The count. Here is where every parameter lives:

Component Shape Count
Token embedding (tied with LM head) 32768 × 768 25,165,824
Attention per layer: fused qkv + out proj 768×2304 + 768×768 2,359,296
SwiGLU per layer: gate + up + down 2×(768×2048) + 2048×768 4,718,592
RMSNorm gains per layer 2 × 768 1,536
One block 7,079,424
12 blocks 84,953,088
Final RMSNorm 768 768
Total (tied head counted once) 110,119,680

The 85.0M transformer stack is byte-identical in size to GPT-2-124M’s; the leaner vocab and the deleted position table account for the rest. The print_param_table() method in the file below prints exactly this breakdown, and you should recognize every row in it.

The complete src/model.py

Everything above, assembled into the file that Lessons 6 through 11 will import unchanged. Read the shape comments in forward — being able to trace (B, T) → (B, T, 768) → (B, T, 32768) in your head is the real deliverable of this lesson.

"""WikiGPT-124M -- a modern decoder-only transformer, from scratch.

Llama-family choices on a GPT-2-small skeleton:
RMSNorm (pre-norm) - RoPE - SwiGLU - flash SDPA - no biases - tied embeddings.
CPU-friendly for testing; bf16 + torch.compile on the GPU from Lesson 6.
"""

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------------------------------------- config ---

@dataclass
class GPTConfig:
    vocab_size: int = 32768   # our BPE from Lesson 4, chat tokens already reserved
    block_size: int = 1024    # maximum context length
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768         # head_dim = 768 / 12 = 64
    rope_theta: float = 10000.0
    dropout: float = 0.0      # one pass over fresh tokens: nothing to overfit to


# ---------------------------------------------------------------- rmsnorm ---

class RMSNorm(nn.Module):
    """y = x / rms(x) * gain. No mean subtraction, no bias."""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):                              # (B, T, C)
        norm = x.float() * torch.rsqrt(
            x.float().pow(2).mean(-1, keepdim=True) + self.eps
        )                                              # reduce in fp32 for bf16 safety
        return norm.type_as(x) * self.weight


# ------------------------------------------------------------------- rope ---

def precompute_rope(head_dim: int, max_seq_len: int, theta: float):
    """cos/sin tables, shape (max_seq_len, head_dim//2), float32."""
    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
    t = torch.arange(max_seq_len, dtype=torch.float32)
    freqs = torch.outer(t, inv_freq)                   # (T, head_dim//2): m * theta_i
    return torch.cos(freqs), torch.sin(freqs)


def apply_rope(x, cos, sin):
    """Rotate q or k by position-dependent angles.  x: (B, n_head, T, head_dim)."""
    T = x.size(2)
    cos = cos[:T].view(1, 1, T, -1)                    # broadcast over batch and heads
    sin = sin[:T].view(1, 1, T, -1)
    x1, x2 = x.float().chunk(2, dim=-1)                # pair dim i with dim i + hd//2
    y1 = x1 * cos - x2 * sin                           # 2-D rotation of each pair
    y2 = x1 * sin + x2 * cos
    return torch.cat([y1, y2], dim=-1).type_as(x)


# -------------------------------------------------------------- attention ---

class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.n_head = config.n_head
        self.head_dim = config.n_embd // config.n_head
        self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.dropout = config.dropout

    def forward(self, x, cos, sin):
        B, T, C = x.shape                                              # (B, T, 768)
        q, k, v = self.qkv(x).split(C, dim=2)                          # 3 x (B, T, 768)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, 12, T, 64)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, 12, T, 64)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, 12, T, 64)
        q = apply_rope(q, cos, sin)                    # position enters here, not in wte
        k = apply_rope(k, cos, sin)
        y = F.scaled_dot_product_attention(            # flash kernel on CUDA + bf16
            q, k, v,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=True,                            # no T x T mask ever materialized
        )                                              # (B, 12, T, 64)
        y = y.transpose(1, 2).contiguous().view(B, T, C)               # (B, T, 768)
        return self.proj(y)


# ------------------------------------------------------------------ swiglu ---

class SwiGLU(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden = int(8 * config.n_embd / 3)            # 2048: same params as 4x GELU MLP
        hidden = 256 * ((hidden + 255) // 256)         # tensor-core friendly width
        self.w_gate = nn.Linear(config.n_embd, hidden, bias=False)
        self.w_up   = nn.Linear(config.n_embd, hidden, bias=False)
        self.w_down = nn.Linear(hidden, config.n_embd, bias=False)

    def forward(self, x):                              # (B, T, 768)
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))     # 768->2048->768


# ------------------------------------------------------------------- block ---

class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.norm1 = RMSNorm(config.n_embd)
        self.attn  = CausalSelfAttention(config)
        self.norm2 = RMSNorm(config.n_embd)
        self.mlp   = SwiGLU(config)

    def forward(self, x, cos, sin):
        x = x + self.attn(self.norm1(x), cos, sin)     # pre-norm residual
        x = x + self.mlp(self.norm2(x))
        return x


# --------------------------------------------------------------------- gpt ---

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.blocks = nn.ModuleList(Block(config) for _ in range(config.n_layer))
        self.norm_f = RMSNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight = self.wte.weight          # weight tying: one matrix, two jobs

        cos, sin = precompute_rope(
            config.n_embd // config.n_head, config.block_size, config.rope_theta
        )
        self.register_buffer("rope_cos", cos, persistent=False)  # follows .to(device),
        self.register_buffer("rope_sin", sin, persistent=False)  # skipped in checkpoints

        self.apply(self._init_weights)
        # scaled init for the two matrices per block that WRITE the residual stream:
        # 2 * n_layer additive writes -> shrink each by sqrt(2 * n_layer)
        for name, p in self.named_parameters():
            if name.endswith("attn.proj.weight") or name.endswith("mlp.w_down.weight"):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        assert T <= self.config.block_size, f"sequence length {T} > block_size"
        x = self.wte(idx)                              # (B, T) -> (B, T, 768)
        for block in self.blocks:
            x = block(x, self.rope_cos, self.rope_sin) # (B, T, 768), twelve times
        x = self.norm_f(x)                             # (B, T, 768)
        if targets is not None:
            logits = self.lm_head(x)                   # (B, T, 32768)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),      # (B*T, 32768)
                targets.view(-1),                      # (B*T,)
                ignore_index=-1,                       # Lesson 9 masks prompts with -1
            )
            return logits, loss
        logits = self.lm_head(x[:, [-1], :])           # inference: last position only
        return logits, None

    # ------------------------------------------------------------- utils ---

    def num_params(self) -> int:
        return sum(p.numel() for p in self.parameters())  # tied head counted once

    def print_param_table(self):
        c = self.config
        groups = {
            "token embedding (tied w/ lm_head)":
                self.wte.weight.numel(),
            f"attention ({c.n_layer} layers)":
                sum(p.numel() for n, p in self.named_parameters() if ".attn." in n),
            f"SwiGLU FFN ({c.n_layer} layers)":
                sum(p.numel() for n, p in self.named_parameters() if ".mlp." in n),
            "RMSNorm gains":
                sum(p.numel() for n, p in self.named_parameters() if "norm" in n),
        }
        for name, n in groups.items():
            print(f"{name:<38s} {n:>12,d}")
        print(f"{'TOTAL':<38s} {self.num_params():>12,d}")

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
        """Autoregressive sampling. temperature=0.0 -> greedy decoding.
        # ponytail: recomputes the full prefix each step -- fine at 124M/1k ctx;
        # add a KV cache if serve.py (Lesson 11) ever feels slow.
        """
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.block_size:]        # crop to context window
            logits, _ = self(idx_cond)                          # (B, 1, 32768)
            logits = logits[:, -1, :]                           # (B, 32768)
            if temperature == 0.0:
                next_tok = logits.argmax(dim=-1, keepdim=True)  # greedy
            else:
                logits = logits / temperature                   # flatten or sharpen
                if top_k is not None:                           # keep k best
                    kth = torch.topk(logits, top_k, dim=-1).values[:, [-1]]
                    logits = logits.masked_fill(logits < kth, float("-inf"))
                if top_p is not None:                           # nucleus sampling
                    s_logits, s_idx = torch.sort(logits, descending=True, dim=-1)
                    probs = F.softmax(s_logits, dim=-1)
                    cum = probs.cumsum(dim=-1)
                    mask = cum - probs > top_p                  # keep first past threshold
                    s_logits = s_logits.masked_fill(mask, float("-inf"))
                    logits = torch.full_like(logits, float("-inf")) \
                                  .scatter_(-1, s_idx, s_logits)
                probs = F.softmax(logits, dim=-1)
                next_tok = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_tok], dim=1)
        return idx


# -------------------------------------------------------------- smoke test ---

if __name__ == "__main__":
    torch.manual_seed(1337)
    cfg = GPTConfig()
    model = GPT(cfg)
    model.print_param_table()

    B, T = 2, 128
    idx = torch.randint(0, cfg.vocab_size, (B, T))
    targets = torch.randint(0, cfg.vocab_size, (B, T))
    logits, loss = model(idx, targets)
    assert logits.shape == (B, T, cfg.vocab_size)

    expected = math.log(cfg.vocab_size)                # ln(32768) = 10.397
    print(f"initial loss {loss.item():.3f}   ln(vocab_size) {expected:.3f}")
    assert abs(loss.item() - expected) < 0.5, "broken init: loss should start ~uniform"

    out = model.generate(idx[:, :4], max_new_tokens=8, temperature=0.8, top_k=50)
    assert out.shape == (B, 12)
    out = model.generate(idx[:, :4], max_new_tokens=8, temperature=0.0)   # greedy path
    out = model.generate(idx[:, :4], max_new_tokens=8, top_p=0.9)         # nucleus path
    assert out.shape == (B, 12)
    print("smoke test passed")

Two forward-looking details you just planted: ignore_index=-1 in the cross-entropy is inert now, but in Lesson 9 it becomes the mechanism for masking prompt tokens so SFT loss lands only on assistant tokens. And the logits[:, [-1], :] inference shortcut skips computing 32,768 scores for 1,023 positions you’d throw away — a real speedup inside generate()’s loop.

Run the smoke test — cost: $0.00

No vast.ai instance this lesson. The whole point of the smoke test is to catch bugs on your laptop before any of them can cost GPU-hours:

cd wikillm
python src/model.py

Expected output (exact loss varies slightly with PyTorch version, but must sit near 10.4):

token embedding (tied w/ lm_head)        25,165,824
attention (12 layers)                    28,311,552
SwiGLU FFN (12 layers)                   56,623,104
RMSNorm gains                                19,200
TOTAL                                   110,119,680
initial loss 10.482   ln(vocab_size) 10.397
smoke test passed

The loss check is the deepest single line of validation in the file. An untrained model with a sane init knows nothing, so its predictive distribution should be near-uniform over 32,768 tokens, and cross-entropy against a uniform distribution is exactly \(\ln(32768) = 15\ln 2 \approx 10.397\). It lands slightly above that because the random logits aren’t exactly zero (their variance adds roughly \(\sigma^2/2\) to the loss). What the check catches: an init std that’s 10× too big (loss starts at 13+ and the run diverges), a missing residual scaling (loss fine, but activations balloon by layer 12), broken weight tying, a shape bug silently broadcasting. When training starts in Lesson 7, watching the loss fall from 10.4 is watching the model climb down from perfect ignorance — every 0.69 drop halves the effective number of tokens it’s choosing between.

🧪 Your task

Prove RoPE’s central claim to yourself: the attention score between a query at position \(m\) and a key at position \(n\) depends only on the offset \(m - n\), not on the absolute positions. Write scripts/check_rope.py that takes one random query vector and one random key vector (64 dims, no batch), places them at several different absolute-position pairs sharing the same offset (e.g. offset 7 at positions (7,0), (57,50), (1000,993)), rotates each with the tables from precompute_rope, and asserts the dot products are identical to within 1e-4. Then break it on purpose: show the score does change when the offset changes.

Solution
# scripts/check_rope.py
import sys, torch
sys.path.append("src")
from model import precompute_rope

torch.manual_seed(0)
head_dim = 64
cos, sin = precompute_rope(head_dim, 1024, 10000.0)   # (1024, 32) each

def rope_at(x, pos):
    """Rotate a single (head_dim,) vector as if it sat at absolute position `pos`."""
    x1, x2 = x.chunk(2)                                # (32,), (32,)
    c, s = cos[pos], sin[pos]
    return torch.cat([x1 * c - x2 * s, x1 * s + x2 * c])

q = torch.randn(head_dim)
k = torch.randn(head_dim)

# same offset, wildly different absolute positions -> identical scores
offset = 7
scores = [rope_at(q, m) @ rope_at(k, m - offset) for m in (7, 57, 300, 1000)]
spread = max(scores) - min(scores)
print(f"offset {offset}: scores {[f'{s:.6f}' for s in scores]}  spread {spread:.2e}")
assert spread < 1e-4, "RoPE should depend only on relative offset"

# different offset -> different score (sanity that we're not testing a constant)
other = rope_at(q, 20) @ rope_at(k, 12)                # offset 8
assert abs(other - scores[0]) > 1e-3
print("RoPE is purely relative -- check passed")

Why it works: rotating \(q\) by angle \(m\theta_i\) and \(k\) by \(n\theta_i\) means their per-pair dot product contains \(\cos(m\theta_i - n\theta_i)\) terms — rotation matrices compose by angle subtraction, so only \(m - n\) survives. Absolute position cancels out pair by pair, in all 32 pairs at once.

Key takeaways

  • WikiGPT-124M is the Llama recipe at GPT-2-small scale: RMSNorm, RoPE, SwiGLU, pre-norm, no biases, tied embeddings — every choice matches what frontier-lab models converged on, just twelve layers of it.
  • RMSNorm drops LayerNorm’s mean-centering and bias because only scale control stabilizes deep residual stacks; compute the reduction in fp32 even when training in bf16.
  • RoPE encodes position as a rotation of q/k pairs inside the attention dot product, making scores depend only on relative offset — zero learned parameters, precomputed cos/sin buffers, and you proved the relative property yourself.
  • F.scaled_dot_product_attention(is_causal=True) buys the FlashAttention kernel for free — O(T) memory, roughly 2× throughput, and no hand-rolled T×T mask to get wrong.
  • SwiGLU at 8/3× hidden width spends exactly the same 4.72M parameters per layer as GPT-2’s 4× GELU MLP, and spends them better; the gate/up/down structure is standard across Llama, PaLM, and Mistral.
  • Init is 0.02 everywhere plus \(1/\sqrt{2 \cdot n_{layer}}\) scaling on the two residual-writing projections per block — two lines that prevent activation blow-up across 24 residual writes.
  • The tied embedding counts once: 110,119,680 parameters exactly, with an 85.0M transformer stack identical in size to GPT-2-124M’s — the difference is purely our leaner 32k vocab.
  • An untrained model must score \(\approx \ln(V) = 10.4\); that one assert catches init, tying, and shape bugs on your laptop for $0.00 before the GPU meter starts.

Coming up

In Lesson 6 we wrap this model in src/train.py — bf16 autocast, torch.compile, cosine LR schedule with warmup, gradient accumulation, W&B logging, and crash-proof resumable checkpoints — the engine that has to survive 20+ unattended hours on a rented 4090.


🏠 📖 Course home  |  ← Lesson 04  |  Lesson 06 →  |  📚 All mini-courses

 

© Kader Mohideen