Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • AI & ML Encyclopedia
  • Extra
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

In this chapter

  • 44.1 Tokenization: turning text into integers
  • 44.2 The GPU memory hierarchy: why bandwidth, not FLOPs, is the bottleneck
  • 44.3 FlashAttention: making attention IO-aware
  • 44.4 Scaling laws: spending a compute budget wisely
  • 44.5 Distributed training: splitting the work across many GPUs
  • 44.6 Mixture-of-Experts: more parameters, not more compute
  • 44.7 The training stack: precision, checkpointing, and memory math
  • 44.8 The data pipeline: web-scale curation
  • 44.9 Parameter-efficient fine-tuning: adapting without retraining
  • 44.10 Post-training: from base model to assistant
  • 44.11 Inference systems: serving the model efficiently
  • 44.1 — Which parallelism should I use?
  • 44.2 — Quick reference
  • 44.3 — Key takeaways
  • 44.4 — See also

Chapter 44 — 🏗️ LLM Systems: Building LLMs from Scratch

📖 All chapters  |  ← 43 · 🔎 Information Retrieval & Data Mining  |  45 · 🎚️ Post-Training I →

📚 Jump to any chapter

🧮 Mathematical Foundations

  • 01 · 🧮 Linear Algebra
  • 02 · ∂ Calculus & Differentiation
  • 03 · 📉 Optimization
  • 04 · 🎲 Probability & Statistics

🧭 The ML Workflow

  • 05 · 🌐 AI, ML & the Learning Process
  • 06 · 🧹 Data Preprocessing
  • 07 · 🗜️ Dimensionality Reduction

🧩 Classical Machine Learning

  • 08 · 📈 Regression
  • 09 · 📐 Classification Algorithms
  • 10 · 🌳 Ensemble Methods
  • 11 · 🔮 Clustering & Unsupervised Learning
  • 12 · 🎯 Model Evaluation & Tuning

🎲 Probabilistic Models

  • 13 · 🕸️ Probabilistic Graphical Models

🧠 Deep Learning

  • 14 · 🧠 Neural Networks (Core)
  • 15 · 🖼️ Convolutional Neural Networks
  • 16 · 🔁 Recurrent & Sequence Models
  • 17 · ⚡ Attention & Transformers
  • 18 · 🎨 Generative Models

🗣️ Applied AI: Vision, Language, Audio & Time

  • 19 · 👁️ Computer Vision
  • 20 · 💬 Natural Language Processing
  • 21 · 🔊 Speech & Audio Processing
  • 22 · ⏳ Time Series & Forecasting
  • 23 · 📚 Large Language Models
  • 24 · 🌈 Multimodal AI

🕹️ Reinforcement Learning

  • 25 · 🕹️ Reinforcement Learning

🛠️ Applied ML Systems & Industries

  • 26 · 🛒 Recommender Systems
  • 27 · 🚨 Anomaly & Fraud Detection
  • 28 · 🏦 ML Across Industries

🚀 Production, Tooling & Infrastructure

  • 29 · 🔧 MLOps & Deployment
  • 30 · 🚀 AI Infrastructure & Efficient Inference
  • 31 · 🧰 Tools & Frameworks

📚 Classical & Symbolic AI

  • 32 · 🧭 Search & Problem Solving
  • 33 · 📖 Knowledge Representation & Reasoning
  • 34 · 🗺️ Planning, Constraint Satisfaction & Game Playing
  • 35 · 🧬 Evolutionary Computation & Metaheuristics

⚖️ Responsible AI & Frontier

  • 36 · 🔍 Explainable AI & Interpretability
  • 37 · 🧷 Causal Inference
  • 38 · ⚖️ AI Ethics, Fairness & Safety
  • 39 · 🌠 Frontier & Emerging Directions

🎓 Advanced & Specialized Topics

  • 40 · 🔗 Graph Machine Learning
  • 41 · 🤖 Robotics & Autonomy
  • 42 · 📐 Learning Theory
  • 43 · 🔎 Information Retrieval & Data Mining
  • 44 · 🏗️ LLM Systems: Building LLMs from Scratch

🎚️ Post-Training & Fine-Tuning

  • 45 · 🎚️ Post-Training I — Transfer, Fine-Tuning & PEFT
  • 46 · 🏅 Post-Training II — Alignment & Evaluation

🚢 Model Serving & Deployment

  • 47 · 🚢 Model Serving & Deployment in Production

Chapter 23 told you what a large language model is — a Transformer decoder trained to predict the next token. This chapter is about the machinery that makes one actually exist: how raw text becomes tokens, why a GPU spends most of its time waiting on memory rather than computing, how you split a trillion-parameter model across thousands of chips, and how you serve it without melting your budget. The model architecture is the easy part; the systems around it are where frontier labs spend their effort.

🧭 In context: LLM engineering / infrastructure · used to train and serve frontier language models · key idea: the bottleneck is rarely the math — it is memory bandwidth, data quality, and how you shard the work.

💡 Remember this: building an LLM is a systems problem, not a math problem — the wins come from moving fewer bytes (tokenization, FlashAttention, KV cache, quantization), sharding work across many GPUs, and feeding the model clean data, not from a cleverer architecture.

44.1 Tokenization: turning text into integers

A neural network only ever sees numbers. Before any training happens, you need a fixed, finite vocabulary that maps text fragments to integer IDs. The naive options both fail: one ID per word leaves you helpless on typos and rare words (and explodes for languages with rich morphology), while one ID per character makes sequences punishingly long and forces the model to relearn spelling from scratch. The winning compromise is subword tokenization — common words stay whole, rare words break into pieces, and nothing is ever out-of-vocabulary.

The dominant algorithm is Byte-Pair Encoding (BPE). The idea is borrowed from a 1990s compression scheme: start with the smallest units you have, then repeatedly merge the most frequent adjacent pair into a new unit. Each merge becomes a vocabulary entry. You stop when you hit your target vocabulary size.

Intuition first: think of how a child learns to read. They don’t memorize every word as a whole picture, nor do they sound out every single letter forever. They learn chunks — -ing, -tion, un-, the — and snap new words together from familiar pieces. BPE discovers exactly those chunks automatically, just by counting which pairs of symbols show up next to each other most often.

Each merge step glues the most common neighboring pair into a single new symbol — watch l·o·w collapse into one token:

l o w lo most frequent pair “l·o” → one new token “lo”

A worked merge example

Take a tiny corpus where word counts are: low ×5, lower ×2, newest ×6, widest ×3. We start from characters (a special · marks a word boundary so the tokenizer can later glue words back together). Count every adjacent pair across the corpus, weighted by word frequency:

Step Most frequent pair Count New token
1 e + s 6+3 = 9 es
2 es + t 9 est
3 l + o 5+2 = 7 lo
4 lo + w 7 low

After four merges, newest tokenizes as n e w est, lowest (never seen in training) tokenizes cleanly as low est, and we have learned reusable chunks. That last point is the whole game: BPE generalizes to unseen words by reusing learned fragments.

from collections import Counter

def get_pairs(word):                       # word is a tuple of symbols
    return [(word[i], word[i+1]) for i in range(len(word)-1)]

def bpe_train(corpus, num_merges):
    # corpus: dict of word(str) -> count
    vocab = {tuple(w) + ('·',): c for w, c in corpus.items()}
    merges = []
    for _ in range(num_merges):
        pairs = Counter()
        for word, c in vocab.items():
            for p in get_pairs(word):
                pairs[p] += c              # weight by word frequency
        if not pairs: break
        best = pairs.most_common(1)[0][0]  # greedy: most frequent pair
        merges.append(best)
        # apply the merge everywhere
        new_vocab = {}
        for word, c in vocab.items():
            w, i = [], 0
            while i < len(word):
                if i < len(word)-1 and (word[i], word[i+1]) == best:
                    w.append(word[i] + word[i+1]); i += 2
                else:
                    w.append(word[i]); i += 1
            new_vocab[tuple(w)] = c
        vocab = new_vocab
    return merges

corpus = {"low": 5, "lower": 2, "newest": 6, "widest": 3}
print(bpe_train(corpus, 4))
# [('e','s'), ('es','t'), ('l','o'), ('lo','w')]

At inference time you don’t recount anything — you just apply the learned merge list in order to each new word.

Using a real tokenizer

In practice you almost never hand-roll BPE — you train or load one with Hugging Face tokenizers (the fast Rust-backed library) or reach for an existing one. Here is both: training a byte-level BPE from scratch, and inspecting how a production tokenizer chops text.

# 1) Train a byte-level BPE tokenizer with Hugging Face `tokenizers`
from tokenizers import Tokenizer, models, trainers, pre_tokenizers

tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
trainer = trainers.BpeTrainer(vocab_size=8000,
                              special_tokens=["<|endoftext|>"])
tok.train(["corpus.txt"], trainer)          # learns the merge list
print(tok.encode("lowest newest").tokens)   # -> subword pieces

# 2) Inspect a production tokenizer (GPT-style, via tiktoken)
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")   # GPT-4-class vocab
print(enc.encode("The number 1234 costs $5."))
print([enc.decode([t]) for t in enc.encode("1234")])  # how digits split
# e.g. ['123', '4'] — note arithmetic-unfriendly chunking

Byte-level BPE and vocabulary size

What happens with an emoji, a Korean character, or a byte sequence the corpus never contained? Character-level BPE would emit an unknown token. Byte-level BPE (used by GPT-2 onward) sidesteps this entirely: it runs BPE not over Unicode characters but over the 256 possible bytes. Any string in any language or encoding decomposes into bytes, so there is no out-of-vocabulary case, ever. The base vocabulary is exactly 256 symbols, and merges build up from there.

A close cousin worth naming is WordPiece (used by BERT) and Unigram (used by SentencePiece and the T5 family). WordPiece merges by what most increases corpus likelihood rather than by raw frequency; Unigram goes the other way — it starts with a large candidate vocabulary and prunes the tokens that hurt likelihood least. SentencePiece is the tooling that makes these language-agnostic by treating the raw input (spaces included, encoded as ▁) as just another character, so it needs no pre-tokenizer and round-trips perfectly. For building an LLM today, byte-level BPE and Unigram-via-SentencePiece are the two defaults you actually choose between.

Vocabulary size is a genuine tradeoff:

Vocab size Sequence length Embedding params Rare-token quality
Small (~8K) Long (more tokens/word) Cheap Poor — splits common words
Medium (~50K) Moderate Moderate Good (GPT-2/3 used ~50K)
Large (~128K–256K) Short (fewer tokens/word) Expensive Excellent, multilingual

A bigger vocabulary means each token carries more text, so sequences are shorter and you burn less compute per word of training data — but the embedding and output-projection matrices grow as \(V \times d_{model}\), and very rare tokens get too few gradient updates to learn good representations. Modern multilingual models trend toward 128K–256K to fairly represent non-English scripts.

Warning

Pitfall: tokenization quietly shapes everything downstream. A model is bad at arithmetic partly because 1234 might split as 12+34; it struggles to reverse strings because it never sees characters; and per-token pricing means a verbose tokenizer for your language costs you real money. The tokenizer is frozen before training — get it wrong and you pay for the life of the model.

44.2 The GPU memory hierarchy: why bandwidth, not FLOPs, is the bottleneck

Here is the counterintuitive fact that governs LLM systems: a modern GPU can do arithmetic far faster than it can fetch the numbers to operate on. An NVIDIA A100 delivers roughly 312 teraFLOP/s of bf16 compute but only about 2 TB/s of memory bandwidth from its main memory. That gap means many operations finish their math and then sit idle, waiting for the next batch of data to arrive.

Everyday analogy: imagine a world-class chef (the compute units) who can chop and cook almost instantly, but ingredients arrive from the pantry down a single narrow hallway (the memory bus). If a recipe needs lots of trips to the pantry for very little cooking, the chef spends the day standing around waiting for deliveries. The fix is not a faster chef — it is doing more cooking per pantry trip.

GPU memory is a hierarchy, and the speeds differ by orders of magnitude:

Registers / SRAM (on-chip) ~20 MB · ~19 TB/s HBM (GPU main memory) 40–80 GB · ~2 TB/s CPU RAM / NVLink to other GPUs 100s of GB · ~100s GB/s Smaller + faster at top · larger + slower at bottom slow ↕ fast

SRAM is the tiny, blazing-fast scratchpad physically next to the compute cores. HBM (High-Bandwidth Memory) is the gigabytes of “GPU RAM” where your weights and activations live. Every operation must move data from HBM into SRAM, compute, and write results back. When the time spent moving data dominates the time spent computing, the kernel is memory-bound; when compute dominates, it is compute-bound.

Arithmetic intensity, worked

The deciding ratio is arithmetic intensity = FLOPs performed per byte read from HBM. Compare it to the hardware’s ratio of (peak FLOP/s) / (bandwidth). For the A100 that ceiling is \(312\text{e}12 / 2\text{e}12 \approx 156\) FLOPs/byte. An operation must do at least ~156 floating-point ops per byte it loads just to keep the compute units busy.

\[ \text{arithmetic intensity} = \frac{\text{FLOPs performed}}{\text{bytes moved from HBM}} \]

In words: how much useful math you squeeze out of each byte you bothered to fetch from main memory. Also written: \(I = \text{FLOPs} / \text{bytes}\); the op is compute-bound when \(I > \text{(peak FLOP/s)}/\text{(bandwidth)}\) and memory-bound otherwise.

Consider adding two vectors of length \(n\) (bf16, 2 bytes each): you read \(2n\) values, write \(n\), and do \(n\) additions. That is \(n\) FLOPs for \(6n\) bytes — an intensity of \(1/6\). You are nowhere near 156, so vector addition is hopelessly memory-bound; the GPU’s compute sits ~99% idle. A large matrix multiply, by contrast, reuses each loaded value across many multiply-adds and can reach intensities in the hundreds — that is why GPUs love big matmuls and hate elementwise ops.

The same picture drawn against the hardware ceiling is the roofline: below the ridge point you are bandwidth-limited (the slanted roof), above it you are FLOP-limited (the flat roof).

arithmetic intensity (FLOP/byte) → attainable FLOP/s → ridge ≈ 156 memory-bound compute-bound vector add (⅙) big matmul
Tip

Intuition: the lesson of modern GPU programming is stop moving data. Fuse many small operations into one kernel so intermediate results stay in SRAM instead of round-tripping through HBM. Most LLM kernel engineering is, at heart, the art of touching HBM as few times as possible.

44.3 FlashAttention: making attention IO-aware

Standard self-attention is a textbook example of a needlessly memory-bound operation. For a sequence of length \(N\), the naive implementation computes the score matrix \(S = QK^\top\) (size \(N \times N\)), writes it to HBM, reads it back to apply softmax, writes the probabilities, reads them again to multiply by \(V\). That \(N \times N\) matrix is the problem: for \(N = 8192\) it is 67 million entries, written and re-read several times. The expensive part is not the math — it is shuttling that giant matrix in and out of HBM.

Intuition first: imagine grading a giant pile of exams where you only need the class average and the top scorer. The wasteful way is to lay every exam out on a huge table (you need a warehouse), then walk the whole table again to average. The smart way is to process exams in small stacks that fit on your desk, keeping just a running total and a running maximum — you never need the whole pile spread out at once, and you get the exact same answer. FlashAttention is that running-total trick applied to attention.

FlashAttention computes the exact same result without ever materializing the full \(N \times N\) matrix in HBM. The trick is tiling: split \(Q\), \(K\), \(V\) into blocks that fit in SRAM, and process the attention block by block, keeping a running softmax as you go.

flowchart LR
    A["Q, K, V in HBM"] --> B["Load a block of Q,K,V into SRAM"]
    B --> C["Compute block scores QKᵀ in SRAM"]
    C --> D["Update running max + sum<br/>(online softmax)"]
    D --> E["Accumulate weighted V in SRAM"]
    E --> F{"More K,V blocks?"}
    F -->|yes| B
    F -->|no| G["Write final output to HBM once"]

The mathematical enabler is the online softmax, and the plain-English version is short: keep a running answer, and whenever a new block contains a bigger score than anything you’ve seen, gently shrink the old running total to match the new scale before adding the block in. That’s it. You carry two running numbers — the biggest score so far (\(m\)) and the sum of the weights so far (\(\ell\)) — and each block does one rescale-then-add. Because you only ever rescale down (you subtract the largest score before exponentiating), no exponential ever blows up, and the final answer is bit-for-bit identical to running softmax over the whole row at once. The whole pile of scores never has to exist at the same time.

The block-by-block sweep, with each pass keeping only a tile in fast SRAM:

K,V blocks streamed through SRAM (the full N×N is never stored): block 1 … block 5 running max m ↑ running sum ℓ ↑ → one HBM write at the end

The softmax itself, written in the numerically stable shifted form FlashAttention uses block-by-block:

\[ \text{softmax}(s_j) = \frac{e^{s_j - m}}{\sum_k e^{s_k - m}}, \qquad m = \max_k s_k \]

In words: subtract the largest score before exponentiating so nothing overflows, then divide each by the total — and that running max/total can be updated one block at a time. Also written: \(\text{softmax}(s)_j = e^{s_j} / \sum_k e^{s_k}\) (the unshifted form; subtracting \(m\) leaves the result unchanged because the constant cancels top and bottom).

import numpy as np
# online softmax over two blocks == one softmax over the concatenation
def online(blocks):
    m, l, acc = -np.inf, 0.0, 0.0          # running max, normalizer, sum(p*v)
    for s, v in blocks:                     # s: scores, v: values (toy: scalars)
        bm = s.max()
        new_m = max(m, bm)
        l = l*np.exp(m-new_m) + np.exp(s-new_m).sum()
        acc = acc*np.exp(m-new_m) + (np.exp(s-new_m)*v).sum()
        m = new_m
    return acc / l

s = np.array([1.0, 3.0, 2.0, 0.5]); v = np.array([10., 20., 30., 40.])
ref = (np.exp(s-s.max())/np.exp(s-s.max()).sum() * v).sum()
got = online([(s[:2], v[:2]), (s[2:], v[2:])])   # split into 2 blocks
assert abs(ref - got) < 1e-9                       # exact, not approximate
print(round(got, 6))

The payoff is dramatic. Naive attention does \(O(N^2)\) HBM accesses; FlashAttention does \(O(N^2 / M)\) where \(M\) is the SRAM block size — often a 5–20× reduction in memory traffic and a 2–4× wall-clock speedup, with lower memory use because the \(N \times N\) matrix is never stored. It is exact, not an approximation. This single IO-aware rewrite is much of what made long context windows economically feasible.

Using FlashAttention in practice

You rarely call FlashAttention by hand — modern PyTorch ships it behind a single function, and Hugging Face enables it with a config flag.

import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.bfloat16)
k = torch.randn_like(q); v = torch.randn_like(q)

# PyTorch picks a fused FlashAttention kernel automatically when shapes/dtypes allow
with torch.backends.cuda.sdp_kernel(enable_flash=True):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# Hugging Face: opt in per model
# from transformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(
#     "meta-llama/Llama-3.1-8B",
#     attn_implementation="flash_attention_2",
#     torch_dtype=torch.bfloat16)

A note on Triton kernels

FlashAttention is written as a custom GPU kernel — a low-level program that controls exactly which bytes move between HBM and SRAM. Historically that meant CUDA C++, which is painful. Triton (an OpenAI-originated Python-like language) lets you write fused kernels at the block level — you describe how to tile the data and Triton handles the register allocation and memory coalescing. Much modern kernel work, including teaching versions of FlashAttention, is written in Triton precisely because it makes IO-aware programming accessible without dropping all the way to CUDA.

44.4 Scaling laws: spending a compute budget wisely

Suppose someone hands you a fixed compute budget — say, the GPU-hours for \(10^{22}\) floating-point operations. Should you train a huge model on a little data, or a smaller model on a lot of data? Scaling laws answer this empirically: they are power-law relationships showing that loss falls predictably and smoothly as you increase parameters \(N\), data \(D\), and compute \(C\).

The first influential study (Kaplan et al., 2020, OpenAI) concluded that, given more compute, you should mostly make the model bigger and increase data only modestly. Models of that era (including GPT-3 at 175B parameters trained on ~300B tokens) followed this prescription and ended up large and comparatively data-starved.

Two years later, Hoffmann et al., 2022 — the Chinchilla paper (DeepMind) — re-ran the analysis more carefully and overturned the recommendation. Their finding: for compute-optimal training, parameters and tokens should scale in roughly equal proportion. A useful rule of thumb that fell out of it is about 20 tokens per parameter. GPT-3 was badly undertrained by this standard; Chinchilla (70B params, 1.4T tokens) beat the much larger Gopher (280B) using the same compute.

The Chinchilla fit itself is a clean, memorable form — loss decomposes into an irreducible floor plus two power-law terms, one that shrinks with parameters and one that shrinks with data:

\[ L(N, D) = E + \frac{A}{N^{\alpha}} + \frac{B}{D^{\beta}} \]

In words: the loss you reach is an unbeatable floor \(E\) (the inherent randomness of language) plus a penalty for having too few parameters plus a penalty for having seen too little data — and both penalties shrink as smooth power laws. Also written: in log space the two penalty terms are straight lines, \(\log(L - E) \approx \log A - \alpha \log N\) when data is plentiful (and symmetrically \(\log B - \beta \log D\) when parameters are plentiful) — which is why scaling-law plots are log-log straight lines.

with fitted constants where \(\alpha \approx 0.34\) and \(\beta \approx 0.28\) are close enough that the optimal allocation splits compute roughly evenly between growing \(N\) and growing \(D\). \(E\) is the entropy of natural text you can never train away.

The governing approximation everyone uses for the compute cost is the FLOP count for a dense Transformer:

\[ C \approx 6 N D \]

In words: total training cost is about six floating-point operations for every (parameter, token) pair — two for the forward pass and four for the backward pass. Also written: \(C \approx C_{\text{fwd}} + C_{\text{bwd}} = 2ND + 4ND\), so the FLOPs scale linearly in both how big the model is and how much data it sees.

where \(C\) is total training FLOPs, \(N\) is parameter count, and \(D\) is the number of training tokens. (The 6 comes from roughly 2 FLOPs per parameter for the forward pass and 4 for the backward pass.)

Worked: allocating a budget

Take \(C = 10^{22}\) FLOPs. The Chinchilla heuristic says set \(D \approx 20N\). Substitute into \(C = 6ND\):

\[ 10^{22} = 6 \cdot N \cdot 20N = 120\,N^2 \;\Rightarrow\; N^2 = \frac{10^{22}}{120} \approx 8.3\text{e}19 \]

\[ N \approx 9.1\text{e}9 \approx 9\text{ billion parameters}, \qquad D \approx 20N \approx 1.8\text{e}11 = 180\text{ billion tokens.} \]

C = 1e22
# Chinchilla: D = 20 N, and C = 6 N D  ->  C = 120 N^2
N = (C / 120) ** 0.5
D = 20 * N
print(f"params ≈ {N/1e9:.1f}B, tokens ≈ {D/1e9:.0f}B")
# params ≈ 9.1B, tokens ≈ 182B

So a ~9B model on ~180B tokens is compute-optimal for that budget.

Warning

Pitfall: “compute-optimal” optimizes training cost only. If a model will serve billions of inference requests, it is often worth deliberately overtraining a smaller model far past the 20:1 point — it costs more to train but is cheaper to run forever after. This is exactly why many deployed open models (the Llama family, for instance) are trained on trillions of tokens well beyond Chinchilla-optimal: inference economics, not training economics, win.

44.5 Distributed training: splitting the work across many GPUs

A frontier model does not fit on one GPU — not the parameters, not the optimizer state, not the activations. Training is spread across thousands of GPUs using several orthogonal forms of parallelism that compose together. The art is combining them to keep every GPU busy while minimizing the communication between them.

flowchart TB
    subgraph DP["Data Parallel (replicas)"]
      direction LR
      subgraph R1["Replica 1"]
        direction TB
        P1A["Pipeline stage 1<br/>layers 1–8<br/>(tensor-sharded)"]
        P1B["Pipeline stage 2<br/>layers 9–16<br/>(tensor-sharded)"]
        P1A --> P1B
      end
      subgraph R2["Replica 2"]
        direction TB
        P2A["Pipeline stage 1<br/>layers 1–8"]
        P2B["Pipeline stage 2<br/>layers 9–16"]
        P2A --> P2B
      end
    end
    R1 -. "all-reduce gradients" .- R2

Data parallelism (DP) is the simplest: replicate the whole model on each GPU, give each a different slice of the batch, and after the backward pass all-reduce the gradients so every replica stays in sync. It scales throughput but requires the model to fit on one device — which for large models it does not.

Tensor parallelism (TP), also called intra-layer parallelism, splits individual matrices across GPUs. A weight matrix \(W\) in a feed-forward layer is cut into column shards \([W_1, W_2]\); each GPU computes \(xW_i\) on the full input, and the partial outputs are concatenated or summed via an all-reduce. This splits both compute and memory within a layer, but the per-layer communication is heavy, so TP is normally confined to GPUs inside one node connected by fast NVLink.

Pipeline parallelism (PP) splits the model by layers: GPU 0 holds layers 1–8, GPU 1 holds layers 9–16, and so on. Activations flow forward stage to stage, gradients flow back. The danger is the pipeline bubble — stages idling while waiting for the first microbatch to reach them. The fix is to chop the batch into many microbatches so stages stay busy in an overlapping schedule (GPipe, 1F1B).

The bubble is worth seeing concretely. With \(p\) pipeline stages and \(m\) microbatches, the fraction of time a stage sits idle is roughly \((p-1)/(m + p - 1)\) — so the cure is simply more microbatches:

\[ \text{bubble fraction} = \frac{p-1}{m + p - 1} \]

In words: the share of wasted GPU time is the number of “warm-up/drain” stages over the total number of scheduled steps — pour in more microbatches \(m\) and the waste vanishes. Also written: as \(m \to \infty\), \(\frac{p-1}{m+p-1} \to 0\); e.g. with \(p=4\) stages, \(m=4\) microbatches wastes \(3/7 \approx 43\%\), but \(m=32\) wastes only \(3/35 \approx 9\%\).

stage 1234 ■ useful work ▨ bubble (idle, waiting for fill) more microbatches → longer blue tail → bubble shrinks

Sequence / context parallelism splits along the sequence length dimension, so a million-token context can be spread across devices that each hold a slice of the positions. This is what makes very long context windows trainable when activations for the full sequence would not fit on one GPU.

In practice these compose as a grid — a frontier run might use 8-way tensor parallel within a node, 12-way pipeline parallel across nodes, and data parallel on top of that, often called 3D parallelism.

Collective communication: the primitives underneath

Every form of parallelism above is built on a handful of collective operations — group communication patterns implemented by NCCL on NVIDIA hardware. Three matter most:

Collective What it does Used by
all-reduce Sum a tensor across all GPUs, give everyone the result DP gradient sync
all-gather Each GPU has a shard; everyone ends with the full concatenation FSDP weight gather, TP
reduce-scatter Sum across GPUs, but each GPU keeps only its shard of the result FSDP gradient reduction

Intuition: all-reduce is “everyone shouts their number, everyone walks away knowing the total.” A clever ring implementation moves only about \(2(k-1)/k\) times the tensor size per GPU regardless of how many GPUs \(k\) you have, which is why gradient sync stays affordable at scale. Note the identity that ZeRO exploits: all-reduce = reduce-scatter then all-gather — so sharding the optimizer step (next subsection) does not add a fundamentally new communication cost, it just splits the all-reduce into its two halves.

ZeRO / FSDP: sharding the optimizer state

Even with the model replicated, a huge hidden cost is the optimizer state. Training with Adam in mixed precision, each parameter drags along: an fp32 master copy of the weight (4 bytes), an fp32 momentum (4 bytes), and an fp32 variance (4 bytes) — 12 bytes of optimizer state per parameter, on top of the bf16 weight and gradient. For a 7B model that is ~84 GB of optimizer state alone, far past a single GPU.

ZeRO (Zero Redundancy Optimizer, from DeepSpeed) and the equivalent FSDP (Fully Sharded Data Parallel, in PyTorch) notice that under plain data parallelism every replica stores an identical copy of all that state — pure redundancy. They shard it across the data-parallel group instead:

Stage What is sharded Memory per GPU
ZeRO-1 Optimizer states Big drop
ZeRO-2 + gradients Bigger drop
ZeRO-3 / FSDP + parameters themselves Scales with #GPUs

At ZeRO-3 / FSDP, each GPU permanently holds only \(1/k\) of the parameters (for \(k\) GPUs); when a layer is needed, its full weights are all-gathered just in time, used, then discarded. The tradeoff is more communication for far less memory — letting you train models that would otherwise be impossible, at the cost of bandwidth. (This connects to the AI infrastructure discussion in Ch 30.)

In PyTorch this is a thin wrapper around your model — the framework inserts the all-gather/reduce-scatter automatically:

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy

mp = MixedPrecision(param_dtype=torch.bfloat16,      # compute in bf16
                    reduce_dtype=torch.float32)       # reduce grads in fp32

model = FSDP(model,
             sharding_strategy=ShardingStrategy.FULL_SHARD,  # = ZeRO-3
             mixed_precision=mp,
             device_id=torch.cuda.current_device())

# Activation checkpointing pairs naturally with FSDP (see 44.7)
# from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
#     apply_activation_checkpointing)

# The training loop looks completely ordinary — FSDP handles the sharding:
for batch in loader:
    loss = model(**batch).loss
    loss.backward()        # reduce-scatter gradients across shards
    optimizer.step()       # each rank updates only its 1/k of params
    optimizer.zero_grad()

44.6 Mixture-of-Experts: more parameters, not more compute

Every dense model pays the same price for every token: all parameters fire on every input. Mixture-of-Experts (MoE) breaks that link. Replace the single feed-forward block in a Transformer layer with \(E\) parallel feed-forward “experts” plus a small router network. For each token, the router picks the top-\(k\) experts (often \(k=2\)) and sends the token only to those. The other experts stay dark for that token.

Everyday analogy: a hospital does not send every patient to every doctor. A triage nurse (the router) glances at each patient and routes them to the two most relevant specialists. The hospital can employ a hundred specialists (huge total capacity) while any single patient only ever consumes two consultations (small per-patient cost). MoE gives a model a large staff of specialists but only “consults” a couple per token.

The consequence is the headline property: you grow total parameters (knowledge capacity) without growing the active compute per token. A model with 8 experts of which 2 fire has ~4× the parameters of its dense FFN but only ~2× the FFN compute. Mixtral 8×7B, for example, has ~47B total parameters but activates only ~13B per token.

Where this is used: MoE is no longer exotic — it is how several frontier and open models hit their quality-per-FLOP. Mixtral 8×7B (8 experts, top-2) runs at the cost of a ~13B dense model while carrying ~47B parameters’ worth of knowledge; DeepSeek and Qwen ship MoE variants with dozens to hundreds of fine-grained experts. The pattern is always the same trade: pay in memory and routing complexity to buy capacity without latency.

flowchart LR
    T["token hidden state"] --> R["Router<br/>(linear → softmax)"]
    R -->|top-2 scores| G{"gate"}
    G -->|w₁| E2["Expert 2 (FFN)"]
    G -->|w₃| E4["Expert 4 (FFN)"]
    E1["Expert 1"]:::off
    E3["Expert 3"]:::off
    E5["Expert 5"]:::off
    E2 --> S["weighted sum"]
    E4 --> S
    S --> O["layer output"]
    classDef off fill:#eee,stroke:#bbb,color:#999;

The MoE layer output for a token \(x\) is the gated sum over only the chosen experts:

\[ y = \sum_{i \in \text{top-}k} g_i \, E_i(x), \qquad g = \text{softmax}\big(\text{top-}k(x W_r)\big) \]

In words: run the token through the few experts the router picked, then blend their outputs weighted by the router’s confidence in each. Also written: \(y = \sum_{i=1}^{E} g_i\,E_i(x)\) where \(g_i = 0\) for every expert outside the top-\(k\) — a sum over all experts in which all but \(k\) gates are zeroed out.

The load-balancing problem

Left alone, the router collapses: it learns to favor a few experts, those get more gradient, get better, attract even more tokens — a rich-get-richer death spiral that leaves most experts untrained and wastes the parameters you paid for. The standard fix is an auxiliary load-balancing loss added to the training objective, which penalizes uneven routing and pushes traffic toward a uniform spread across experts.

A common formulation (Switch Transformer) for \(E\) experts over a batch is:

\[ \mathcal{L}_{aux} = \alpha \, E \sum_{i=1}^{E} f_i \, P_i \]

In words: multiply, for each expert, how many tokens it actually got by how confident the router was about it, sum across experts — this is smallest when every expert gets an equal share, so minimizing it spreads the load. Also written: \(\mathcal{L}_{aux} = \alpha\,E\,\langle f, P\rangle\), the (scaled) dot product of the routing-fraction vector \(f\) and the mean-probability vector \(P\); its minimum over the simplex is \(\alpha\) at \(f_i = P_i = 1/E\).

where \(f_i\) is the fraction of tokens actually routed to expert \(i\) and \(P_i\) is the average router probability assigned to expert \(i\). The product is minimized when both are uniform (\(1/E\) each), so the loss nudges the router toward balanced utilization; \(\alpha\) is a small weight (e.g. 0.01).

import numpy as np
# 6 tokens, 4 experts: router probabilities (rows sum to 1)
P = np.array([[.7,.1,.1,.1],[.6,.2,.1,.1],[.65,.15,.1,.1],
              [.5,.3,.1,.1],[.55,.25,.1,.1],[.6,.2,.1,.1]])
top1 = P.argmax(1)                         # everyone picks expert 0 -> imbalance
f = np.bincount(top1, minlength=4) / len(P)   # fraction routed
Pbar = P.mean(0)                              # mean router prob
aux = 4 * np.sum(f * Pbar)
print("frac routed:", f, " aux loss:", round(aux,3))
# frac routed: [1. 0. 0. 0.]  aux loss: 2.4   <- far above the 1.0 of perfect balance

The aux loss of 2.4 (versus 1.0 for perfectly uniform routing) is exactly the signal that pushes the router to stop dumping every token on expert 0. Because routing decisions are discrete and capacity per expert is capped, real MoE systems also set an expert capacity limit and simply drop (skip the FFN for) tokens that overflow a full expert — another reason the balancing loss earns its keep.

Warning

Pitfall: MoE trades compute for memory and complexity. All \(E\) experts must still fit in memory even though only \(k\) run — so MoE saves FLOPs, not VRAM. Routing also adds an all-to-all communication step (tokens must travel to whichever GPU holds their chosen expert), which can dominate at scale, and uneven routing causes load spikes that are a genuine systems headache to serve.

44.7 The training stack: precision, checkpointing, and memory math

Beyond parallelism, a handful of techniques decide whether a training run fits in memory and finishes on time.

Mixed precision. Storing and computing in 16-bit instead of 32-bit halves memory and roughly doubles throughput on tensor-core hardware. The standard today is bf16 (bfloat16): it keeps the full 8-bit exponent range of fp32 — so it does not overflow or underflow the way fp16 does — while sacrificing mantissa precision. A master copy of the weights is kept in fp32 for the optimizer update; the forward and backward passes run in bf16. (With the older fp16, the narrow exponent range forced an extra trick, loss scaling — multiply the loss by a large constant before backprop so tiny gradients don’t flush to zero, then divide it back out. bf16’s wide range mostly removes the need.) The frontier is pushing further to fp8 (8-bit floats) for the matmuls, with careful per-tensor scaling to preserve dynamic range.

Format Bits (sign/exp/mantissa) Range Precision Use
fp32 1 / 8 / 23 wide high master weights, optimizer
fp16 1 / 5 / 10 narrow medium older mixed precision (+ loss scaling)
bf16 1 / 8 / 7 wide low default compute today
fp8 (E4M3) 1 / 4 / 3 narrow very low frontier matmuls, per-tensor scaled

Gradient (activation) checkpointing. The backward pass needs the activations from the forward pass, and storing every layer’s activations is often the single biggest memory consumer. Checkpointing keeps only a sparse set of activations and recomputes the rest during the backward pass — trading roughly 30% more compute for a large drop in activation memory. It is the standard knob for fitting longer sequences or bigger batches.

In PyTorch, mixed precision and checkpointing are each a few lines:

import torch
from torch.utils.checkpoint import checkpoint

# Mixed precision: autocast picks bf16 for matmuls, keeps reductions in fp32
for batch in loader:
    with torch.autocast("cuda", dtype=torch.bfloat16):
        loss = model(**batch).loss
    loss.backward()           # bf16 needs no GradScaler; fp16 would
    optimizer.step(); optimizer.zero_grad()

# Activation checkpointing: wrap an expensive block to recompute it in backward
def block_forward(x):
    return transformer_block(x)
y = checkpoint(block_forward, x, use_reentrant=False)   # ~30% more compute, big memory cut

Worked: optimizer-state memory math

Let’s size the memory for a 7B-parameter model trained with Adam in bf16. Per parameter:

Item Precision Bytes/param
Weights (compute copy) bf16 2
Gradients bf16 2
Master weights fp32 4
Adam momentum \(m\) fp32 4
Adam variance \(v\) fp32 4
Total 16
params = 7e9
bytes_per_param = 2 + 2 + 4 + 4 + 4      # weights, grads, master, m, v
gb = params * bytes_per_param / 1e9
print(f"{gb:.0f} GB just for model + optimizer state")
# 112 GB

That 112 GB — before you account for a single byte of activations — already exceeds one 80 GB A100. This number is the entire reason ZeRO/FSDP sharding (44.5) exists: a 7B model is “small” by frontier standards yet its training footprint blows past a top-end GPU. Activations add more on top, which is what gradient checkpointing claws back.

44.8 The data pipeline: web-scale curation

A frontier model is only as good as the trillions of tokens it eats, and raw web text is mostly garbage. The data pipeline that turns a web crawl into a training corpus is unglamorous and is widely considered one of the biggest real differentiators between labs — more than architecture. The stages, in order:

flowchart LR
    A["Raw web crawl<br/>(e.g. Common Crawl)"] --> B["Text extraction<br/>strip HTML/boilerplate"]
    B --> C["Language ID<br/>+ filtering"]
    C --> D["Quality filtering<br/>heuristics + classifier"]
    D --> E["Deduplication<br/>exact + near (MinHash)"]
    E --> F["Decontamination<br/>remove eval/benchmark data"]
    F --> G["Tokenize + shuffle<br/>+ mix sources"]
    G --> H["Training-ready shards"]

Quality filtering removes machine-generated spam, keyword-stuffed pages, and gibberish — early pipelines used hand-built heuristics (bad punctuation ratios, too few real words), later ones a learned classifier scoring how “reference-like” a page is. Deduplication is the highest-leverage step: the web is enormously repetitive, and training on duplicates wastes compute, encourages memorization, and hurts generalization. Exact dedup catches identical documents; near-dedup via MinHash + locality-sensitive hashing catches the far more common near-copies (the same article on a hundred mirror sites). Decontamination removes any text overlapping your evaluation benchmarks — skip it and your reported scores are inflated by the model having literally seen the test set.

The one piece worth seeing mechanically is near-dedup. MinHash estimates the Jaccard similarity of two documents’ shingle (n-gram) sets cheaply: hash each document’s shingles many ways, keep the minimum hash per function, and the fraction of matching minima estimates the overlap — no pairwise comparison of full documents required.

\[ J(A, B) = \frac{|A \cap B|}{|A \cup B|} \;\approx\; \Pr\big[\min h(A) = \min h(B)\big] \]

In words: the true overlap between two shingle sets (shared shingles over total distinct shingles) equals the probability that a random hash function picks the same minimum from both — so counting matching minima across many hashes estimates similarity without ever comparing documents directly. Also written: \(\hat{J} = \frac{1}{n}\sum_{i=1}^{n}\mathbb{1}[\min h_i(A) = \min h_i(B)]\), the fraction of the \(n\) hash functions whose minima agree.

import numpy as np
def shingles(text, k=5):                    # k-word shingles as a set
    w = text.split(); return {" ".join(w[i:i+k]) for i in range(len(w)-k+1)}
def minhash(s, seeds):                       # one min per hash seed
    return np.array([min((hash((seed, sh)) & 0xffffffff) for sh in s) for seed in seeds])

seeds = range(128)
a = "the quick brown fox jumps over the lazy dog every single morning"
b = "the quick brown fox jumps over the lazy dog every single evening"   # near-copy
c = "completely unrelated text about distributed training and gpus here"
ma, mb, mc = (minhash(shingles(t), seeds) for t in (a, b, c))
print("a~b est:", (ma == mb).mean(), "  a~c est:", (ma == mc).mean())
# a~b est: ~0.6 (flagged as near-dup)   a~c est: ~0.0 (kept)

Finally, mixing: you don’t train on the raw source proportions. Curated high-quality sources (Wikipedia, books, code) are typically up-weighted relative to bulk web text, and the mixture ratios are themselves a tuned hyperparameter that materially affects the final model.

Tip

Intuition: more tokens is not better; more unique, high-quality tokens is better. A smaller, aggressively deduplicated and filtered corpus routinely beats a larger raw one at the same compute. Data work is where a lot of the quiet quality gains in modern models actually come from.

44.9 Parameter-efficient fine-tuning: adapting without retraining

The 112 GB above was for training a 7B model from scratch — full fine-tuning of an existing model costs the same, because you still hold the optimizer state for every parameter. But most people who adapt an open model do not want to move all 7 billion weights; they want to teach it one new skill or style on a modest GPU. Parameter-efficient fine-tuning (PEFT) freezes the pretrained weights and trains a tiny number of new ones. (Chapter 45 is devoted to transfer learning, fine-tuning and PEFT; this section is the systems-level summary.)

Intuition first: you don’t rewrite an entire textbook to add a chapter of margin notes. PEFT keeps the original “textbook” (the frozen base model) untouched and trains a thin set of “sticky notes” — a few million parameters — that nudge the model’s behavior. Because there is no optimizer state for the frozen 99%+ of weights, the memory cost collapses.

The dominant method is LoRA (Low-Rank Adaptation). The key observation: the update a fine-tune applies to a big weight matrix \(W \in \mathbb{R}^{d \times d}\) is empirically low-rank — it can be well approximated by the product of two skinny matrices. So instead of learning \(\Delta W\) directly (which is \(d^2\) numbers), you learn \(\Delta W = BA\) where \(A \in \mathbb{R}^{r \times d}\) and \(B \in \mathbb{R}^{d \times r}\) with rank \(r \ll d\).

\[ W' = W + \Delta W = W + \frac{\alpha}{r}\,BA \]

In words: keep the frozen weight and add a small correction built from two thin matrices, scaled by \(\alpha/r\) — you only ever train \(A\) and \(B\). Also written: the layer’s output is \(h = Wx + \tfrac{\alpha}{r}B(Ax)\); with \(d=4096\) and \(r=8\) that is \(2 \cdot 4096 \cdot 8 \approx 66\text{K}\) trainable params instead of \(4096^2 \approx 16.8\text{M}\) — a ~250× reduction per matrix.

W frozen d×d + B (d×r) A (r×d) trainable, tiny (rank r) ΔW = B·A (low rank)

The wins compound: tiny optimizer state (only \(A,B\) get momentum/variance), the frozen base can be loaded once and quantized (this is QLoRA — a 4-bit frozen base plus LoRA adapters fine-tunes a 70B model on a single 48 GB GPU), and adapters are a few megabytes you can swap per task without touching the base. At inference you can fold \(BA\) back into \(W\) so there is zero added latency.

# LoRA fine-tuning with Hugging Face PEFT (+ 4-bit base = QLoRA)
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

bnb = BitsAndBytesConfig(load_in_4bit=True,
                         bnb_4bit_quant_type="nf4",
                         bnb_4bit_compute_dtype="bfloat16")
base = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B", quantization_config=bnb, device_map="auto")

lora = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05,
                  target_modules=["q_proj", "v_proj"],  # adapt attention projections
                  task_type="CAUSAL_LM")
model = get_peft_model(base, lora)
model.print_trainable_parameters()
# trainable params: ~3.4M || all params: ~8B || trainable%: 0.04
# ...then train with a normal Trainer; only A,B receive gradients
Tip

Intuition: full fine-tuning and LoRA usually land at similar quality for adapting a model to a new domain or instruction style, because the adaptation really is low-rank. Reach for full fine-tuning only when you are changing the model’s fundamental capabilities, not its surface behavior.

44.10 Post-training: from base model to assistant

A model that has only done next-token prediction is a brilliant autocomplete, not an assistant: ask it a question and it might continue with more questions, because that is what often follows a question on the web. Turning a pretrained base into something helpful, honest, and harmless is post-training, and it has two main stages. (Chapter 46 covers alignment and evaluation in depth.)

Supervised fine-tuning (SFT). First you fine-tune on a curated set of (prompt, ideal-response) pairs — high-quality demonstrations of the behavior you want: answering questions, following instructions, refusing harmful requests, using a consistent format. This is ordinary supervised next-token training, just on demonstration data rather than raw web text. It teaches the model the shape of being an assistant.

Preference optimization (RLHF and successors). SFT alone can’t easily teach “this answer is better than that one” when both are plausible. The fix is to learn from human preferences: collect pairs where a human marked one response as preferred over another, then push the model toward the winners and away from the losers.

flowchart LR
    A["Pretrained<br/>base model"] --> B["SFT<br/>(demonstrations)"]
    B --> C["Collect preference pairs<br/>(human picks better of two)"]
    C --> D["RLHF: train reward model<br/>+ PPO against it"]
    C --> E["DPO: optimize preferences<br/>directly, no reward model"]
    D --> F["Aligned<br/>assistant"]
    E --> F

Classic RLHF does this in two steps: train a reward model to predict which response a human prefers, then use reinforcement learning (PPO) to fine-tune the LLM to maximize that reward — while a KL penalty keeps it from drifting too far from the SFT model and “hacking” the reward. It works but is finicky: you maintain a separate reward model and an unstable RL loop.

Direct Preference Optimization (DPO) is the now-common simplification. A clever derivation shows you can skip the explicit reward model and the RL entirely, optimizing the policy directly on preference pairs with a simple classification-style loss:

\[ \mathcal{L}_{DPO} = -\,\mathbb{E}_{(x, y_w, y_l)}\Big[\log \sigma\Big(\beta \log \tfrac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \tfrac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)}\Big)\Big] \]

In words: raise the model’s probability of the human-preferred answer \(y_w\) relative to the rejected one \(y_l\), measured against a frozen reference model so it doesn’t wander off — it’s just a logistic loss on the difference of log-probability ratios. Also written: with the per-response “implicit reward” \(\hat r(y) = \beta\log\frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}\), the loss is simply \(-\log\sigma(\hat r(y_w) - \hat r(y_l))\) — the same form as training a binary classifier to prefer the winner.

# DPO with the TRL library — no reward model, no PPO loop
from trl import DPOTrainer, DPOConfig
# dataset rows: {"prompt": ..., "chosen": ..., "rejected": ...}
trainer = DPOTrainer(
    model=policy_model,            # the SFT model being tuned
    ref_model=reference_model,     # frozen SFT copy (the π_ref baseline)
    args=DPOConfig(beta=0.1, learning_rate=5e-7, output_dir="dpo-out"),
    train_dataset=pref_dataset,
    processing_class=tokenizer,
)
trainer.train()
Note

This pairs naturally with PEFT (44.9): SFT and DPO are very often done with LoRA adapters on a frozen, quantized base, which is how a small team can align a 70B model without a training cluster. Newer variants (RLAIF, where an LLM provides the preference labels; GRPO, used for reasoning models) follow the same skeleton — collect preferences, optimize toward them.

44.11 Inference systems: serving the model efficiently

Training happens once; inference happens billions of times, so its economics dominate the lifetime cost of a model. LLM inference has a peculiar structure: it runs in two distinct phases. Prefill ingests the whole prompt in one parallel forward pass — this is compute-bound and fast per token. Decode then emits one token at a time, and each step re-reads the entire model from HBM to produce a single token — this is memory-bound and is where the latency lives. Recognizing that decode is memory-bound is the key that unlocks almost every trick below. (Chapter 47 treats model serving and deployment in full; here is how these pieces fit the systems picture.)

Everyday analogy for the two phases: prefill is like reading a whole letter at a glance — your eyes take in the page in parallel. Decode is like writing the reply by hand, one word at a time, where each word depends on all the words before it. Reading is fast and parallel; writing is serial and slow, and that is where you spend your afternoon.

Here is the decode bottleneck made visible: to emit each single token, the entire model has to be hauled out of HBM again — the whole weight matrix pulses across the bus for one word of output.

HBM 14 GB of weights ~2 TB/s bus compute mostly idle ▸ one full sweep of the model = one output token

KV cache. Generating token \(t\) requires attending to all previous tokens’ keys and values. Recomputing them every step would be quadratic; instead you cache the K and V tensors for every past token and reuse them. This is the single most important inference optimization — but the cache grows linearly with sequence length and batch size and quickly becomes the dominant memory consumer, larger than the model itself for long contexts. Its size is concrete and worth being able to estimate:

\[ \text{KV bytes} = 2 \times L \times n_{kv} \times d_{head} \times N_{tok} \times \text{bytes} \]

In words: the cache stores a key and a value (the leading 2) for every layer, every key/value head, every head dimension, and every token in the sequence — multiply them all by the bytes per number. Also written: \(\text{KV bytes} = 2\,L\,n_{kv}\,d_{head}\,N_{tok}\,b\); since \(n_{kv}\,d_{head} = d_{kv}\) (the total key/value width per layer), this is \(2\,L\,d_{kv}\,N_{tok}\,b\).

The leading 2 is K and V, \(L\) is layers, \(n_{kv}\) the number of key/value heads, \(N_{tok}\) the sequence length. For a 7B-class model (32 layers, 32 KV heads, 128 head dim) at bf16, a single 8K-token sequence costs \(2 \times 32 \times 32 \times 128 \times 8192 \times 2 \approx 4.3\) GB — which is exactly why grouped-query attention (fewer \(n_{kv}\)) and cache quantization matter so much, and why memory, not compute, caps your batch size.

PagedAttention / vLLM. The KV cache is awkward to store: a naive contiguous allocation per request wastes huge amounts of memory to fragmentation, because you must reserve for the maximum possible length. PagedAttention (the core idea in vLLM) borrows virtual-memory paging from operating systems: the cache is split into fixed-size blocks allocated on demand and tracked by a block table, so memory is used only as tokens are actually generated. The result is far higher GPU utilization and bigger effective batch sizes.

Continuous batching. Requests arrive at different times and finish at different lengths. Static batching forces fast requests to wait for the slowest in their batch. Continuous (in-flight) batching instead lets a finished sequence drop out and a newly arrived one slot into the batch at the next step, keeping the GPU saturated. It is one of the largest throughput wins in production serving.

In practice you get PagedAttention and continuous batching for free by serving with vLLM rather than implementing them:

# vLLM applies PagedAttention + continuous batching automatically
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
          quantization="awq",          # int4 weights -> fewer HBM bytes/token
          gpu_memory_utilization=0.9)   # how much VRAM to give the KV cache

params = SamplingParams(temperature=0.7, max_tokens=256)
# Pass many prompts at once; the scheduler batches them in-flight
outputs = llm.generate(["Explain the KV cache.", "What is FlashAttention?"], params)
for o in outputs:
    print(o.outputs[0].text)

Quantization. Since decode is memory-bound, shrinking the weights directly speeds it up. Post-training quantization to int8 or int4 (methods like GPTQ and AWQ) cuts the bytes read from HBM per token — roughly halving or quartering the dominant cost — for a small, often negligible, quality loss. A 70B model that needs 140 GB in bf16 fits in ~35 GB at int4.

Speculative decoding. The clever one. Run a small, cheap draft model to propose several tokens ahead, then let the big target model verify all of them in a single forward pass. Because verifying \(k\) tokens in parallel costs about the same as generating one (decode is memory-bound, remember — the weights are read once either way), every accepted draft token is nearly free.

flowchart LR
    A["context"] --> B["Draft model<br/>proposes 4 tokens<br/>(cheap, fast)"]
    B --> C["Target model<br/>verifies all 4 in<br/>ONE forward pass"]
    C --> D{"accept prefix<br/>that matches?"}
    D -->|"3 of 4 ok"| E["keep 3 tokens<br/>+ 1 free correction"]
    D -->|"all rejected"| F["fall back to 1<br/>target token"]
    E --> A
    F --> A

The output is provably identical to what the target model would have produced alone — the draft only proposes, the target always decides (via a rejection-sampling check that preserves the target’s exact distribution) — so it is a pure latency win (commonly 2–3×) with no quality cost. When drafts are accurate, you get several tokens per expensive forward pass; when they are wrong, you fall back gracefully to normal decoding.

Throughput vs. latency

A last framing that ties the serving knobs together: you are always trading two metrics. Latency is what one user feels — time-to-first-token (dominated by prefill) and inter-token latency (dominated by decode). Throughput is total tokens/second across all users, which is what determines cost-per-token. Bigger batches raise throughput but worsen per-user latency; speculative decoding cuts latency but spends extra compute; quantization helps both. There is no single optimum — you tune the batch size and these techniques to a target latency SLA, then maximize throughput under it.

A small worked example makes the cost concrete. Suppose decode reads the full 7B model (14 GB in bf16) from HBM once per token on an A100 (~2 TB/s):

model_bytes = 14e9          # 7B params * 2 bytes (bf16)
bandwidth   = 2e12          # A100 HBM ~2 TB/s
sec_per_tok = model_bytes / bandwidth
print(f"~{1/sec_per_tok:.0f} tokens/sec, single stream (memory-bound floor)")
# ~143 tokens/sec  -> int4 quantization (3.5 GB) would roughly 4x this

The number is set by bandwidth divided by bytes read, not by FLOPs — which is precisely why quantization (fewer bytes) and large batches (amortize the read across many sequences) are the two levers that move it.

44.1 — Which parallelism should I use?

When a model or its training state won’t fit, the choice of parallelism follows a short decision tree — reach for the cheapest-communication option that solves your actual constraint:

flowchart TD
    A["Model + optimizer state<br/>won't fit on one GPU?"] -->|"No, just want<br/>faster throughput"| DP["Data Parallel<br/>(replicate, all-reduce grads)"]
    A -->|"Yes — out of memory"| B{"What overflows?"}
    B -->|"Optimizer state /<br/>gradients / params"| Z["ZeRO / FSDP<br/>(shard state across DP group)"]
    B -->|"A single layer's<br/>weights too big"| TP["Tensor Parallel<br/>(intra-node, NVLink)"]
    B -->|"Whole model too deep<br/>for one node"| PP["Pipeline Parallel<br/>(split by layers, + microbatches)"]
    B -->|"Activations for a<br/>very long sequence"| SP["Sequence / Context Parallel"]
    Z -.->|"compose at frontier scale"| TP
    TP -.-> PP

Rule of thumb: start with FSDP/ZeRO (cheapest to adopt, biggest memory win), add tensor parallelism only inside a fast-NVLink node, and pipeline parallelism only when the model is too deep to fit even sharded — combining all three is 3D parallelism.

44.2 — Quick reference

Term / formula Meaning When / why it matters
Byte-level BPE Subword tokenizer over 256 bytes; merge frequent pairs Default tokenizer; guarantees no out-of-vocabulary token
Vocab size (\(V \times d_{model}\)) Tokens/word vs. embedding cost vs. rare-token quality Bigger = shorter sequences but pricier embeddings; ~128K–256K for multilingual
Arithmetic intensity \(I=\text{FLOPs}/\text{bytes}\) Math done per byte read from HBM \(I <\) ridge (~156 on A100) ⇒ memory-bound; the reason to fuse kernels
FlashAttention Exact attention via SRAM tiling + online softmax Never stores the \(N\times N\) matrix; enables long context, 2–4× faster
Chinchilla \(L=E+A/N^\alpha+B/D^\beta\) Loss = floor + parameter penalty + data penalty Scaling-law fit; predicts loss from \(N\) and \(D\)
\(C \approx 6ND\) Training FLOPs ≈ 6 × params × tokens Project model+data size from a compute budget (~20 tokens/param)
Data / Tensor / Pipeline parallel Split batch / matrices / layers across GPUs Compose as 3D parallelism to fit and scale frontier training
Pipeline bubble \(\frac{p-1}{m+p-1}\) Fraction of idle stage time More microbatches \(m\) ⇒ less waste
ZeRO / FSDP Shard optimizer state, grads, params across DP group Fit models far larger than one GPU; trade bandwidth for memory
all-reduce / all-gather / reduce-scatter Collective comm primitives (NCCL) Underlie DP grad sync, FSDP gather/reduce
MoE \(y=\sum_{i\in\text{top-}k} g_i E_i(x)\) Route each token to a few experts More params, not more compute/token; needs load-balancing loss
Mixed precision (bf16 / fp8) Compute in 16/8-bit, fp32 master weights Halves memory, ~2× throughput; bf16 avoids fp16 loss-scaling
Gradient checkpointing Recompute activations in backward ~30% more compute for a big activation-memory cut
Optimizer memory (16 B/param) bf16 w+g + fp32 master+m+v with Adam 7B model ≈ 112 GB before activations — why FSDP exists
MinHash + LSH Estimate Jaccard similarity cheaply Near-deduplication of web corpora — highest-leverage data step
LoRA \(W'=W+\frac{\alpha}{r}BA\) Train a low-rank update, freeze base ~250×/matrix fewer trainable params; QLoRA adds 4-bit base
SFT → RLHF / DPO Demonstrations, then preference optimization Turns a base autocomplete into a helpful assistant
KV cache (\(2 L\,n_{kv} d_{head} N_{tok} b\)) Cache past keys/values to avoid recompute Top inference optimization; dominates memory at long context
PagedAttention / continuous batching OS-style paging + in-flight batch swaps vLLM defaults; raise GPU utilization and throughput
Speculative decoding Draft proposes, target verifies in one pass 2–3× latency win, output provably identical

44.3 — Key takeaways

  • Tokenization is foundational and frozen. Byte-level BPE guarantees no out-of-vocabulary tokens (WordPiece and Unigram/SentencePiece are the alternatives); vocabulary size trades sequence length and embedding cost against rare-token quality. Tokenizer quirks explain a surprising amount of model behavior (bad arithmetic, per-language cost).
  • Memory bandwidth, not FLOPs, is usually the bottleneck. The HBM↔︎SRAM gap means most LLM kernel work is about moving less data. Arithmetic intensity versus the roofline ridge decides whether an op is compute- or memory-bound.
  • FlashAttention computes exact attention without materializing the \(N\times N\) matrix, via tiling and online softmax — a key enabler of long context. Custom kernels are increasingly written in Triton; in practice you call it via scaled_dot_product_attention or attn_implementation="flash_attention_2".
  • Scaling laws make loss predictable via \(L(N,D)=E+A/N^{\alpha}+B/D^{\beta}\). Chinchilla’s ~20 tokens/parameter overturned Kaplan; \(C \approx 6ND\) lets you project model and data size from a compute budget — though inference economics often justify overtraining smaller models.
  • Distributed training composes data, tensor, pipeline, and sequence parallelism (3D parallelism); the pipeline bubble shrinks with more microbatches; the collective primitives (all-reduce, all-gather, reduce-scatter) underlie all of it; ZeRO/FSDP shard optimizer state, gradients, and parameters to fit models that dwarf a single GPU.
  • Mixture-of-Experts grows parameters without proportional compute via top-\(k\) routing, kept healthy by a load-balancing loss and expert-capacity limits — but all experts must fit in memory and routing adds all-to-all communication.
  • The training stack (bf16/fp8 mixed precision, loss scaling for fp16, gradient checkpointing, Adam’s 16 bytes/param) decides what fits; a “small” 7B model needs ~112 GB before activations.
  • Data quality — dedup (MinHash/LSH), filtering, decontamination, mixing — is a top differentiator; unique high-quality tokens beat raw volume.
  • PEFT/LoRA freezes the base and trains a low-rank update \(\Delta W = \tfrac{\alpha}{r}BA\), slashing trainable parameters ~250×; QLoRA adds a 4-bit frozen base so a 70B model fine-tunes on one GPU.
  • Post-training turns a base model into an assistant: supervised fine-tuning on demonstrations, then preference optimization (RLHF with a reward model + PPO, or the simpler reward-model-free DPO).
  • Inference splits into compute-bound prefill and memory-bound decode; KV cache (size it with the formula), PagedAttention/vLLM, continuous batching, quantization, and speculative decoding together trade latency against throughput to make serving affordable.

44.4 — See also

  • Chapter 23 — Large Language Models: the Transformer decoder architecture, pretraining objective, and what these systems are built to train and serve.
  • Chapter 17 — Attention & Transformers: the attention math that FlashAttention (44.3) rewrites to be IO-aware.
  • Chapter 30 — AI Infrastructure & Efficient Inference: the broader infrastructure context for distributed training (44.5) and the KV-cache/quantization serving techniques in 44.11.
  • Chapter 45 — Post-Training I (Transfer / Fine-Tuning / PEFT): the full treatment of LoRA/QLoRA and parameter-efficient fine-tuning summarized in 44.9.
  • Chapter 46 — Post-Training II (Alignment / Evaluation): RLHF, DPO, and the alignment pipeline that 44.10 introduces.
  • Chapter 47 — Model Serving & Deployment: the production-serving context for prefill/decode, PagedAttention, and continuous batching in 44.11.
  • Chapter 03 — Optimization: Adam and the optimizer state whose memory footprint (44.7) drives ZeRO/FSDP.

↪ The thread continues → Chapter 45 · 🎚️ Post-Training I — Transfer, Fine-Tuning & PEFT

You can build a base model; but a raw base model is a brilliant amnesiac with no manners. Post-training teaches it to specialize cheaply.


📖 All chapters  |  ← 43 · 🔎 Information Retrieval & Data Mining  |  45 · 🎚️ Post-Training I →

 

© Kader Mohideen