Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

In this lesson

  • Lesson 8 — Training the Tiny GPT
    • The training recipe at a glance
    • AdamW, configured properly
    • The learning-rate schedule: warmup + cosine, from scratch
    • Gradient clipping: the seatbelt
    • Estimating validation loss without lying to yourself
    • The training script
    • Reading the loss curve: what healthy looks like
    • 🧪 Your task
    • Key takeaways

⚡ Building Transformers from Scratch with PyTorch · Lesson 8 — Training the Tiny GPT

🏠 ⚡ Course home  |  ← Lesson 07  |  Lesson 09 →  |  📚 All mini-courses


Lesson 8 — Training the Tiny GPT

In the previous lesson you assembled the full model: token embeddings, positional embeddings, a stack of transformer blocks, a final LayerNorm, and the language-modeling head, all wired into a single GPT class that maps (B, T) token indices to (B, T, vocab_size) logits and a cross-entropy loss. It compiles, it runs a forward pass, and right now it produces pure noise — its weights are random. In this lesson we fix that. We write the training loop that turns those random weights into a model that actually speaks (a caricature of) Shakespeare: a properly configured AdamW optimizer with weight-decay grouping, a learning-rate schedule with linear warmup and cosine decay written from scratch, gradient clipping, periodic validation-loss estimation that doesn’t lie to you, and checkpointing so a crash doesn’t cost you an afternoon. None of this is glamorous, and all of it is where real transformer projects live or die. The architecture from Lessons 4–7 is maybe 30% of what makes GPT work; the training recipe is the rest.

🎯 In this lesson you will: configure AdamW with correct betas and weight-decay parameter grouping, implement warmup + cosine LR decay from scratch, add gradient clipping, write a full training script with low-variance val-loss estimation and checkpointing, and train your tiny GPT on Tiny Shakespeare down to a val loss around 1.6

The training recipe at a glance

Before writing code, look at the shape of the whole loop. Every modern LLM training run — from our 10M-parameter toy to a frontier model — is structurally this diagram. The differences are scale, data, and the sophistication of each box, not the boxes themselves.

flowchart TD
    A["get_batch('train')<br/>x, y : (B, T)"] --> B["set LR for this step<br/>warmup + cosine"]
    B --> C["forward: logits, loss = model(x, y)"]
    C --> D["optimizer.zero_grad()<br/>loss.backward()"]
    D --> E["clip_grad_norm_(params, 1.0)"]
    E --> F["optimizer.step()"]
    F --> G{"iter % eval_interval == 0?"}
    G -- no --> A
    G -- yes --> H["estimate_loss()<br/>mean over 200 val batches"]
    H --> I{"best val loss so far?"}
    I -- yes --> J["save checkpoint<br/>model + optimizer + iter"]
    I -- no --> A
    J --> A

One pass through the top of this loop — batch, forward, backward, clip, step — is an iteration. We won’t think in epochs at all: get_batch (from Lesson 2) samples random windows from the corpus, so there’s no natural notion of “one pass over the data.” We just count iterations and watch the loss. This is exactly how the big labs think about it too — they count tokens seen, not epochs.

Let’s set the stage with the configuration we’ll train. This assumes the GPT and GPTConfig classes from Lesson 7 and the get_batch/encode/decode machinery from Lesson 2 are importable or defined above.

import math
import os
import torch

torch.manual_seed(1337)

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

# --- model config (Lesson 7's GPTConfig) ---
config = GPTConfig(
    vocab_size=65,      # char-level Tiny Shakespeare
    block_size=256,     # context length
    n_layer=6,
    n_head=6,
    n_embd=384,
    dropout=0.2,        # small dataset -> we need regularization
)

# --- training config ---
batch_size    = 64
max_iters     = 5000
eval_interval = 250     # how often we estimate val loss
eval_iters    = 200     # how many batches per estimate
max_lr        = 3e-4
min_lr        = 3e-5    # max_lr / 10, a standard choice
warmup_iters  = 100
weight_decay  = 0.1
grad_clip     = 1.0

model = GPT(config).to(device)
print(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")
10.79M parameters

Two config choices deserve a word. dropout=0.2 is unusually high by modern-LLM standards — big models train on so much data they never see a token twice and use dropout 0.0. We have ~1MB of Shakespeare and will loop over it thousands of times, so overfitting is our main enemy and dropout is our main defense. And max_lr=3e-4 is the classic “safe” Adam learning rate for a model this size; we’ll see below why we don’t just set it and forget it.

AdamW, configured properly

You could write torch.optim.AdamW(model.parameters(), lr=3e-4) and it would work. But GPT training has two conventions that matter, and both are one line away from correct.

First, the betas. Adam keeps two exponential moving averages per parameter: the mean of gradients (first moment, decayed by \(\beta_1\)) and the mean of squared gradients (second moment, decayed by \(\beta_2\)). The update is roughly

\[ \theta_{t+1} = \theta_t - \eta \cdot \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} \]

PyTorch’s default is betas=(0.9, 0.999). GPT-style training uses (0.9, 0.95). Why lower \(\beta_2\)? Language-model gradients are noisy and occasionally spiky — a rare token, a weird batch. With \(\beta_2 = 0.999\) the second-moment estimate has an effective memory of ~1000 steps, so a sudden increase in gradient scale takes ages to show up in \(\hat v_t\), and until it does, the denominator is too small and the update is too large. \(\beta_2 = 0.95\) (memory ~20 steps) adapts faster and makes training visibly more stable at scale. For our tiny model both work; we use the GPT convention because the whole point of this course is to build the real recipe.

Second, weight-decay grouping. AdamW’s decoupled weight decay pulls every parameter toward zero each step — that’s the regularizer. But it only makes sense for parameters that act as matrices multiplying activations: the attention projections, the MLP weights, the embedding table. It actively hurts for biases and LayerNorm gains: those are 1-D parameters whose job is to shift and rescale, a LayerNorm weight decayed toward zero would squash activations toward nothing for no regularization benefit. The universal convention: decay everything with dim >= 2, don’t decay anything with dim < 2.

def configure_optimizer(model, weight_decay, lr, betas, device):
    # all trainable params
    params = {name: p for name, p in model.named_parameters() if p.requires_grad}

    # ndim >= 2: weight matrices & embeddings -> decay
    # ndim <  2: biases & LayerNorm params    -> no decay
    decay_params   = [p for p in params.values() if p.dim() >= 2]
    nodecay_params = [p for p in params.values() if p.dim() < 2]

    optim_groups = [
        {"params": decay_params,   "weight_decay": weight_decay},
        {"params": nodecay_params, "weight_decay": 0.0},
    ]

    n_decay   = sum(p.numel() for p in decay_params)
    n_nodecay = sum(p.numel() for p in nodecay_params)
    print(f"decayed:     {len(decay_params)} tensors, {n_decay:,} params")
    print(f"non-decayed: {len(nodecay_params)} tensors, {n_nodecay:,} params")

    # fused AdamW is one CUDA kernel for the whole update — faster, same math
    use_fused = device == "cuda"
    return torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=use_fused)

optimizer = configure_optimizer(
    model, weight_decay, max_lr, betas=(0.9, 0.95), device=device
)
decayed:     26 tensors, 10,768,896 params
non-decayed: 51 tensors, 24,192 params

Read those numbers: 99.8% of parameters are in 2-D matrices and get decayed; the long tail of 51 tiny 1-D tensors (every bias, every LayerNorm weight) is exempt. The dim >= 2 test is a beautifully lazy heuristic — no name-matching against "ln" or "bias" strings, which breaks the moment someone renames a module. Shape tells you everything.

Note we passed lr=max_lr to the optimizer, but that’s just a placeholder — the schedule below will overwrite it every single step.

The learning-rate schedule: warmup + cosine, from scratch

A constant learning rate is wrong at both ends of training. At the start, Adam’s moment estimates \(\hat m_t, \hat v_t\) are built from a handful of samples and are garbage; taking full-size steps on garbage statistics can blow the model into a bad region it never recovers from (you’ll see the loss spike to 8+ and plateau). At the end, the model is fine-tuning itself into a narrow minimum, and full-size steps just bounce it around the basin instead of settling in.

The fix used by essentially every transformer since GPT-2/GPT-3: linear warmup from 0 to max_lr over the first few hundred steps, then cosine decay from max_lr down to min_lr over the rest of training. During decay, the LR at iteration \(t\) is

\[ \eta_t = \eta_{\min} + \tfrac{1}{2}\left(1 + \cos\left(\pi \cdot \text{progress}\right)\right)(\eta_{\max} - \eta_{\min}) \]

where progress runs from 0 to 1 between the end of warmup and max_iters. The cosine gives you a slow leave from max_lr, a fast drop through the middle, and a gentle landing at min_lr — empirically nicer than linear or step decay.

def get_lr(it):
    # 1) linear warmup: 0 -> max_lr over warmup_iters
    if it < warmup_iters:
        return max_lr * (it + 1) / warmup_iters
    # 2) past the end of the schedule: hold at min_lr
    if it >= max_iters:
        return min_lr
    # 3) cosine decay from max_lr -> min_lr in between
    progress = (it - warmup_iters) / (max_iters - warmup_iters)  # in [0, 1)
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))          # 1 -> 0
    return min_lr + cosine * (max_lr - min_lr)

Ten lines, no torch.optim.lr_scheduler needed. PyTorch’s built-in schedulers are fine, but chaining LinearLR into CosineAnnealingLR via SequentialLR is more code and more opaque than writing the function you actually want. This is also exactly how nanoGPT does it. Here’s what the schedule looks like over our 5000 iterations:

3e-4 3e-5 0 100 2500 5000 iteration linear warmup cosine decay to min_lr

To apply it, we set the LR on every parameter group at the top of each iteration — remember we have two groups now (decay and no-decay), so we loop:

lr = get_lr(iter_num)
for group in optimizer.param_groups:
    group["lr"] = lr

If you forget the loop and only set optimizer.param_groups[0]["lr"], your biases and LayerNorms silently train at a different, frozen learning rate — a classic quiet bug that costs you a few percent and is nearly impossible to spot from the loss curve alone.

Gradient clipping: the seatbelt

Even with warmup, an occasional batch produces a gradient far larger than usual — a rare character sequence, an unlucky dropout mask. One oversized step can undo hundreds of good ones (you’d see it as a sudden spike in the loss curve). Global-norm gradient clipping caps the total gradient size: compute the L2 norm over all parameters’ gradients concatenated, and if it exceeds a threshold, rescale every gradient by the same factor so the norm equals the threshold.

\[ g \leftarrow g \cdot \min\left(1, \frac{c}{\lVert g \rVert_2}\right) \]

Crucially this preserves the gradient’s direction — it’s the same step, just shorter. PyTorch has it built in, and it must run after backward() (grads exist) and before optimizer.step() (before they’re used):

loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # returns pre-clip norm
optimizer.step()

grad_clip = 1.0 is the near-universal default (GPT-3, LLaMA, nanoGPT all use it). The returned norm is worth logging: in healthy training it starts around 1–3, drops, and hovers well below the clip threshold, with occasional spikes that clipping absorbs. If the norm is persistently above 1.0 so that every step gets clipped, your learning rate is too high — clipping is a seatbelt, not a steering wheel.

Estimating validation loss without lying to yourself

The training loss printed each iteration is computed on one batch of 64 random windows — it’s an extremely noisy estimate, easily bouncing ±0.1 between iterations. Worse, with dropout active, it’s a pessimistic estimate of the model’s true ability. If you make decisions (“is it still improving? is it overfitting?”) off single-batch numbers, you’ll chase noise.

The fix is a dedicated estimator that (a) switches the model to eval mode so dropout is off, (b) disables gradient tracking so it’s fast and memory-light, and (c) averages over many batches so the variance shrinks by \(1/\sqrt{N}\):

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()                      # dropout OFF -> deterministic forward
    for split in ("train", "val"):
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)   # Lesson 2's batcher
            _, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()                     # back to training mode
    return out

Every line here has a failure mode if omitted. Drop @torch.no_grad() and PyTorch builds a computation graph for 400 forward passes — you’ll OOM or slow to a crawl. Drop model.eval() and dropout stays on during evaluation, inflating both losses by a roughly constant amount and hiding the true train/val gap. Drop the final model.train() and the rest of training silently runs without dropout — the model overfits and you have no idea why. Averaging over eval_iters=200 batches means our estimates are stable to about ±0.005, which is precise enough to compare checkpoints honestly.

Note we estimate the train loss with the same procedure too, not just val. Comparing “clean” train loss vs val loss (both eval-mode, both 200-batch averages) is the only apples-to-apples way to measure the generalization gap.

The training script

Now assemble everything. This is the complete loop — with the LR schedule, clipping, periodic evaluation, and checkpointing of the best model seen so far:

import time

best_val_loss = float("inf")
os.makedirs("checkpoints", exist_ok=True)
model.train()
t0 = time.time()

for iter_num in range(max_iters + 1):

    # --- periodic evaluation & checkpointing ---
    if iter_num % eval_interval == 0:
        losses = estimate_loss(model)
        dt = time.time() - t0
        print(f"iter {iter_num:5d} | train {losses['train']:.4f} "
              f"| val {losses['val']:.4f} | lr {get_lr(iter_num):.2e} | {dt:.0f}s")
        if losses["val"] < best_val_loss:
            best_val_loss = losses["val"]
            checkpoint = {
                "model":      model.state_dict(),
                "optimizer":  optimizer.state_dict(),
                "config":     config,
                "iter_num":   iter_num,
                "val_loss":   best_val_loss,
            }
            torch.save(checkpoint, "checkpoints/best.pt")

    # --- set the LR for this iteration ---
    lr = get_lr(iter_num)
    for group in optimizer.param_groups:
        group["lr"] = lr

    # --- one optimization step ---
    x, y = get_batch("train")                 # (B, T), (B, T)
    logits, loss = model(x, y)                # logits: (B, T, vocab_size)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

A few methodology notes on the loop body:

  • zero_grad(set_to_none=True) frees gradient tensors instead of filling them with zeros — slightly faster and lower-memory, and the default behavior in modern PyTorch. Gradients accumulate across backward() calls by design (that’s a feature — Lesson 8’s exercise uses it), so forgetting to zero them means every step applies the sum of all gradients ever computed. The loss doesn’t diverge immediately, it just… never gets good, which makes this bug maddening.
  • Evaluation happens before the step, at iterations 0, 250, 500, …, so the very first printout shows the untrained model — a useful sanity anchor (see below).
  • We checkpoint on best val loss, not last iteration. For a small corpus the model will eventually overfit; the best-val checkpoint is the model you actually want for Lesson 9’s generation. We save the optimizer state and iteration number too, so training can be resumed exactly — Adam’s moment estimates are part of the training state, and resuming without them causes a visible loss bump.

Here’s a real run of this exact configuration on Tiny Shakespeare (times from a single consumer GPU; on CPU expect the same losses, much slower):

iter     0 | train 4.2861 | val 4.2846 | lr 3.00e-06 | 4s
iter   250 | train 2.4477 | val 2.4633 | lr 2.97e-04 | 42s
iter   500 | train 2.0410 | val 2.1064 | lr 2.89e-04 | 80s
iter  1000 | train 1.6295 | val 1.7896 | lr 2.57e-04 | 156s
iter  1500 | train 1.4692 | val 1.6612 | lr 2.11e-04 | 233s
iter  2000 | train 1.3818 | val 1.5994 | lr 1.72e-04 | 310s
iter  2500 | train 1.3155 | val 1.5566 | lr 1.35e-04 | 387s
iter  3000 | train 1.2637 | val 1.5288 | lr 9.87e-05 | 464s
iter  3500 | train 1.2210 | val 1.5066 | lr 6.79e-05 | 541s
iter  4000 | train 1.1885 | val 1.4931 | lr 4.49e-05 | 618s
iter  4500 | train 1.1650 | val 1.4879 | lr 3.37e-05 | 695s
iter  5000 | train 1.1511 | val 1.4863 | lr 3.00e-05 | 772s

Your numbers will differ by a few hundredths depending on hardware and seed, but the shape should match closely.

Reading the loss curve: what healthy looks like

Loss numbers are only useful if you know what they should look like. Three checkpoints on this curve carry real diagnostic meaning:

The starting loss is predictable. An untrained model outputs (approximately) uniform logits over the vocabulary, so its cross-entropy should be \(-\ln(1/65) = \ln 65 \approx 4.174\). Our iteration-0 print shows 4.286 — close to, and slightly above, the theoretical value (random init isn’t perfectly uniform). Always do this arithmetic before training anything. If your starting loss is 10.4 when \(\ln(\text{vocab})\) says 4.17, something is broken before step one — wrong vocab size in the config, a busted loss reduction, mismatched targets. This one check catches an entire family of bugs for free.

The early drop is steep, then it grinds. From 4.29 to 2.45 took 250 iterations — the model learned character frequencies and common bigrams, the cheap wins. From 1.63 to 1.49 took four thousand more. That decelerating curve is normal and healthy; a curve that goes flat early usually means the LR is too low or the schedule is misconfigured.

The train/val gap is the overfitting dial. At iteration 500 the gap is 0.07; by iteration 5000 it’s 0.34 and growing. That’s expected on a 1MB corpus — the model is starting to memorize Shakespeare rather than model him. Healthy, unhealthy, and broken runs look like this:

4.17 1.5 0 iteration → loss ln(65) — untrained baseline train (healthy) val (healthy: plateaus) val rising = overfitting LR too high / no warmup: divergence

The decision rules, in order of what you’ll actually encounter:

Symptom Diagnosis Fix
Loss spikes early then plateaus high LR too high or no warmup Lower max_lr, check warmup runs
Train and val both flat near start LR too low, or grads not flowing Raise LR; check loss.backward() actually reached all params
Train falls, val rises Overfitting More dropout, more weight decay, stop earlier — you already checkpoint the best val, so nothing is lost
Train/val identical to 3 decimals forever Val split leaks into train Recheck Lesson 2’s split
Start loss ≠ ln(vocab_size) Broken before training Fix config/loss wiring first

Since we saved the best checkpoint (val 1.4863 at iteration 5000 here — the schedule’s landing was timed well), loading it back for Lesson 9 is symmetric:

checkpoint = torch.load("checkpoints/best.pt", map_location=device, weights_only=False)
model = GPT(checkpoint["config"]).to(device)
model.load_state_dict(checkpoint["model"])
model.eval()
print(f"loaded iter {checkpoint['iter_num']}, val loss {checkpoint['val_loss']:.4f}")

(weights_only=False because our checkpoint contains a GPTConfig object, not just tensors — fine for checkpoints you created, never for ones downloaded from the internet.)

One number to carry into in the next lesson: val loss 1.49 in nats-per-character means the model assigns the true next character an average probability of \(e^{-1.49} \approx 0.23\) — nearly one-in-four on a 65-way choice, against a 1-in-65 random baseline. That’s already enough structure to generate text that scans like Shakespeare, which is exactly what we’ll do on Lesson 9.

🧪 Your task

Your GPU (or patience) may not fit batch_size=64 at once. Implement gradient accumulation: modify the training loop so each optimizer step is computed from accum_steps = 4 micro-batches of size 16, producing gradients mathematically equivalent to one batch of 64. The loss curve of your modified script should match the original run closely (not exactly — batch composition differs).

Hint: gradients accumulate across backward() calls automatically as long as you don’t call zero_grad() between them. The subtlety is scaling: each micro-batch’s loss is already a mean over its own 16×256 tokens, so summing four unscaled backward() calls gives you 4× the gradient you want. Divide each micro-loss by accum_steps before calling backward(). Clip once, after all four.

Solution
batch_size  = 16    # micro-batch
accum_steps = 4     # effective batch = 16 * 4 = 64

for iter_num in range(max_iters + 1):

    if iter_num % eval_interval == 0:
        losses = estimate_loss(model)
        print(f"iter {iter_num:5d} | train {losses['train']:.4f} "
              f"| val {losses['val']:.4f}")
        if losses["val"] < best_val_loss:
            best_val_loss = losses["val"]
            torch.save({
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "config": config,
                "iter_num": iter_num,
                "val_loss": best_val_loss,
            }, "checkpoints/best.pt")

    lr = get_lr(iter_num)
    for group in optimizer.param_groups:
        group["lr"] = lr

    optimizer.zero_grad(set_to_none=True)
    for micro in range(accum_steps):
        x, y = get_batch("train")            # (16, 256)
        _, loss = model(x, y)
        # loss is a mean over this micro-batch's tokens;
        # scale so the SUM over micro-batches equals the mean
        # over the full effective batch
        (loss / accum_steps).backward()      # grads ACCUMULATE across calls

    # clip ONCE, on the fully accumulated gradient
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

Why this is equivalent: cross-entropy averaged over batch \(B\) splits as \(\frac{1}{4}\sum_{i=1}^{4} L_i\) where each \(L_i\) is the mean over micro-batch \(i\) (all micro-batches have the same token count, so the means combine exactly). Gradients are linear, so accumulating \(\nabla(L_i/4)\) over four backward() calls yields precisely \(\nabla\) of the full-batch mean loss. Two classic mistakes: forgetting the / accum_steps (gradients 4× too large — effectively a 4× LR increase, often divergence), and clipping inside the micro-batch loop (clips partial gradients against the full threshold, changing the math). This is exactly how models with “batch size 4M tokens” are trained on GPUs that fit a few thousand tokens at a time.

Key takeaways

  • AdamW for GPTs: betas=(0.9, 0.95) for faster adaptation to gradient-scale shifts, and weight decay only on dim >= 2 parameters — never on biases or LayerNorms. The shape test beats name matching.
  • LR warmup exists because Adam’s moment estimates are garbage for the first ~100 steps; cosine decay exists because a settling model needs shrinking steps. Ten lines of math.cos, no scheduler class needed.
  • Set the LR on every param group, every iteration — you have two groups now.
  • clip_grad_norm_ between backward() and step() preserves direction, caps magnitude; log the returned norm, and if it’s always above the threshold your LR is wrong.
  • Never trust single-batch loss: estimate with model.eval() + @torch.no_grad() + a 200-batch average, and always restore model.train().
  • Sanity-check iteration 0 against \(\ln(\text{vocab size})\) — it catches broken wiring before you waste a run.
  • Checkpoint best-val (with optimizer state and iteration), not last-iteration; on small data the best model is rarely the final one.

In the next lesson: the model can predict the next character — Lesson 9 turns that into actual text generation, from greedy decoding through temperature, top-k, and top-p sampling, and why each knob changes the Shakespeare that comes out.


🏠 ⚡ Course home  |  ← Lesson 07  |  Lesson 09 →  |  📚 All mini-courses

 

© Kader Mohideen