flowchart TB
tok["token ids<br/>B × T"] --> emb["token embedding<br/>32768 × 768"]
emb --> b1["Block 1"]
b1 --> b2["Block 2"]
b2 --> dots["⋮<br/>12 blocks total"]
dots --> b12["Block 12"]
b12 --> fn["final RMSNorm"]
fn --> head["LM head 768 → 32768<br/>same matrix as the embedding"]
head --> logits["logits<br/>B × T × 32768"]
rope["RoPE cos/sin buffers<br/>1024 × 32, precomputed once"] -.->|"rotates q,k in every block"| b1
rope -.-> b12
📖 Build Your Own Wikipedia LLM · Lesson 5 — WikiGPT-124M: The Model, Line by Line
🏠 📖 Course home | ← Lesson 04 | Lesson 06 → | 📚 All mini-courses
Lesson 5 — WikiGPT-124M: The Model, Line by Line
In Lesson 4 you trained a 32,768-symbol BPE tokenizer on your cleaned Wikipedia corpus and reserved the <|user|>, <|assistant|>, and <|end|> chat tokens that will matter much later, in Lessons 9 and 10. You now have data/tokens/ full of uint16 token ids and a tokenizer/tokenizer.json that maps them back to text. What you don’t have yet is the thing that will consume those tokens: the network itself.
This lesson is the heart of the course. We build src/model.py — the complete WikiGPT-124M — in pure PyTorch, with no external model libraries, and we justify every single design decision: why RMSNorm and not LayerNorm, why RoPE and not learned position embeddings, why SwiGLU and not the classic 4× GELU MLP, why no biases anywhere, why the embedding matrix does double duty as the output head. These are not arbitrary choices; they are, almost exactly, the Llama-family recipe scaled down to a GPT-2-small skeleton. Best of all: everything in this lesson runs on your laptop CPU. Your vast.ai wallet stays closed until Lesson 6.
🎯 In this lesson you will: write the complete src/model.py — RMSNorm, RoPE, flash attention via SDPA, SwiGLU, pre-norm blocks, weight tying, the GPT-2 init scheme, a parameter-count breakdown, and a generate() you’ll reuse for the rest of the course — then smoke-test it on CPU by checking that the untrained loss lands at ln(32768) ≈ 10.4.
The blueprint: one config, one diagram
Before any code, hold the whole machine in your head. WikiGPT-124M is a decoder-only transformer: token ids go in, one embedding lookup turns them into vectors, twelve identical blocks refine those vectors, a final norm and a linear head turn them back into a probability distribution over the next token. That’s the entire architecture — everything else is detail inside the blocks.
The five numbers below are fixed for the whole course — the tokenizer (Lesson 4), the data packing (Lesson 4), the training config (Lesson 6), and every checkpoint you’ll ever load all assume them:
@dataclass
class GPTConfig:
vocab_size: int = 32768 # our BPE from Lesson 4, chat tokens already reserved
block_size: int = 1024 # maximum context length
n_layer: int = 12
n_head: int = 12
n_embd: int = 768 # so head_dim = 768 / 12 = 64
rope_theta: float = 10000.0
dropout: float = 0.0Two things worth pausing on. First, dropout: 0.0 — dropout is a regularizer against overfitting, and in Lesson 7 we train for roughly one pass over ~4B fresh tokens. The model never sees the same batch twice; there is nothing to overfit to, and dropout would only slow learning. Every modern pretraining run (Llama, Qwen, Mistral) does the same. Second, rope_theta: 10000.0 — the classic base frequency from the original RoPE paper and Llama 1/2. Models targeting 100k+ contexts raise it to 500,000; at our 1024-token context, 10,000 is exactly right.
And here is how our choices line up against the model this one is named after:
| Design choice | GPT-2-124M (2019) | WikiGPT-124M | Who else does it our way |
|---|---|---|---|
| Positional info | learned absolute embeddings (wpe) |
RoPE applied to q/k | Llama 1–3, Mistral, Qwen |
| Normalization | LayerNorm (mean + bias) | RMSNorm | Llama-family, T5 |
| Norm placement | pre-norm | pre-norm | both |
| FFN | 4× GELU MLP | SwiGLU at 8/3× | Llama, PaLM, Mistral |
| Biases | in every Linear + LayerNorm | none anywhere | Llama-family |
| Attention impl | hand-written masked softmax | F.scaled_dot_product_attention (flash) |
everything modern |
| Vocab | 50,257 (GPT-2 BPE) | 32,768 (your BPE) | — |
| Weight tying | yes | yes | GPT-2, Gemma |
| Exact params | 124.4M | 110.1M | — |
About that last row: the twelve transformer blocks are parameter-for-parameter the same size as GPT-2’s — 85.0M (SwiGLU at 8/3× has exactly the same parameter count as a 4× two-matrix MLP, as we’ll verify below). The difference is entirely the embedding: our lean 32k vocabulary costs 25.2M where GPT-2’s 50k costs 38.6M, and RoPE deletes the position-embedding table outright. Had we kept GPT-2’s vocab, we’d land within half a percent of 124.4M. The name marks the class — this is a GPT-2-small-scale model — and the smaller exact count is pure savings: fewer embedding parameters, same transformer.
RMSNorm: normalize the scale, skip the mean
Every block needs normalization, because stacking twelve residual blocks without it lets activation magnitudes drift — grow layer over layer until bf16 gradients overflow, or shrink until they vanish. Classic LayerNorm fixes this by subtracting the mean, dividing by the standard deviation, then applying a learned gain and bias. RMSNorm (Zhang & Sennrich, 2019) observed that the re-centering is dead weight: what stabilizes training is controlling the scale of the vector, not its mean. So RMSNorm just divides by the root-mean-square and applies a gain:
\[\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \odot g\]
Same stability, one fewer reduction over the 768-dim vector, one fewer parameter tensor, measurably faster — which is why the entire Llama family, T5, and Mistral use it. One implementation subtlety that will matter in Lesson 6: we compute the normalization in float32 even when the model runs in bf16. The sum of 768 squared values loses precision in bf16’s 8-bit mantissa; doing the reduction in fp32 and casting back costs nothing and prevents subtle loss spikes.
class RMSNorm(nn.Module):
"""y = x / rms(x) * gain. No mean subtraction, no bias."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # gain, init to identity
def forward(self, x): # x: (B, T, C)
norm = x.float() * torch.rsqrt(
x.float().pow(2).mean(-1, keepdim=True) + self.eps
) # fp32 for stability
return norm.type_as(x) * self.weight(PyTorch ≥ 2.4 ships an equivalent nn.RMSNorm; we write our own because seeing the six lines is the point of this lesson.)
RoPE: position as rotation
A transformer’s attention is permutation-invariant — shuffle the input tokens and, without positional information, the attention scores don’t change. GPT-2 fixed this with a learned table of 1024 position vectors added to the token embeddings. That works, but it has two flaws: position becomes an absolute property (“this is token #517”) when language mostly cares about relative offsets (“the adjective three tokens back”), and the table hard-caps the context at whatever length you trained.
Rotary Position Embedding (RoPE, Su et al. 2021) encodes position where it actually matters — inside the attention dot product. Split each 64-dim query/key head vector into 32 pairs, and rotate pair \(i\) of the vector at position \(m\) by the angle \(m \cdot \theta_i\), where \(\theta_i = 10000^{-2i/64}\). Low-index pairs spin fast (fine-grained, nearby positions), high-index pairs spin slowly (coarse, long-range). The magic is in what happens to the attention score: the dot product of a query rotated by \(m\theta\) and a key rotated by \(n\theta\) depends only on the difference \(m-n\). Relative position, for free, with zero learned parameters — this is the mechanism in Llama, Mistral, Qwen, and essentially every model trained since 2022. You will prove the relative-only property yourself in this lesson’s task.
Implementation: the angles depend only on position and pair index, never on the data, so we precompute cos/sin tables once — shape (1024, 32) — and register them as non-persistent buffers (they follow the model to the GPU but don’t bloat checkpoints, since they’re recomputable from the config).
def precompute_rope(head_dim: int, max_seq_len: int, theta: float):
"""cos/sin tables, shape (max_seq_len, head_dim//2), float32."""
# theta_i = theta ** (-2i / head_dim), i = 0..head_dim/2-1
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32) # positions 0..T-1
freqs = torch.outer(t, inv_freq) # (T, head_dim//2): m * theta_i
return torch.cos(freqs), torch.sin(freqs)
def apply_rope(x, cos, sin):
"""Rotate q or k by position-dependent angles. x: (B, n_head, T, head_dim)."""
T = x.size(2)
cos = cos[:T].view(1, 1, T, -1) # (1, 1, T, 32) — broadcasts over B and heads
sin = sin[:T].view(1, 1, T, -1)
x1, x2 = x.float().chunk(2, dim=-1) # pair dim i with dim i + head_dim//2
y1 = x1 * cos - x2 * sin # standard 2-D rotation of each pair
y2 = x1 * sin + x2 * cos
return torch.cat([y1, y2], dim=-1).type_as(x)Note we rotate in fp32 and cast back — same reasoning as RMSNorm: trigonometric mixing in bf16 accumulates rounding noise across 12 layers.
Attention and SwiGLU: the two workhorses
Attention. Each block’s attention lets every position gather information from every earlier position. We fuse the three q/k/v projections into one nn.Linear(768, 2304) — one matmul instead of three launches — reshape into 12 heads of 64 dims, rotate q and k with RoPE, and then hand everything to F.scaled_dot_product_attention. That one call is the single most important performance decision in the file: with is_causal=True and bf16 inputs on CUDA, PyTorch dispatches to a FlashAttention kernel that never materializes the T×T attention matrix. Memory drops from O(T²) to O(T), and this is worth roughly 2× throughput at our 1024-token context — a large slice of the 45–55k tok/s that keeps Lesson 7’s bill at $8–12 instead of $25. On CPU it silently falls back to the math kernel, which is why our smoke test still works.
What breaks without is_causal=True? Position \(t\) could attend to position \(t+1\) — whose embedding is the token it’s being asked to predict. Training loss would plummet toward zero within minutes, you’d think you invented a miracle, and the model would emit garbage at generation time when the future genuinely isn’t there. If you ever see suspiciously fast loss curves, check the mask first.
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head # 64
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.dropout = config.dropout
def forward(self, x, cos, sin):
B, T, C = x.shape # (B, T, 768)
q, k, v = self.qkv(x).split(C, dim=2) # 3 × (B, T, 768)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, 12, T, 64)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, 12, T, 64)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, 12, T, 64)
q = apply_rope(q, cos, sin) # position enters HERE, not in the embeddings
k = apply_rope(k, cos, sin) # (v is never rotated — only scores need position)
y = F.scaled_dot_product_attention( # flash path on CUDA+bf16
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True, # kernel-internal causal mask, no T×T tensor
) # (B, 12, T, 64)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble heads: (B, T, 768)
return self.proj(y) # (B, T, 768)SwiGLU. The feed-forward network is where most of the model’s knowledge lives — it’s the per-token “lookup and transform” stage. GPT-2 used Linear(768→3072) → GELU → Linear(3072→768). Shazeer’s 2020 paper “GLU Variants Improve Transformer” showed that a gated variant consistently wins at equal parameter count: compute two parallel projections, pass one through SiLU (the smooth sigmoid-weighted activation), and multiply them elementwise — one path decides what to say, the other how much of it gets through. Because SwiGLU needs three matrices instead of two, the hidden width shrinks to 8/3 of the embedding dim to keep parameters equal: \(8/3 \times 768 = 2048\), and conveniently \(3 \times 768 \times 2048 = 2 \times 768 \times 3072 = 4{,}718{,}592\) — exactly GPT-2’s FFN budget, spent better. Llama, PaLM, and Mistral all made the same trade.
class SwiGLU(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
hidden = int(8 * config.n_embd / 3) # 2048 — same param count as a 4x GELU MLP
hidden = 256 * ((hidden + 255) // 256) # round UP to a multiple of 256: tensor cores
self.w_gate = nn.Linear(config.n_embd, hidden, bias=False)
self.w_up = nn.Linear(config.n_embd, hidden, bias=False)
self.w_down = nn.Linear(hidden, config.n_embd, bias=False)
def forward(self, x): # (B, T, 768)
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) # 768 → 2048 → 768The rounding line is a no-op for us (2048 is already a multiple of 256) but it’s load-bearing at other widths: GPU tensor cores want matmul dimensions divisible by large powers of two, and an odd hidden size like 1365 (what 8/3 × 512 would give) can cost double-digit percent throughput.
No biases, anywhere. Every nn.Linear in this file has bias=False. Ablations across the Llama-scale literature show bias terms contribute nothing measurable to quality in pre-normed transformers — the norms’ gains absorb their job — while adding parameters, memory traffic, and kernel epilogue work. Deleting them is free performance.
The block: pre-norm residuals
Now compose. Each block is two residual updates: attention (tokens talk to each other) then the FFN (each token thinks alone). The norm goes before each sub-layer — “pre-norm” — so the residual stream itself is never normalized, only the input to each computation. This is what makes deep stacks trainable without warmup gymnastics: the identity path from embedding to logits is completely clean, and each block just adds a bounded correction onto it. GPT-2 already did this (its one genuinely modern trait), and nobody has gone back.
flowchart TB
x["input x<br/>B × T × 768"] --> n1["RMSNorm"]
n1 --> att["causal self-attention<br/>RoPE on q,k · SDPA flash · 12 heads × 64"]
att --> a1(("+"))
x --> a1
a1 --> n2["RMSNorm"]
n2 --> ffn["SwiGLU FFN<br/>768 → 2048 → 768"]
ffn --> a2(("+"))
a1 --> a2
a2 --> out["output<br/>B × T × 768"]
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.norm1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.norm2 = RMSNorm(config.n_embd)
self.mlp = SwiGLU(config)
def forward(self, x, cos, sin):
x = x + self.attn(self.norm1(x), cos, sin) # pre-norm residual: communicate
x = x + self.mlp(self.norm2(x)) # pre-norm residual: compute
return xWeight tying, initialization, and the parameter count
Three decisions live at the top level of the model.
Weight tying. The embedding maps token → vector; the LM head maps vector → token scores. These are near-transposes of the same semantic relationship, so we make them literally the same 32768×768 matrix: self.lm_head.weight = self.wte.weight. This saves 25.2M parameters — 19% of the model — and at our scale measurably helps quality, because the embedding gets gradient signal from every output position, not just from the tokens that appear in the input. GPT-2 tied; Gemma ties; large models (70B+) often untie because the savings become negligible. At 124M-class scale, tying is the clear win.
Initialization. All weights start from \(\mathcal{N}(0, 0.02^2)\) — the GPT-2 value, well-matched to \(d=768\) (it’s close to the Xavier-ish \(1/\sqrt{768} \approx 0.036\), slightly conservative). Then one crucial correction: the two matrices in each block that write into the residual stream — the attention output projection and the FFN down projection — get scaled down by \(1/\sqrt{2 \cdot n_{layer}}\). Why: the stream receives \(2 \times 12 = 24\) additive writes, and independent additions grow the stream’s variance linearly with depth. Dividing each writer’s init by \(\sqrt{24}\) keeps the stream’s magnitude roughly constant from block 1 to block 12. Skip this and you’ll see it immediately in Lesson 7’s W&B dashboard: activation norms climbing layer by layer, grad-norm spikes in the first few hundred steps, and with bad luck a diverged run — $10 of GPU time gone. It’s two lines of code.
The count. Here is where every parameter lives:
| Component | Shape | Count |
|---|---|---|
| Token embedding (tied with LM head) | 32768 × 768 | 25,165,824 |
| Attention per layer: fused qkv + out proj | 768×2304 + 768×768 | 2,359,296 |
| SwiGLU per layer: gate + up + down | 2×(768×2048) + 2048×768 | 4,718,592 |
| RMSNorm gains per layer | 2 × 768 | 1,536 |
| One block | 7,079,424 | |
| 12 blocks | 84,953,088 | |
| Final RMSNorm | 768 | 768 |
| Total (tied head counted once) | 110,119,680 |
The 85.0M transformer stack is byte-identical in size to GPT-2-124M’s; the leaner vocab and the deleted position table account for the rest. The print_param_table() method in the file below prints exactly this breakdown, and you should recognize every row in it.
The complete src/model.py
Everything above, assembled into the file that Lessons 6 through 11 will import unchanged. Read the shape comments in forward — being able to trace (B, T) → (B, T, 768) → (B, T, 32768) in your head is the real deliverable of this lesson.
"""WikiGPT-124M -- a modern decoder-only transformer, from scratch.
Llama-family choices on a GPT-2-small skeleton:
RMSNorm (pre-norm) - RoPE - SwiGLU - flash SDPA - no biases - tied embeddings.
CPU-friendly for testing; bf16 + torch.compile on the GPU from Lesson 6.
"""
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------------------------------------- config ---
@dataclass
class GPTConfig:
vocab_size: int = 32768 # our BPE from Lesson 4, chat tokens already reserved
block_size: int = 1024 # maximum context length
n_layer: int = 12
n_head: int = 12
n_embd: int = 768 # head_dim = 768 / 12 = 64
rope_theta: float = 10000.0
dropout: float = 0.0 # one pass over fresh tokens: nothing to overfit to
# ---------------------------------------------------------------- rmsnorm ---
class RMSNorm(nn.Module):
"""y = x / rms(x) * gain. No mean subtraction, no bias."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x): # (B, T, C)
norm = x.float() * torch.rsqrt(
x.float().pow(2).mean(-1, keepdim=True) + self.eps
) # reduce in fp32 for bf16 safety
return norm.type_as(x) * self.weight
# ------------------------------------------------------------------- rope ---
def precompute_rope(head_dim: int, max_seq_len: int, theta: float):
"""cos/sin tables, shape (max_seq_len, head_dim//2), float32."""
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, inv_freq) # (T, head_dim//2): m * theta_i
return torch.cos(freqs), torch.sin(freqs)
def apply_rope(x, cos, sin):
"""Rotate q or k by position-dependent angles. x: (B, n_head, T, head_dim)."""
T = x.size(2)
cos = cos[:T].view(1, 1, T, -1) # broadcast over batch and heads
sin = sin[:T].view(1, 1, T, -1)
x1, x2 = x.float().chunk(2, dim=-1) # pair dim i with dim i + hd//2
y1 = x1 * cos - x2 * sin # 2-D rotation of each pair
y2 = x1 * sin + x2 * cos
return torch.cat([y1, y2], dim=-1).type_as(x)
# -------------------------------------------------------------- attention ---
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.dropout = config.dropout
def forward(self, x, cos, sin):
B, T, C = x.shape # (B, T, 768)
q, k, v = self.qkv(x).split(C, dim=2) # 3 x (B, T, 768)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, 12, T, 64)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, 12, T, 64)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, 12, T, 64)
q = apply_rope(q, cos, sin) # position enters here, not in wte
k = apply_rope(k, cos, sin)
y = F.scaled_dot_product_attention( # flash kernel on CUDA + bf16
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True, # no T x T mask ever materialized
) # (B, 12, T, 64)
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, 768)
return self.proj(y)
# ------------------------------------------------------------------ swiglu ---
class SwiGLU(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
hidden = int(8 * config.n_embd / 3) # 2048: same params as 4x GELU MLP
hidden = 256 * ((hidden + 255) // 256) # tensor-core friendly width
self.w_gate = nn.Linear(config.n_embd, hidden, bias=False)
self.w_up = nn.Linear(config.n_embd, hidden, bias=False)
self.w_down = nn.Linear(hidden, config.n_embd, bias=False)
def forward(self, x): # (B, T, 768)
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) # 768->2048->768
# ------------------------------------------------------------------- block ---
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.norm1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.norm2 = RMSNorm(config.n_embd)
self.mlp = SwiGLU(config)
def forward(self, x, cos, sin):
x = x + self.attn(self.norm1(x), cos, sin) # pre-norm residual
x = x + self.mlp(self.norm2(x))
return x
# --------------------------------------------------------------------- gpt ---
class GPT(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.ModuleList(Block(config) for _ in range(config.n_layer))
self.norm_f = RMSNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head.weight = self.wte.weight # weight tying: one matrix, two jobs
cos, sin = precompute_rope(
config.n_embd // config.n_head, config.block_size, config.rope_theta
)
self.register_buffer("rope_cos", cos, persistent=False) # follows .to(device),
self.register_buffer("rope_sin", sin, persistent=False) # skipped in checkpoints
self.apply(self._init_weights)
# scaled init for the two matrices per block that WRITE the residual stream:
# 2 * n_layer additive writes -> shrink each by sqrt(2 * n_layer)
for name, p in self.named_parameters():
if name.endswith("attn.proj.weight") or name.endswith("mlp.w_down.weight"):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
assert T <= self.config.block_size, f"sequence length {T} > block_size"
x = self.wte(idx) # (B, T) -> (B, T, 768)
for block in self.blocks:
x = block(x, self.rope_cos, self.rope_sin) # (B, T, 768), twelve times
x = self.norm_f(x) # (B, T, 768)
if targets is not None:
logits = self.lm_head(x) # (B, T, 32768)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), # (B*T, 32768)
targets.view(-1), # (B*T,)
ignore_index=-1, # Lesson 9 masks prompts with -1
)
return logits, loss
logits = self.lm_head(x[:, [-1], :]) # inference: last position only
return logits, None
# ------------------------------------------------------------- utils ---
def num_params(self) -> int:
return sum(p.numel() for p in self.parameters()) # tied head counted once
def print_param_table(self):
c = self.config
groups = {
"token embedding (tied w/ lm_head)":
self.wte.weight.numel(),
f"attention ({c.n_layer} layers)":
sum(p.numel() for n, p in self.named_parameters() if ".attn." in n),
f"SwiGLU FFN ({c.n_layer} layers)":
sum(p.numel() for n, p in self.named_parameters() if ".mlp." in n),
"RMSNorm gains":
sum(p.numel() for n, p in self.named_parameters() if "norm" in n),
}
for name, n in groups.items():
print(f"{name:<38s} {n:>12,d}")
print(f"{'TOTAL':<38s} {self.num_params():>12,d}")
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
"""Autoregressive sampling. temperature=0.0 -> greedy decoding.
# ponytail: recomputes the full prefix each step -- fine at 124M/1k ctx;
# add a KV cache if serve.py (Lesson 11) ever feels slow.
"""
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.block_size:] # crop to context window
logits, _ = self(idx_cond) # (B, 1, 32768)
logits = logits[:, -1, :] # (B, 32768)
if temperature == 0.0:
next_tok = logits.argmax(dim=-1, keepdim=True) # greedy
else:
logits = logits / temperature # flatten or sharpen
if top_k is not None: # keep k best
kth = torch.topk(logits, top_k, dim=-1).values[:, [-1]]
logits = logits.masked_fill(logits < kth, float("-inf"))
if top_p is not None: # nucleus sampling
s_logits, s_idx = torch.sort(logits, descending=True, dim=-1)
probs = F.softmax(s_logits, dim=-1)
cum = probs.cumsum(dim=-1)
mask = cum - probs > top_p # keep first past threshold
s_logits = s_logits.masked_fill(mask, float("-inf"))
logits = torch.full_like(logits, float("-inf")) \
.scatter_(-1, s_idx, s_logits)
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_tok], dim=1)
return idx
# -------------------------------------------------------------- smoke test ---
if __name__ == "__main__":
torch.manual_seed(1337)
cfg = GPTConfig()
model = GPT(cfg)
model.print_param_table()
B, T = 2, 128
idx = torch.randint(0, cfg.vocab_size, (B, T))
targets = torch.randint(0, cfg.vocab_size, (B, T))
logits, loss = model(idx, targets)
assert logits.shape == (B, T, cfg.vocab_size)
expected = math.log(cfg.vocab_size) # ln(32768) = 10.397
print(f"initial loss {loss.item():.3f} ln(vocab_size) {expected:.3f}")
assert abs(loss.item() - expected) < 0.5, "broken init: loss should start ~uniform"
out = model.generate(idx[:, :4], max_new_tokens=8, temperature=0.8, top_k=50)
assert out.shape == (B, 12)
out = model.generate(idx[:, :4], max_new_tokens=8, temperature=0.0) # greedy path
out = model.generate(idx[:, :4], max_new_tokens=8, top_p=0.9) # nucleus path
assert out.shape == (B, 12)
print("smoke test passed")Two forward-looking details you just planted: ignore_index=-1 in the cross-entropy is inert now, but in Lesson 9 it becomes the mechanism for masking prompt tokens so SFT loss lands only on assistant tokens. And the logits[:, [-1], :] inference shortcut skips computing 32,768 scores for 1,023 positions you’d throw away — a real speedup inside generate()’s loop.
Run the smoke test — cost: $0.00
No vast.ai instance this lesson. The whole point of the smoke test is to catch bugs on your laptop before any of them can cost GPU-hours:
cd wikillm
python src/model.pyExpected output (exact loss varies slightly with PyTorch version, but must sit near 10.4):
token embedding (tied w/ lm_head) 25,165,824
attention (12 layers) 28,311,552
SwiGLU FFN (12 layers) 56,623,104
RMSNorm gains 19,200
TOTAL 110,119,680
initial loss 10.482 ln(vocab_size) 10.397
smoke test passed
The loss check is the deepest single line of validation in the file. An untrained model with a sane init knows nothing, so its predictive distribution should be near-uniform over 32,768 tokens, and cross-entropy against a uniform distribution is exactly \(\ln(32768) = 15\ln 2 \approx 10.397\). It lands slightly above that because the random logits aren’t exactly zero (their variance adds roughly \(\sigma^2/2\) to the loss). What the check catches: an init std that’s 10× too big (loss starts at 13+ and the run diverges), a missing residual scaling (loss fine, but activations balloon by layer 12), broken weight tying, a shape bug silently broadcasting. When training starts in Lesson 7, watching the loss fall from 10.4 is watching the model climb down from perfect ignorance — every 0.69 drop halves the effective number of tokens it’s choosing between.
🧪 Your task
Prove RoPE’s central claim to yourself: the attention score between a query at position \(m\) and a key at position \(n\) depends only on the offset \(m - n\), not on the absolute positions. Write scripts/check_rope.py that takes one random query vector and one random key vector (64 dims, no batch), places them at several different absolute-position pairs sharing the same offset (e.g. offset 7 at positions (7,0), (57,50), (1000,993)), rotates each with the tables from precompute_rope, and asserts the dot products are identical to within 1e-4. Then break it on purpose: show the score does change when the offset changes.
Solution
# scripts/check_rope.py
import sys, torch
sys.path.append("src")
from model import precompute_rope
torch.manual_seed(0)
head_dim = 64
cos, sin = precompute_rope(head_dim, 1024, 10000.0) # (1024, 32) each
def rope_at(x, pos):
"""Rotate a single (head_dim,) vector as if it sat at absolute position `pos`."""
x1, x2 = x.chunk(2) # (32,), (32,)
c, s = cos[pos], sin[pos]
return torch.cat([x1 * c - x2 * s, x1 * s + x2 * c])
q = torch.randn(head_dim)
k = torch.randn(head_dim)
# same offset, wildly different absolute positions -> identical scores
offset = 7
scores = [rope_at(q, m) @ rope_at(k, m - offset) for m in (7, 57, 300, 1000)]
spread = max(scores) - min(scores)
print(f"offset {offset}: scores {[f'{s:.6f}' for s in scores]} spread {spread:.2e}")
assert spread < 1e-4, "RoPE should depend only on relative offset"
# different offset -> different score (sanity that we're not testing a constant)
other = rope_at(q, 20) @ rope_at(k, 12) # offset 8
assert abs(other - scores[0]) > 1e-3
print("RoPE is purely relative -- check passed")Why it works: rotating \(q\) by angle \(m\theta_i\) and \(k\) by \(n\theta_i\) means their per-pair dot product contains \(\cos(m\theta_i - n\theta_i)\) terms — rotation matrices compose by angle subtraction, so only \(m - n\) survives. Absolute position cancels out pair by pair, in all 32 pairs at once.
Key takeaways
- WikiGPT-124M is the Llama recipe at GPT-2-small scale: RMSNorm, RoPE, SwiGLU, pre-norm, no biases, tied embeddings — every choice matches what frontier-lab models converged on, just twelve layers of it.
- RMSNorm drops LayerNorm’s mean-centering and bias because only scale control stabilizes deep residual stacks; compute the reduction in fp32 even when training in bf16.
- RoPE encodes position as a rotation of q/k pairs inside the attention dot product, making scores depend only on relative offset — zero learned parameters, precomputed cos/sin buffers, and you proved the relative property yourself.
F.scaled_dot_product_attention(is_causal=True)buys the FlashAttention kernel for free — O(T) memory, roughly 2× throughput, and no hand-rolled T×T mask to get wrong.- SwiGLU at 8/3× hidden width spends exactly the same 4.72M parameters per layer as GPT-2’s 4× GELU MLP, and spends them better; the gate/up/down structure is standard across Llama, PaLM, and Mistral.
- Init is 0.02 everywhere plus \(1/\sqrt{2 \cdot n_{layer}}\) scaling on the two residual-writing projections per block — two lines that prevent activation blow-up across 24 residual writes.
- The tied embedding counts once: 110,119,680 parameters exactly, with an 85.0M transformer stack identical in size to GPT-2-124M’s — the difference is purely our leaner 32k vocab.
- An untrained model must score \(\approx \ln(V) = 10.4\); that one assert catches init, tying, and shape bugs on your laptop for $0.00 before the GPU meter starts.
Coming up
In Lesson 6 we wrap this model in src/train.py — bf16 autocast, torch.compile, cosine LR schedule with warmup, gradient accumulation, W&B logging, and crash-proof resumable checkpoints — the engine that has to survive 20+ unattended hours on a rented 4090.
🏠 📖 Course home | ← Lesson 04 | Lesson 06 → | 📚 All mini-courses