flowchart TD
A["get_batch('train')<br/>x, y : (B, T)"] --> B["set LR for this step<br/>warmup + cosine"]
B --> C["forward: logits, loss = model(x, y)"]
C --> D["optimizer.zero_grad()<br/>loss.backward()"]
D --> E["clip_grad_norm_(params, 1.0)"]
E --> F["optimizer.step()"]
F --> G{"iter % eval_interval == 0?"}
G -- no --> A
G -- yes --> H["estimate_loss()<br/>mean over 200 val batches"]
H --> I{"best val loss so far?"}
I -- yes --> J["save checkpoint<br/>model + optimizer + iter"]
I -- no --> A
J --> A
⚡ Building Transformers from Scratch with PyTorch · Day 8 — Training the Tiny GPT
🏠 ⚡ Course home | ← Day 07 | Day 09 → | 📚 All mini-courses
Day 8 — Training the Tiny GPT
Yesterday you assembled the full model: token embeddings, positional embeddings, a stack of transformer blocks, a final LayerNorm, and the language-modeling head, all wired into a single GPT class that maps (B, T) token indices to (B, T, vocab_size) logits and a cross-entropy loss. It compiles, it runs a forward pass, and right now it produces pure noise — its weights are random. Today we fix that. We write the training loop that turns those random weights into a model that actually speaks (a caricature of) Shakespeare: a properly configured AdamW optimizer with weight-decay grouping, a learning-rate schedule with linear warmup and cosine decay written from scratch, gradient clipping, periodic validation-loss estimation that doesn’t lie to you, and checkpointing so a crash doesn’t cost you an afternoon. None of this is glamorous, and all of it is where real transformer projects live or die. The architecture from Days 4–7 is maybe 30% of what makes GPT work; the training recipe is the rest.
🎯 Today you will: configure AdamW with correct betas and weight-decay parameter grouping, implement warmup + cosine LR decay from scratch, add gradient clipping, write a full training script with low-variance val-loss estimation and checkpointing, and train your tiny GPT on Tiny Shakespeare down to a val loss around 1.6
The training recipe at a glance
Before writing code, look at the shape of the whole loop. Every modern LLM training run — from our 10M-parameter toy to a frontier model — is structurally this diagram. The differences are scale, data, and the sophistication of each box, not the boxes themselves.
One pass through the top of this loop — batch, forward, backward, clip, step — is an iteration. We won’t think in epochs at all: get_batch (from Day 2) samples random windows from the corpus, so there’s no natural notion of “one pass over the data.” We just count iterations and watch the loss. This is exactly how the big labs think about it too — they count tokens seen, not epochs.
Let’s set the stage with the configuration we’ll train. This assumes the GPT and GPTConfig classes from Day 7 and the get_batch/encode/decode machinery from Day 2 are importable or defined above.
import math
import os
import torch
torch.manual_seed(1337)
device = (
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
# --- model config (Day 7's GPTConfig) ---
config = GPTConfig(
vocab_size=65, # char-level Tiny Shakespeare
block_size=256, # context length
n_layer=6,
n_head=6,
n_embd=384,
dropout=0.2, # small dataset -> we need regularization
)
# --- training config ---
batch_size = 64
max_iters = 5000
eval_interval = 250 # how often we estimate val loss
eval_iters = 200 # how many batches per estimate
max_lr = 3e-4
min_lr = 3e-5 # max_lr / 10, a standard choice
warmup_iters = 100
weight_decay = 0.1
grad_clip = 1.0
model = GPT(config).to(device)
print(f"{sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")10.79M parameters
Two config choices deserve a word. dropout=0.2 is unusually high by modern-LLM standards — big models train on so much data they never see a token twice and use dropout 0.0. We have ~1MB of Shakespeare and will loop over it thousands of times, so overfitting is our main enemy and dropout is our main defense. And max_lr=3e-4 is the classic “safe” Adam learning rate for a model this size; we’ll see below why we don’t just set it and forget it.
AdamW, configured properly
You could write torch.optim.AdamW(model.parameters(), lr=3e-4) and it would work. But GPT training has two conventions that matter, and both are one line away from correct.
First, the betas. Adam keeps two exponential moving averages per parameter: the mean of gradients (first moment, decayed by \(\beta_1\)) and the mean of squared gradients (second moment, decayed by \(\beta_2\)). The update is roughly
\[ \theta_{t+1} = \theta_t - \eta \cdot \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} \]
PyTorch’s default is betas=(0.9, 0.999). GPT-style training uses (0.9, 0.95). Why lower \(\beta_2\)? Language-model gradients are noisy and occasionally spiky — a rare token, a weird batch. With \(\beta_2 = 0.999\) the second-moment estimate has an effective memory of ~1000 steps, so a sudden increase in gradient scale takes ages to show up in \(\hat v_t\), and until it does, the denominator is too small and the update is too large. \(\beta_2 = 0.95\) (memory ~20 steps) adapts faster and makes training visibly more stable at scale. For our tiny model both work; we use the GPT convention because the whole point of this course is to build the real recipe.
Second, weight-decay grouping. AdamW’s decoupled weight decay pulls every parameter toward zero each step — that’s the regularizer. But it only makes sense for parameters that act as matrices multiplying activations: the attention projections, the MLP weights, the embedding table. It actively hurts for biases and LayerNorm gains: those are 1-D parameters whose job is to shift and rescale, a LayerNorm weight decayed toward zero would squash activations toward nothing for no regularization benefit. The universal convention: decay everything with dim >= 2, don’t decay anything with dim < 2.
def configure_optimizer(model, weight_decay, lr, betas, device):
# all trainable params
params = {name: p for name, p in model.named_parameters() if p.requires_grad}
# ndim >= 2: weight matrices & embeddings -> decay
# ndim < 2: biases & LayerNorm params -> no decay
decay_params = [p for p in params.values() if p.dim() >= 2]
nodecay_params = [p for p in params.values() if p.dim() < 2]
optim_groups = [
{"params": decay_params, "weight_decay": weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]
n_decay = sum(p.numel() for p in decay_params)
n_nodecay = sum(p.numel() for p in nodecay_params)
print(f"decayed: {len(decay_params)} tensors, {n_decay:,} params")
print(f"non-decayed: {len(nodecay_params)} tensors, {n_nodecay:,} params")
# fused AdamW is one CUDA kernel for the whole update — faster, same math
use_fused = device == "cuda"
return torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=use_fused)
optimizer = configure_optimizer(
model, weight_decay, max_lr, betas=(0.9, 0.95), device=device
)decayed: 26 tensors, 10,768,896 params
non-decayed: 51 tensors, 24,192 params
Read those numbers: 99.8% of parameters are in 2-D matrices and get decayed; the long tail of 51 tiny 1-D tensors (every bias, every LayerNorm weight) is exempt. The dim >= 2 test is a beautifully lazy heuristic — no name-matching against "ln" or "bias" strings, which breaks the moment someone renames a module. Shape tells you everything.
Note we passed lr=max_lr to the optimizer, but that’s just a placeholder — the schedule below will overwrite it every single step.
The learning-rate schedule: warmup + cosine, from scratch
A constant learning rate is wrong at both ends of training. At the start, Adam’s moment estimates \(\hat m_t, \hat v_t\) are built from a handful of samples and are garbage; taking full-size steps on garbage statistics can blow the model into a bad region it never recovers from (you’ll see the loss spike to 8+ and plateau). At the end, the model is fine-tuning itself into a narrow minimum, and full-size steps just bounce it around the basin instead of settling in.
The fix used by essentially every transformer since GPT-2/GPT-3: linear warmup from 0 to max_lr over the first few hundred steps, then cosine decay from max_lr down to min_lr over the rest of training. During decay, the LR at iteration \(t\) is
\[ \eta_t = \eta_{\min} + \tfrac{1}{2}\left(1 + \cos\left(\pi \cdot \text{progress}\right)\right)(\eta_{\max} - \eta_{\min}) \]
where progress runs from 0 to 1 between the end of warmup and max_iters. The cosine gives you a slow leave from max_lr, a fast drop through the middle, and a gentle landing at min_lr — empirically nicer than linear or step decay.
def get_lr(it):
# 1) linear warmup: 0 -> max_lr over warmup_iters
if it < warmup_iters:
return max_lr * (it + 1) / warmup_iters
# 2) past the end of the schedule: hold at min_lr
if it >= max_iters:
return min_lr
# 3) cosine decay from max_lr -> min_lr in between
progress = (it - warmup_iters) / (max_iters - warmup_iters) # in [0, 1)
cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) # 1 -> 0
return min_lr + cosine * (max_lr - min_lr)Ten lines, no torch.optim.lr_scheduler needed. PyTorch’s built-in schedulers are fine, but chaining LinearLR into CosineAnnealingLR via SequentialLR is more code and more opaque than writing the function you actually want. This is also exactly how nanoGPT does it. Here’s what the schedule looks like over our 5000 iterations:
To apply it, we set the LR on every parameter group at the top of each iteration — remember we have two groups now (decay and no-decay), so we loop:
lr = get_lr(iter_num)
for group in optimizer.param_groups:
group["lr"] = lrIf you forget the loop and only set optimizer.param_groups[0]["lr"], your biases and LayerNorms silently train at a different, frozen learning rate — a classic quiet bug that costs you a few percent and is nearly impossible to spot from the loss curve alone.
Gradient clipping: the seatbelt
Even with warmup, an occasional batch produces a gradient far larger than usual — a rare character sequence, an unlucky dropout mask. One oversized step can undo hundreds of good ones (you’d see it as a sudden spike in the loss curve). Global-norm gradient clipping caps the total gradient size: compute the L2 norm over all parameters’ gradients concatenated, and if it exceeds a threshold, rescale every gradient by the same factor so the norm equals the threshold.
\[ g \leftarrow g \cdot \min\left(1, \frac{c}{\lVert g \rVert_2}\right) \]
Crucially this preserves the gradient’s direction — it’s the same step, just shorter. PyTorch has it built in, and it must run after backward() (grads exist) and before optimizer.step() (before they’re used):
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # returns pre-clip norm
optimizer.step()grad_clip = 1.0 is the near-universal default (GPT-3, LLaMA, nanoGPT all use it). The returned norm is worth logging: in healthy training it starts around 1–3, drops, and hovers well below the clip threshold, with occasional spikes that clipping absorbs. If the norm is persistently above 1.0 so that every step gets clipped, your learning rate is too high — clipping is a seatbelt, not a steering wheel.
Estimating validation loss without lying to yourself
The training loss printed each iteration is computed on one batch of 64 random windows — it’s an extremely noisy estimate, easily bouncing ±0.1 between iterations. Worse, with dropout active, it’s a pessimistic estimate of the model’s true ability. If you make decisions (“is it still improving? is it overfitting?”) off single-batch numbers, you’ll chase noise.
The fix is a dedicated estimator that (a) switches the model to eval mode so dropout is off, (b) disables gradient tracking so it’s fast and memory-light, and (c) averages over many batches so the variance shrinks by \(1/\sqrt{N}\):
@torch.no_grad()
def estimate_loss(model):
out = {}
model.eval() # dropout OFF -> deterministic forward
for split in ("train", "val"):
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
x, y = get_batch(split) # Day 2's batcher
_, loss = model(x, y)
losses[k] = loss.item()
out[split] = losses.mean().item()
model.train() # back to training mode
return outEvery line here has a failure mode if omitted. Drop @torch.no_grad() and PyTorch builds a computation graph for 400 forward passes — you’ll OOM or slow to a crawl. Drop model.eval() and dropout stays on during evaluation, inflating both losses by a roughly constant amount and hiding the true train/val gap. Drop the final model.train() and the rest of training silently runs without dropout — the model overfits and you have no idea why. Averaging over eval_iters=200 batches means our estimates are stable to about ±0.005, which is precise enough to compare checkpoints honestly.
Note we estimate the train loss with the same procedure too, not just val. Comparing “clean” train loss vs val loss (both eval-mode, both 200-batch averages) is the only apples-to-apples way to measure the generalization gap.
The training script
Now assemble everything. This is the complete loop — with the LR schedule, clipping, periodic evaluation, and checkpointing of the best model seen so far:
import time
best_val_loss = float("inf")
os.makedirs("checkpoints", exist_ok=True)
model.train()
t0 = time.time()
for iter_num in range(max_iters + 1):
# --- periodic evaluation & checkpointing ---
if iter_num % eval_interval == 0:
losses = estimate_loss(model)
dt = time.time() - t0
print(f"iter {iter_num:5d} | train {losses['train']:.4f} "
f"| val {losses['val']:.4f} | lr {get_lr(iter_num):.2e} | {dt:.0f}s")
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"config": config,
"iter_num": iter_num,
"val_loss": best_val_loss,
}
torch.save(checkpoint, "checkpoints/best.pt")
# --- set the LR for this iteration ---
lr = get_lr(iter_num)
for group in optimizer.param_groups:
group["lr"] = lr
# --- one optimization step ---
x, y = get_batch("train") # (B, T), (B, T)
logits, loss = model(x, y) # logits: (B, T, vocab_size)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()A few methodology notes on the loop body:
zero_grad(set_to_none=True)frees gradient tensors instead of filling them with zeros — slightly faster and lower-memory, and the default behavior in modern PyTorch. Gradients accumulate acrossbackward()calls by design (that’s a feature — Day 8’s exercise uses it), so forgetting to zero them means every step applies the sum of all gradients ever computed. The loss doesn’t diverge immediately, it just… never gets good, which makes this bug maddening.- Evaluation happens before the step, at iterations 0, 250, 500, …, so the very first printout shows the untrained model — a useful sanity anchor (see below).
- We checkpoint on best val loss, not last iteration. For a small corpus the model will eventually overfit; the best-val checkpoint is the model you actually want for Day 9’s generation. We save the optimizer state and iteration number too, so training can be resumed exactly — Adam’s moment estimates are part of the training state, and resuming without them causes a visible loss bump.
Here’s a real run of this exact configuration on Tiny Shakespeare (times from a single consumer GPU; on CPU expect the same losses, much slower):
iter 0 | train 4.2861 | val 4.2846 | lr 3.00e-06 | 4s
iter 250 | train 2.4477 | val 2.4633 | lr 2.97e-04 | 42s
iter 500 | train 2.0410 | val 2.1064 | lr 2.89e-04 | 80s
iter 1000 | train 1.6295 | val 1.7896 | lr 2.57e-04 | 156s
iter 1500 | train 1.4692 | val 1.6612 | lr 2.11e-04 | 233s
iter 2000 | train 1.3818 | val 1.5994 | lr 1.72e-04 | 310s
iter 2500 | train 1.3155 | val 1.5566 | lr 1.35e-04 | 387s
iter 3000 | train 1.2637 | val 1.5288 | lr 9.87e-05 | 464s
iter 3500 | train 1.2210 | val 1.5066 | lr 6.79e-05 | 541s
iter 4000 | train 1.1885 | val 1.4931 | lr 4.49e-05 | 618s
iter 4500 | train 1.1650 | val 1.4879 | lr 3.37e-05 | 695s
iter 5000 | train 1.1511 | val 1.4863 | lr 3.00e-05 | 772s
Your numbers will differ by a few hundredths depending on hardware and seed, but the shape should match closely.
Reading the loss curve: what healthy looks like
Loss numbers are only useful if you know what they should look like. Three checkpoints on this curve carry real diagnostic meaning:
The starting loss is predictable. An untrained model outputs (approximately) uniform logits over the vocabulary, so its cross-entropy should be \(-\ln(1/65) = \ln 65 \approx 4.174\). Our iteration-0 print shows 4.286 — close to, and slightly above, the theoretical value (random init isn’t perfectly uniform). Always do this arithmetic before training anything. If your starting loss is 10.4 when \(\ln(\text{vocab})\) says 4.17, something is broken before step one — wrong vocab size in the config, a busted loss reduction, mismatched targets. This one check catches an entire family of bugs for free.
The early drop is steep, then it grinds. From 4.29 to 2.45 took 250 iterations — the model learned character frequencies and common bigrams, the cheap wins. From 1.63 to 1.49 took four thousand more. That decelerating curve is normal and healthy; a curve that goes flat early usually means the LR is too low or the schedule is misconfigured.
The train/val gap is the overfitting dial. At iteration 500 the gap is 0.07; by iteration 5000 it’s 0.34 and growing. That’s expected on a 1MB corpus — the model is starting to memorize Shakespeare rather than model him. Healthy, unhealthy, and broken runs look like this:
The decision rules, in order of what you’ll actually encounter:
| Symptom | Diagnosis | Fix |
|---|---|---|
| Loss spikes early then plateaus high | LR too high or no warmup | Lower max_lr, check warmup runs |
| Train and val both flat near start | LR too low, or grads not flowing | Raise LR; check loss.backward() actually reached all params |
| Train falls, val rises | Overfitting | More dropout, more weight decay, stop earlier — you already checkpoint the best val, so nothing is lost |
| Train/val identical to 3 decimals forever | Val split leaks into train | Recheck Day 2’s split |
| Start loss ≠ ln(vocab_size) | Broken before training | Fix config/loss wiring first |
Since we saved the best checkpoint (val 1.4863 at iteration 5000 here — the schedule’s landing was timed well), loading it back for Day 9 is symmetric:
checkpoint = torch.load("checkpoints/best.pt", map_location=device, weights_only=False)
model = GPT(checkpoint["config"]).to(device)
model.load_state_dict(checkpoint["model"])
model.eval()
print(f"loaded iter {checkpoint['iter_num']}, val loss {checkpoint['val_loss']:.4f}")(weights_only=False because our checkpoint contains a GPTConfig object, not just tensors — fine for checkpoints you created, never for ones downloaded from the internet.)
One number to carry into tomorrow: val loss 1.49 in nats-per-character means the model assigns the true next character an average probability of \(e^{-1.49} \approx 0.23\) — nearly one-in-four on a 65-way choice, against a 1-in-65 random baseline. That’s already enough structure to generate text that scans like Shakespeare, which is exactly what we’ll do on Day 9.
🧪 Your task
Your GPU (or patience) may not fit batch_size=64 at once. Implement gradient accumulation: modify the training loop so each optimizer step is computed from accum_steps = 4 micro-batches of size 16, producing gradients mathematically equivalent to one batch of 64. The loss curve of your modified script should match the original run closely (not exactly — batch composition differs).
Hint: gradients accumulate across backward() calls automatically as long as you don’t call zero_grad() between them. The subtlety is scaling: each micro-batch’s loss is already a mean over its own 16×256 tokens, so summing four unscaled backward() calls gives you 4× the gradient you want. Divide each micro-loss by accum_steps before calling backward(). Clip once, after all four.
Solution
batch_size = 16 # micro-batch
accum_steps = 4 # effective batch = 16 * 4 = 64
for iter_num in range(max_iters + 1):
if iter_num % eval_interval == 0:
losses = estimate_loss(model)
print(f"iter {iter_num:5d} | train {losses['train']:.4f} "
f"| val {losses['val']:.4f}")
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"config": config,
"iter_num": iter_num,
"val_loss": best_val_loss,
}, "checkpoints/best.pt")
lr = get_lr(iter_num)
for group in optimizer.param_groups:
group["lr"] = lr
optimizer.zero_grad(set_to_none=True)
for micro in range(accum_steps):
x, y = get_batch("train") # (16, 256)
_, loss = model(x, y)
# loss is a mean over this micro-batch's tokens;
# scale so the SUM over micro-batches equals the mean
# over the full effective batch
(loss / accum_steps).backward() # grads ACCUMULATE across calls
# clip ONCE, on the fully accumulated gradient
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()Why this is equivalent: cross-entropy averaged over batch \(B\) splits as \(\frac{1}{4}\sum_{i=1}^{4} L_i\) where each \(L_i\) is the mean over micro-batch \(i\) (all micro-batches have the same token count, so the means combine exactly). Gradients are linear, so accumulating \(\nabla(L_i/4)\) over four backward() calls yields precisely \(\nabla\) of the full-batch mean loss. Two classic mistakes: forgetting the / accum_steps (gradients 4× too large — effectively a 4× LR increase, often divergence), and clipping inside the micro-batch loop (clips partial gradients against the full threshold, changing the math). This is exactly how models with “batch size 4M tokens” are trained on GPUs that fit a few thousand tokens at a time.
Key takeaways
- AdamW for GPTs:
betas=(0.9, 0.95)for faster adaptation to gradient-scale shifts, and weight decay only ondim >= 2parameters — never on biases or LayerNorms. The shape test beats name matching. - LR warmup exists because Adam’s moment estimates are garbage for the first ~100 steps; cosine decay exists because a settling model needs shrinking steps. Ten lines of
math.cos, no scheduler class needed. - Set the LR on every param group, every iteration — you have two groups now.
clip_grad_norm_betweenbackward()andstep()preserves direction, caps magnitude; log the returned norm, and if it’s always above the threshold your LR is wrong.- Never trust single-batch loss: estimate with
model.eval()+@torch.no_grad()+ a 200-batch average, and always restoremodel.train(). - Sanity-check iteration 0 against \(\ln(\text{vocab size})\) — it catches broken wiring before you waste a run.
- Checkpoint best-val (with optimizer state and iteration), not last-iteration; on small data the best model is rarely the final one.
Tomorrow: the model can predict the next character — Day 9 turns that into actual text generation, from greedy decoding through temperature, top-k, and top-p sampling, and why each knob changes the Shakespeare that comes out.