Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

In this lesson

  • Lesson 6 — The Pretraining Engine: train.py, Fast and Restartable
    • Sizing the run: Chinchilla math and the half-million-token batch
    • Feeding the GPU: the memmap batch sampler
    • bf16 autocast and torch.compile: the two free speedups
    • The optimizer: AdamW with parameter groups
    • The schedule: 2k-step warmup, cosine to 10%
    • Crash-proofing: atomic checkpoints, exact resume, SIGTERM
    • Watching the run: Weights & Biases
    • The full src/train.py
    • One flag away from multi-GPU: DDP
    • Throughput math and a $0.10 smoke test on vast.ai
    • 🧪 Your task
    • Key takeaways
    • Coming up

📖 Build Your Own Wikipedia LLM · Lesson 6 — The Pretraining Engine: train.py, Fast and Restartable

🏠 📖 Course home  |  ← Lesson 05  |  Lesson 07 →  |  📚 All mini-courses


Lesson 6 — The Pretraining Engine: train.py, Fast and Restartable

In Lesson 5 you built src/model.py — a 124M-parameter decoder-only transformer with RMSNorm, RoPE, SwiGLU, and weight-tied embeddings — and in Lesson 4 you left data/tokens/train.bin and data/tokens/val.bin sitting on disk as flat arrays of uint16 token ids. The model can compute a loss; the data is packed. What’s missing is the machine that connects them for twenty-plus hours straight without wasting a single GPU-minute or losing a single step of progress.

That machine is src/train.py, and it’s the most important file in the repo. A naive training loop would work — for about four hours, until your spot instance gets preempted, or the process OOMs, or you fat-finger Ctrl-C, and you discover you have no checkpoint, no optimizer state, and $4 of vanished compute. This lesson builds the loop the way production runs are built: memory-mapped data feeding, bf16 autocast, torch.compile, gradient accumulation to a half-million-token batch, cosine LR with warmup, atomic checkpoints with full resume (model + optimizer + schedule + step + RNG), a SIGTERM handler for preemption, and Weights & Biases telemetry so you can watch it from your phone.

🎯 In this lesson you will: write the complete production-grade src/train.py — memmap batch sampler, bf16 autocast, torch.compile, AdamW with parameter groups, warmup+cosine schedule, gradient accumulation to ~0.5M tokens/step, atomic keep-last-k checkpoints with exact resume, SIGTERM-safe shutdown, W&B logging with sample generations, and a ~15-line DDP variant — then verify it end to end with a $0.10 smoke run on vast.ai.

Sizing the run: Chinchilla math and the half-million-token batch

Before any code, three numbers must be justified: how many tokens we train on, how many tokens per optimizer step, and how many steps that implies.

How many tokens? The Chinchilla scaling-law result says a compute-optimal training run uses roughly twenty tokens per parameter. For our 124M-parameter model that’s about 2.5B tokens. We’re going to train on 4B tokens — about 32 tokens per parameter, deliberately past the compute-optimal point. Why over-train? Chinchilla optimizes loss per training FLOP, but you don’t care about training FLOPs — you care about how good the final model is for the ~$10 you’re spending, and small models keep improving well past the 20:1 ratio (this is exactly why the LLaMA family trained small models on trillions of tokens). Conveniently, 4B tokens is also about one full epoch of the cleaned corpus from Lesson 3, so we never repeat data.

How many tokens per step? Gradient quality scales with batch size: too small and the gradient is noise, too large and you’re burning compute for no extra signal. For models in this class, ~0.5M tokens per optimizer step is the well-trodden setting (GPT-3 Small, same 124M/12-layer shape, used exactly 0.5M). A single RTX 4090 cannot hold that batch, so we accumulate gradients:

\[ \underbrace{16}_{\text{micro-batch}} \times \underbrace{1024}_{\text{block\_size}} \times \underbrace{32}_{\text{accum steps}} = 524{,}288 \text{ tokens/step} \]

The micro-batch of 16 sequences × 1024 tokens fits comfortably in 24GB with bf16: the model itself needs only ~2GB (0.5GB fp32 weights + 0.5GB grads + 1GB AdamW moments), leaving plenty of headroom for activations and torch.compile’s workspace.

How many steps?

\[ \frac{4\times10^9 \text{ tokens}}{524{,}288 \text{ tokens/step}} \approx 7{,}630 \text{ steps} \]

That’s the whole run: 7,630 optimizer steps. Every hyperparameter below is defined against that horizon.

Feeding the GPU: the memmap batch sampler

The token files from Lesson 4 are gigabytes of raw uint16. Loading them into RAM would work on a big instance, but np.memmap is strictly better: it maps the file into virtual memory and lets the OS page in only what’s touched. Zero startup time, near-zero resident memory, and after a few hundred steps the hot pages live in the OS page cache anyway.

train_data = np.memmap("data/tokens/train.bin", dtype=np.uint16, mode="r")
val_data   = np.memmap("data/tokens/val.bin",   dtype=np.uint16, mode="r")

rng = np.random.default_rng(1337)   # dedicated generator: its state goes in the checkpoint

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = rng.integers(0, len(data) - T - 1, size=B)          # B random windows
    x = torch.stack([torch.from_numpy(data[i     : i + T    ].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + T].astype(np.int64)) for i in ix])
    return (x.pin_memory().to(device, non_blocking=True),
            y.pin_memory().to(device, non_blocking=True))

Line by line, the decisions that matter:

  • Random windows, not a DataLoader. The corpus is one continuous token stream; a “sample” is any 1024-token window. rng.integers picks B random offsets per batch. No workers, no collate function, no epoch bookkeeping — for packed LM data, a DataLoader is pure overhead.
  • y is x shifted by one. Next-token prediction: position i predicts token i+1. Slicing the same memmap twice costs nothing.
  • astype(np.int64) — embedding lookups need int64 indices; uint16 was only for disk compactness (it holds our 32,768-token vocab exactly, which is why we chose that vocab size in Lesson 4).
  • pin_memory() + non_blocking=True — page-locked host memory lets the H2D copy run on a side DMA engine, overlapping with compute. Skip it and every batch copy stalls the GPU for a millisecond or two; over 244,000 micro-batches that’s real money.
  • A dedicated default_rng instead of global np.random — because a resumable run must restore the sampler’s exact position. A Generator object has an extractable, restorable state; the global seed does not compose as cleanly. This is checkpoint groundwork.

bf16 autocast and torch.compile: the two free speedups

Two lines buy you roughly a 3× throughput improvement over naive fp32 eager mode.

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    logits, loss = model(x, y)

Why bf16, and why no GradScaler? fp16 and bf16 are both 16-bit, but they split their bits differently. fp16 spends 10 bits on mantissa and only 5 on exponent — its smallest normal number is ~6×10⁻⁵, and small gradients silently underflow to zero. That’s why fp16 training requires GradScaler: multiply the loss by ~65,000 before backward so gradients stay representable, then un-scale before the optimizer step, and skip steps when the scale was wrong. bf16 spends 8 bits on exponent — the same dynamic range as fp32 — and sacrifices mantissa precision instead. Nothing underflows, so there is nothing to scale: no GradScaler, no skipped steps, no scale-factor hyperparameter, one less thing to checkpoint. The lost mantissa precision is absorbed because autocast keeps the master weights and AdamW moments in fp32; only the matmuls and activations run in bf16. On Ampere and later (the 4090 is Ada), bf16 matmuls run on tensor cores at full speed. This is why the course requires a 30-series-or-newer GPU.

model = torch.compile(model)

What compile actually does: it traces the Python forward pass once, fuses the resulting graph into large Triton kernels (RMSNorm + residual add in one kernel, the SwiGLU gate in one kernel, RoPE rotation folded into the attention prologue), and eliminates most Python-interpreter overhead per step. On this model it’s worth 40–70% more tokens/s. The cost: the first optimizer step takes 1–3 minutes while kernels compile. Over a 22-hour run, that amortizes to nothing — but it’s why the smoke test at the end of this lesson feels slow to start.

torch.set_float32_matmul_precision("high")   # TF32 for any residual fp32 matmuls

One caveat that shapes the code below: torch.compile returns a wrapper whose state_dict() keys are prefixed with _orig_mod.. We keep a raw_model reference to the un-compiled module and use it for checkpointing and generation, so checkpoints stay loadable with or without compile.

The optimizer: AdamW with parameter groups

decay   = [p for p in raw_model.parameters() if p.requires_grad and p.dim() >= 2]
nodecay = [p for p in raw_model.parameters() if p.requires_grad and p.dim() <  2]
optimizer = torch.optim.AdamW(
    [{"params": decay,   "weight_decay": 0.1},
     {"params": nodecay, "weight_decay": 0.0}],
    lr=cfg["max_lr"], betas=(0.9, 0.95), fused=True,
)
  • weight_decay=0.1 on matrices is the standard LLM setting — a strong regularizer that also demonstrably improves optimization dynamics at this scale.
  • No decay on 1-D parameters. The dim() < 2 split is a clean rule because of choices made in Lesson 5: the model has no biases, so the only 1-D parameters are the RMSNorm gains. Decaying a norm gain toward zero fights its job (it’s a learned scale, not a feature weight). The token embedding is 2-D and does get decayed — with weight tying it’s also the output head, and decaying it is standard.
  • betas=(0.9, 0.95) — β₂=0.95 instead of the default 0.999 makes the second-moment estimate adapt in ~20 steps instead of ~1000. LLM gradient variance shifts fast early in training; a sluggish β₂ causes loss spikes. This has been the LLM default since GPT-3.
  • fused=True — one CUDA kernel updates all parameters instead of a Python loop over tensors. Free ~2% throughput.

The schedule: 2k-step warmup, cosine to 10%

The learning rate is a pure function of the step number — deliberately not a stateful LRScheduler object, because a pure function makes resume trivial: restore the step, and the schedule is automatically correct.

def get_lr(step):
    if step < cfg["warmup_steps"]:                                   # linear warmup
        return cfg["max_lr"] * (step + 1) / cfg["warmup_steps"]
    if step >= cfg["max_steps"]:
        return cfg["min_lr"]
    t = (step - cfg["warmup_steps"]) / (cfg["max_steps"] - cfg["warmup_steps"])
    return cfg["min_lr"] + 0.5 * (1 + math.cos(math.pi * t)) * (cfg["max_lr"] - cfg["min_lr"])
step 2,000 (warmup ends) step 7,630 max_lr = 6e-4 min_lr = 6e-5 (10% floor) cosine decay linear warmup

Why each piece:

  • Warmup (2,000 steps) exists because AdamW’s second-moment estimates start at zero; the first steps would otherwise be effectively huge and can blow the loss up permanently. Ramping linearly to max_lr=6e-4 lets the moments calibrate. Yes, 2k of 7,630 steps is a generous ~26% of the run — for short runs that generosity is cheap insurance, and the model is still learning rapidly during warmup anyway.
  • Cosine decay to min_lr = max_lr/10 — the smooth annealing squeezes out the last 0.1–0.2 of loss; ending at 10% rather than 0 keeps the model learning to the final step. This exact recipe (6e-4 → 6e-5, cosine) is the consensus setting for 124M-class models.

Crash-proofing: atomic checkpoints, exact resume, SIGTERM

This is the section that separates a script from an engine. The requirement: kill the process at any moment — preemption, OOM, Ctrl-C — restart it with the same command, and the run continues as if nothing happened. That means checkpointing five things, not one:

  1. Model weights — obviously.
  2. Optimizer state — AdamW’s two moment tensors are 1GB of hard-won information. Resume without them and you effectively restart warmup, with a visible loss spike.
  3. Step number — which, because our schedule is a pure function of step, is the scheduler state.
  4. RNG state — the numpy generator that drives batch sampling, plus torch CPU/CUDA states (dropout is off in our model, but sampling order matters for exact reproducibility). Restore these and the resumed run sees the same data in the same order as the uninterrupted run would have.
  5. The config — so sample.py and eval_ppl.py in Lesson 7 can rebuild the model without guessing.
def save_checkpoint(step):
    if not master:
        return
    ckpt = {
        "model": raw_model.state_dict(),         # raw_model: no _orig_mod. prefixes
        "optimizer": optimizer.state_dict(),
        "step": step,
        "config": cfg,
        "rng": {
            "numpy": rng.bit_generator.state,
            "torch_cpu": torch.get_rng_state(),
            "torch_cuda": torch.cuda.get_rng_state_all(),
        },
    }
    path = os.path.join(cfg["out_dir"], f"ckpt_{step:07d}.pt")
    torch.save(ckpt, path + ".tmp")
    os.replace(path + ".tmp", path)              # atomic rename
    for old in sorted(glob.glob(f"{cfg['out_dir']}/ckpt_*.pt"))[:-cfg["keep_last_k"]]:
        os.remove(old)

Three details here are non-negotiable:

  • Atomic save. torch.save directly to the target path is a trap: if the instance dies mid-write (the exact moment preemption tends to strike, since saves happen periodically), you’re left with a truncated, unloadable file — and if it overwrote your only checkpoint, the run is dead. Writing to .tmp and then os.replace (an atomic rename on POSIX) guarantees the checkpoint file is always either the old complete one or the new complete one.
  • Keep-last-k. Each checkpoint is ~1.5GB (weights + optimizer moments). Keeping every one of ~30 saves would eat 45GB of instance disk. We keep the last 3 — enough to fall back one or two saves if the latest somehow loads badly.
  • Auto-resume, not a flag. At startup, the script looks for the newest ckpt_*.pt and resumes from it if present. This matters operationally: when a preempted vast.ai instance restarts, your relaunch command is identical to the launch command. No human decision required at 3am.

Finally, the SIGTERM handler. When vast.ai reclaims a spot/interruptible instance, your process receives SIGTERM shortly before the hard kill. A signal handler must do almost nothing (it runs between bytecodes), so it just sets a flag; the main loop checks the flag once per step, saves, and exits cleanly:

stop = False
def _request_stop(signum, frame):
    global stop
    stop = True
signal.signal(signal.SIGTERM, _request_stop)   # vast.ai preemption warning
signal.signal(signal.SIGINT,  _request_stop)   # Ctrl-C gets the same graceful path

With ckpt_interval=250 (~25 minutes of compute), the worst case if the SIGTERM never arrives is losing 25 minutes ≈ $0.17. With it, you lose the current step: about 10 seconds.

flowchart TD
    A[process starts] --> B{checkpoint in<br/>checkpoints/ ?}
    B -- no --> C[fresh init<br/>step = 0]
    B -- yes --> D[load model + optimizer<br/>+ step + RNG states]
    C --> E[training loop]
    D --> E
    E --> F{SIGTERM or<br/>Ctrl-C ?}
    F -- no --> G{step % 250 == 0 ?}
    G -- yes --> H[atomic save<br/>tmp then os.replace<br/>prune to last k]
    G -- no --> E
    H --> E
    F -- yes --> I[save checkpoint<br/>exit 0]
    I -.->|instance restarts<br/>same command| A

Watching the run: Weights & Biases

You will not sit in front of an SSH session for 22 hours; the run needs to report to somewhere you can check from anywhere. We use Weights & Biases (free tier is plenty). If you prefer self-hosted, MLflow or TensorBoard are drop-in alternatives — same metric names, different log call — but W&B’s zero-setup mobile dashboard is why it’s the course default, and this is the only place we’ll mention the alternatives.

wandb.init(project="wikillm", name=cfg["run_name"], id=cfg["run_name"],
           resume="allow", config=cfg)

The subtle flags: a fixed id plus resume="allow" means that when the preempted run restarts, it appends to the same W&B run instead of creating a fragment — your loss curve stays one unbroken line across preemptions.

What we log, and why each series earns its place:

metric interval what it tells you
train/loss 10 steps the headline. Should fall fast to ~4, then grind toward ~3.0–3.3
train/lr 10 steps confirms warmup/cosine are behaving; the x-axis sanity check
train/grad_norm 10 steps the early-warning system: sustained values near the 1.0 clip or spikes >5 predict divergence before the loss shows it
perf/tokens_per_s 10 steps the money meter — this number is your cost per lesson; a sudden drop means thermal throttling or a noisy neighbor
val/loss 250 steps the honest number: measured on val.bin, which the model never trains on
samples table 250 steps qualitative eyes-on: 3 fixed prompts, greedy-ish completions, logged as a wandb.Table so you can literally watch the model learn English

The generations table is the most underrated of these. At step 500 the completions are token soup; by step 3,000 they’re grammatical; by step 7,000 they’re recognizably Wikipedia-flavored prose. When a run is silently broken (bad data, LR too high), the samples reveal it hours before the loss curve does.

The full src/train.py

Everything above, assembled. This is the complete file — drop it in src/train.py.

"""Pretraining engine for WikiGPT-124M.

Single GPU:
    python src/train.py configs/pretrain.yaml
Multi-GPU (2-4x GPUs on one vast.ai instance):
    torchrun --standalone --nproc_per_node=4 src/train.py configs/pretrain.yaml
"""
import glob
import math
import os
import signal
import sys
import time

import numpy as np
import torch
import wandb
import yaml
from tokenizers import Tokenizer

from model import GPT, GPTConfig

# ---------------------------------------------------------------- config ----
with open(sys.argv[1]) as f:
    cfg = yaml.safe_load(f)
os.makedirs(cfg["out_dir"], exist_ok=True)

# ------------------------------------------- DDP (inert on a single GPU) ----
ddp = int(os.environ.get("RANK", -1)) != -1        # torchrun sets RANK
if ddp:
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    dist.init_process_group(backend="nccl")
    ddp_rank = dist.get_rank()
    ddp_world_size = dist.get_world_size()
    device = f"cuda:{os.environ['LOCAL_RANK']}"
    torch.cuda.set_device(device)
else:
    ddp_rank, ddp_world_size, device = 0, 1, "cuda"
master = ddp_rank == 0

torch.manual_seed(1337 + ddp_rank)
torch.set_float32_matmul_precision("high")         # TF32 for residual fp32 matmuls

# ------------------------------------------------------------------ data ----
train_data = np.memmap(os.path.join(cfg["data_dir"], "train.bin"), dtype=np.uint16, mode="r")
val_data   = np.memmap(os.path.join(cfg["data_dir"], "val.bin"),   dtype=np.uint16, mode="r")

B, T = cfg["micro_batch_size"], cfg["block_size"]
rng = np.random.default_rng(1337 + ddp_rank)       # per-rank stream; state is checkpointed

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = rng.integers(0, len(data) - T - 1, size=B)
    x = torch.stack([torch.from_numpy(data[i     : i + T    ].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + T].astype(np.int64)) for i in ix])
    return (x.pin_memory().to(device, non_blocking=True),
            y.pin_memory().to(device, non_blocking=True))

# ----------------------------------------------- model, optimizer, sched ----
model = GPT(GPTConfig())                           # the exact 124M config from Lesson 5
model.to(device)
raw_model = model                                  # clean handle for checkpoints/sampling
if cfg["compile"]:
    model = torch.compile(model)                   # first step compiles: 1-3 min, once
if ddp:
    model = DDP(model, device_ids=[int(os.environ["LOCAL_RANK"])])

decay   = [p for p in raw_model.parameters() if p.requires_grad and p.dim() >= 2]
nodecay = [p for p in raw_model.parameters() if p.requires_grad and p.dim() <  2]
optimizer = torch.optim.AdamW(
    [{"params": decay,   "weight_decay": cfg["weight_decay"]},
     {"params": nodecay, "weight_decay": 0.0}],    # RMSNorm gains: no decay
    lr=cfg["max_lr"], betas=(0.9, 0.95), fused=True,
)

grad_accum = cfg["tokens_per_step"] // (B * T * ddp_world_size)
assert cfg["tokens_per_step"] % (B * T * ddp_world_size) == 0, "tokens_per_step must divide evenly"

def get_lr(step):
    """Schedule as a pure function of step: resuming the step resumes the schedule."""
    if step < cfg["warmup_steps"]:
        return cfg["max_lr"] * (step + 1) / cfg["warmup_steps"]
    if step >= cfg["max_steps"]:
        return cfg["min_lr"]
    t = (step - cfg["warmup_steps"]) / (cfg["max_steps"] - cfg["warmup_steps"])
    return cfg["min_lr"] + 0.5 * (1 + math.cos(math.pi * t)) * (cfg["max_lr"] - cfg["min_lr"])

# -------------------------------------------------- checkpointing/resume ----
def save_checkpoint(step):
    if not master:
        return
    ckpt = {
        "model": raw_model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "step": step,
        "config": cfg,
        "rng": {
            "numpy": rng.bit_generator.state,
            "torch_cpu": torch.get_rng_state(),
            "torch_cuda": torch.cuda.get_rng_state_all(),
        },
    }
    path = os.path.join(cfg["out_dir"], f"ckpt_{step:07d}.pt")
    torch.save(ckpt, path + ".tmp")
    os.replace(path + ".tmp", path)                # atomic: never a half-written file
    for old in sorted(glob.glob(os.path.join(cfg["out_dir"], "ckpt_*.pt")))[:-cfg["keep_last_k"]]:
        os.remove(old)

start_step = 0
ckpts = sorted(glob.glob(os.path.join(cfg["out_dir"], "ckpt_*.pt")))
if ckpts:                                          # auto-resume: relaunch == launch
    ckpt = torch.load(ckpts[-1], map_location=device)
    raw_model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    start_step = ckpt["step"]
    rng.bit_generator.state = ckpt["rng"]["numpy"]
    torch.set_rng_state(ckpt["rng"]["torch_cpu"])
    torch.cuda.set_rng_state_all(ckpt["rng"]["torch_cuda"])
    if master:
        print(f"resumed from {ckpts[-1]} at step {start_step}")

# ---------------------------------------------------------- preemption ----
stop = False
def _request_stop(signum, frame):
    global stop
    stop = True                                    # loop finishes the step, saves, exits
signal.signal(signal.SIGTERM, _request_stop)       # vast.ai sends this before reclaiming
signal.signal(signal.SIGINT,  _request_stop)

# ------------------------------------------------------------------ eval ----
tok = Tokenizer.from_file("tokenizer/tokenizer.json")
PROMPTS = ["The history of the Roman Empire", "Photosynthesis is the process",
           "The Pacific Ocean covers"]

@torch.no_grad()
def estimate_val_loss():
    model.eval()
    losses = torch.zeros(cfg["eval_iters"], device=device)
    for i in range(cfg["eval_iters"]):
        x, y = get_batch("val")
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            _, loss = model(x, y)
        losses[i] = loss
    model.train()
    return losses.mean().item()

@torch.no_grad()
def sample_generations():
    raw_model.eval()
    rows = []
    for p in PROMPTS:
        idx = torch.tensor([tok.encode(p).ids], device=device)
        out = raw_model.generate(idx, max_new_tokens=80, temperature=0.8, top_k=50)
        rows.append([p, tok.decode(out[0].tolist())])
    raw_model.train()
    return rows

# ------------------------------------------------------------------ W&B ----
if master:
    wandb.init(project=cfg["wandb_project"], name=cfg["run_name"],
               id=cfg["run_name"], resume="allow", config=cfg)

# ------------------------------------------------------------- the loop ----
step = start_step
t0 = time.time()
while step < cfg["max_steps"] and not stop:
    lr = get_lr(step)
    for g in optimizer.param_groups:
        g["lr"] = lr

    optimizer.zero_grad(set_to_none=True)          # frees grad memory between steps
    loss_accum = 0.0
    for micro in range(grad_accum):
        x, y = get_batch("train")
        if ddp:                                    # all-reduce only on the last micro-step
            model.require_backward_grad_sync = (micro == grad_accum - 1)
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            _, loss = model(x, y)
        loss = loss / grad_accum                   # mean over the effective batch
        loss_accum += loss.detach()
        loss.backward()                            # grads accumulate across micro-steps
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
    optimizer.step()
    step += 1

    if master and step % cfg["log_interval"] == 0:
        torch.cuda.synchronize()                   # honest timing: wait for the GPU
        dt, t0 = time.time() - t0, time.time()
        tok_per_s = cfg["tokens_per_step"] * cfg["log_interval"] / dt
        print(f"step {step:5d} | loss {loss_accum.item():.4f} | lr {lr:.2e} | "
              f"gnorm {grad_norm:.2f} | {tok_per_s/1e3:.1f}k tok/s")
        wandb.log({"train/loss": loss_accum.item(), "train/lr": lr,
                   "train/grad_norm": grad_norm.item(),
                   "perf/tokens_per_s": tok_per_s}, step=step)

    if step % cfg["eval_interval"] == 0 or step == cfg["max_steps"]:
        val_loss = estimate_val_loss()
        if master:
            print(f"step {step:5d} | val loss {val_loss:.4f}")
            table = wandb.Table(columns=["prompt", "generation"], data=sample_generations())
            wandb.log({"val/loss": val_loss, "samples": table}, step=step)
        t0 = time.time()                           # don't bill eval time to tok/s

    if step % cfg["ckpt_interval"] == 0 or step == cfg["max_steps"]:
        save_checkpoint(step)

if stop:
    save_checkpoint(step)
    if master:
        print(f"stop signal received -- checkpoint saved at step {step}, exiting cleanly")

if master:
    wandb.finish()
if ddp:
    dist.destroy_process_group()

Two loop details worth calling out explicitly:

  • loss / grad_accum before backward(). Gradients add across backward calls, so dividing each micro-loss by the accumulation count makes the summed gradient equal the mean-over-524k-tokens gradient — mathematically identical to one giant batch.
  • torch.cuda.synchronize() before timing. CUDA calls are asynchronous; without the sync you’d be timing kernel launches, not kernel execution, and your tokens/s (and cost projections) would be fiction.

And the config, configs/pretrain.yaml:

# configs/pretrain.yaml -- WikiGPT-124M pretraining, 1x RTX 4090
data_dir: data/tokens
out_dir: checkpoints

block_size: 1024
micro_batch_size: 16        # 16 x 1024 = 16,384 tokens per micro-batch, fits 24GB
tokens_per_step: 524288     # ~0.5M-token effective batch -> grad_accum = 32 on 1 GPU

max_steps: 7630             # 7,630 x 524,288 = 4.0B tokens (~32 tok/param)
warmup_steps: 2000
max_lr: 6.0e-4
min_lr: 6.0e-5
weight_decay: 0.1
grad_clip: 1.0

eval_interval: 250          # val loss + sample generations every ~25 min
eval_iters: 100             # 100 x 16 x 1024 = 1.6M val tokens per estimate
log_interval: 10
ckpt_interval: 250          # worst-case loss without SIGTERM: ~25 min of compute
keep_last_k: 3

compile: true
wandb_project: wikillm
run_name: wikigpt-124m-pretrain

flowchart LR
    A[memmap<br/>train.bin] --> B[sample B=16<br/>random windows]
    B --> C[pinned H2D copy<br/>non_blocking]
    C --> D[bf16 autocast<br/>forward + loss]
    D --> E[loss / 32<br/>backward]
    E -->|repeat x32<br/>micro-steps| B
    E --> F[clip grad norm<br/>to 1.0]
    F --> G[fused AdamW step<br/>lr = get_lr of step]
    G --> H{every 10 steps}
    H --> I[W&B: loss, lr,<br/>grad_norm, tok/s]
    G --> J{every 250 steps}
    J --> K[val loss on val.bin<br/>+ generations table]
    J --> L[atomic checkpoint<br/>keep last 3]

One flag away from multi-GPU: DDP

Everything DDP-related is already in the file above — it’s the if ddp: branches, about 15 lines total. Here’s what each one does, because DDP’s elegance is easy to miss:

  1. Detection — torchrun sets RANK, LOCAL_RANK, WORLD_SIZE in the environment; plain python doesn’t. So one env check switches modes, and the single-GPU path pays zero cost.
  2. init_process_group("nccl") + per-rank device — each of the N processes owns one GPU and runs the identical script.
  3. DDP(model) — wraps the model so that during backward(), gradients are all-reduced (averaged) across GPUs, bucket by bucket, overlapped with the rest of the backward pass.
  4. require_backward_grad_sync = (micro == last) — the crucial accumulation interaction. Without it, DDP would all-reduce on every micro-step — 32 synchronizations per optimizer step instead of 1, throttling you to interconnect speed. We sync only when the accumulated gradient is complete.
  5. grad_accum divides by world_size — 4 GPUs each do 8 micro-steps instead of 32; the effective batch is still 524,288 tokens, so no hyperparameter changes at all.
  6. master guards — rank 0 alone logs, evals to W&B, and writes checkpoints; everyone else just computes.

Launch on a 4× 4090 vast.ai instance:

torchrun --standalone --nproc_per_node=4 src/train.py configs/pretrain.yaml

Expect ~3.3–3.6× scaling (not 4×; the all-reduce isn’t free over PCIe): the 22-hour run becomes ~6–7 hours of wall time at roughly 4× the hourly rate — nearly the same total dollars, much faster feedback. One honest caveat: on resume, only rank 0’s RNG state is restored (ranks re-seed as 1337 + rank), so a resumed multi-GPU run sees a slightly different data order than an uninterrupted one — statistically irrelevant for training, but single-GPU resume is bit-exact.

Throughput math and a $0.10 smoke test on vast.ai

The perf/tokens_per_s metric converts directly into money, so let’s do the arithmetic once, carefully. With bf16 + torch.compile on a 4090, this model sustains 45–55k tokens/s (you’ll read your exact number off W&B ten minutes into the run).

\[ \text{hours} = \frac{4\times10^9 \text{ tokens}}{50{,}000 \text{ tok/s} \times 3600} \approx 22.2\text{h} \]

Add ~5% for evals, checkpoints, and the compile warmup → ~23 hours. At $0.40/hr that’s ~$9.30; across the realistic 45–55k tok/s and $0.35–0.45/hr ranges, the envelope is $8–12 — exactly the course budget line. Every 1k tok/s you gain or lose is worth about ±$0.20/hr; this is why we bothered with pinned memory and fused AdamW.

Before Lesson 7’s real launch, spend ten minutes and ~$0.10 proving the whole engine works. Rent the reference instance (or reuse the one from Lesson 2):

# find a 4090 with decent disk and bandwidth
vastai search offers 'gpu_name=RTX_4090 num_gpus=1 disk_space>=100 inet_down>=200' -o 'dph+'

# rent it with the pytorch image (use the offer ID from the search)
vastai create instance <OFFER_ID> --image pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel \
    --disk 100 --ssh

vastai show instances          # grab ssh host/port
ssh -p <PORT> root@<HOST>

tmux new -s train              # everything long-running lives in tmux

From your laptop, sync the repo and (if not already on the instance from Lesson 4) the token files:

rsync -avz --exclude data/raw --exclude data/extracted \
    wikillm/ root@<HOST>:/workspace/wikillm/ -e "ssh -p <PORT>"

On the instance, run a 60-step smoke config that exercises every code path — logging, eval, sampling, checkpointing:

cd /workspace/wikillm
pip install -r requirements.txt
wandb login   # paste your API key

python - <<'EOF'
import yaml
cfg = yaml.safe_load(open("configs/pretrain.yaml"))
cfg.update(max_steps=60, warmup_steps=20, eval_interval=30,
           ckpt_interval=30, eval_iters=10, run_name="smoke-test")
yaml.dump(cfg, open("configs/smoke.yaml", "w"))
EOF

python src/train.py configs/smoke.yaml

The first step takes 1–3 minutes (that’s torch.compile working, not a hang). Then steps should tick by at ~10s each, W&B should show all six metric series, checkpoints/ should contain ckpt_0000030.pt and ckpt_0000060.pt, and the samples table should contain enthusiastic gibberish. Total cost: ~15 minutes ≈ $0.10. Delete configs/smoke.yaml and the smoke checkpoints afterward so Lesson 7 starts clean:

rm -f configs/smoke.yaml checkpoints/ckpt_*.pt

🧪 Your task

Run a preemption fire drill — prove, before the 22-hour run, that a killed run resumes exactly. On the vast.ai instance: (1) launch the smoke config but with max_steps: 120 and run_name: fire-drill; (2) after step ~60, simulate vast.ai preemption by sending the process SIGTERM from a second tmux window; (3) verify a checkpoint was written cleanly on shutdown; (4) relaunch with the identical command and confirm it resumes from the saved step; (5) confirm in W&B that the loss curve is one continuous line with no restart spike (a spike at the resume point would mean optimizer state wasn’t restored).

Solution
# window 1: launch (identical to any launch -- that's the point)
tmux new -s train
cd /workspace/wikillm
python src/train.py configs/smoke.yaml     # smoke.yaml with max_steps: 120, run_name: fire-drill

# window 2: wait until the logs pass step 60, then simulate preemption
tmux new -s ops
pgrep -f "src/train.py"                    # find the PID
kill -TERM <PID>                           # exactly what vast.ai sends

In window 1 you should see the loop finish its current step and print:

stop signal received -- checkpoint saved at step 61, exiting cleanly

Verify the checkpoint is complete and carries all five state components:

python - <<'EOF'
import torch
import glob
ckpt = torch.load(sorted(glob.glob("checkpoints/ckpt_*.pt"))[-1], map_location="cpu")
assert set(ckpt) >= {"model", "optimizer", "step", "config", "rng"}, "missing state!"
assert len(ckpt["optimizer"]["state"]) > 0, "optimizer moments missing!"
print("checkpoint OK at step", ckpt["step"])
EOF

Relaunch with the same command:

python src/train.py configs/smoke.yaml
# -> resumed from checkpoints/ckpt_0000061.pt at step 61

It runs steps 62–120 and exits. Now check W&B: because we used a fixed run id with resume="allow", both processes wrote to the same run, and train/loss should be a single unbroken curve through step 61. The two failure signatures to look for: a loss spike right after 61 means optimizer state didn’t restore (the fresh AdamW moments effectively restart warmup); a duplicate/forked run in the W&B project means the resume id logic is broken. Seeing neither is your license to launch the real thing.

One more optional check for the rigorous: on a single GPU, the RNG restoration makes the resumed run bit-exact — run 0→120 uninterrupted under a different run name and compare train/loss at step 100 between the two runs; they should match to all printed digits.

Key takeaways

  • Chinchilla-optimal is ~20 tokens/param (≈2.5B for 124M); we train to 4B (~32 tok/param) because over-training small models buys inference-time quality per dollar — and it’s ~1 epoch of our corpus.
  • The effective batch is ~0.5M tokens/step, built from 16×1024-token micro-batches × 32 gradient-accumulation steps; dividing each micro-loss by 32 makes accumulation mathematically identical to one giant batch.
  • bf16 needs no GradScaler — it has fp32’s exponent range, so gradients can’t underflow; fp16’s loss-scaling machinery simply isn’t needed. torch.compile adds 40–70% throughput for one line and a few minutes of first-step compile time.
  • AdamW at wd 0.1, β₂ 0.95, with no decay on 1-D params (only the RMSNorm gains, since the model has no biases); LR is a pure function of step — 2k warmup, cosine 6e-4 → 6e-5 — so restoring the step restores the schedule.
  • A checkpoint is five things: model, optimizer moments, step, config, and RNG state. Save atomically (tmp + os.replace), keep last k, auto-resume from the newest file so relaunch equals launch.
  • SIGTERM sets a flag, the loop saves and exits — preemption costs you one step instead of up to 25 minutes; W&B’s fixed run id keeps the curve unbroken across restarts.
  • DDP is ~15 lines and zero hyperparameter changes — detect torchrun’s env vars, wrap the model, sync gradients only on the last micro-step, let rank 0 do the talking.
  • tokens/s is money: at ~50k tok/s, 4B tokens ≈ 22–23 hours ≈ $8–12 on a $0.35–0.45/hr 4090. Always smoke-test the full engine for $0.10 before committing $10.

Coming up

In Lesson 7 we pull the trigger: launch the full 4B-token run on vast.ai, learn to babysit it from the W&B dashboard (and diagnose the loss spikes and throughput dips you might see), then evaluate the finished base model with eval_ppl.py and watch it write its first real paragraphs.


🏠 📖 Course home  |  ← Lesson 05  |  Lesson 07 →  |  📚 All mini-courses

 

© Kader Mohideen