graph TD
A[Workload: same op over huge arrays?] -->|No, branchy serial logic| B[CPU: few fast cores]
A -->|Yes, parallel matmuls| C{Need flexibility?}
C -->|Yes, varied ops + ecosystem| D[GPU: 1000s of cores + tensor cores]
C -->|No, pure matmul at scale| E[TPU: systolic array]
Chapter 30 — 🚀 AI Infrastructure & Efficient Inference
📖 All chapters | ← 29 · 🔧 MLOps & Deployment | 31 · 🧰 Tools & Frameworks →
📚 Jump to any chapter
🧮 Mathematical Foundations
- 01 · 🧮 Linear Algebra
- 02 · ∂ Calculus & Differentiation
- 03 · 📉 Optimization
- 04 · 🎲 Probability & Statistics
🧭 The ML Workflow
🧩 Classical Machine Learning
- 08 · 📈 Regression
- 09 · 📐 Classification Algorithms
- 10 · 🌳 Ensemble Methods
- 11 · 🔮 Clustering & Unsupervised Learning
- 12 · 🎯 Model Evaluation & Tuning
🎲 Probabilistic Models
🧠 Deep Learning
- 14 · 🧠 Neural Networks (Core)
- 15 · 🖼️ Convolutional Neural Networks
- 16 · 🔁 Recurrent & Sequence Models
- 17 · ⚡ Attention & Transformers
- 18 · 🎨 Generative Models
🗣️ Applied AI: Vision, Language, Audio & Time
- 19 · 👁️ Computer Vision
- 20 · 💬 Natural Language Processing
- 21 · 🔊 Speech & Audio Processing
- 22 · ⏳ Time Series & Forecasting
- 23 · 📚 Large Language Models
- 24 · 🌈 Multimodal AI
🕹️ Reinforcement Learning
🛠️ Applied ML Systems & Industries
🚀 Production, Tooling & Infrastructure
📚 Classical & Symbolic AI
- 32 · 🧭 Search & Problem Solving
- 33 · 📖 Knowledge Representation & Reasoning
- 34 · 🗺️ Planning, Constraint Satisfaction & Game Playing
- 35 · 🧬 Evolutionary Computation & Metaheuristics
⚖️ Responsible AI & Frontier
- 36 · 🔍 Explainable AI & Interpretability
- 37 · 🧷 Causal Inference
- 38 · ⚖️ AI Ethics, Fairness & Safety
- 39 · 🌠 Frontier & Emerging Directions
🎓 Advanced & Specialized Topics
- 40 · 🔗 Graph Machine Learning
- 41 · 🤖 Robotics & Autonomy
- 42 · 📐 Learning Theory
- 43 · 🔎 Information Retrieval & Data Mining
- 44 · 🏗️ LLM Systems: Building LLMs from Scratch
🎚️ Post-Training & Fine-Tuning
- 45 · 🎚️ Post-Training I — Transfer, Fine-Tuning & PEFT
- 46 · 🏅 Post-Training II — Alignment & Evaluation
🚢 Model Serving & Deployment
Training a model is only half the story; running it cheaply and quickly for millions of requests is the other half, and it is where most of the money goes. This chapter is about the hardware that makes large models possible and the tricks that make them affordable to serve — from how GPUs differ from CPUs, to spreading one giant model across hundreds of chips, to squeezing a 70-billion-parameter model onto a single card. It sits at the production end of the ML lifecycle: the bridge between a trained checkpoint and a service that answers in milliseconds.
🧭 In context: Production, Tooling & Infrastructure · used to train and serve large models fast and cheaply · the one key idea: modern AI is bottlenecked by memory bandwidth and memory capacity, not raw compute — almost every technique here is about moving fewer bytes.
💡 Remember this: Serving large models is a memory problem, not a math problem — almost every speedup here works by moving fewer bytes, not by doing less arithmetic.
30.1 — GPUs vs TPUs vs CPUs and why parallelism/memory matter
Imagine you need to add a million pairs of numbers. A CPU (central processing unit) is a brilliant short-order cook: a handful of very fast, very clever cores that handle one complicated order at a time, with deep branch prediction and huge caches. A GPU (graphics processing unit) is a stadium kitchen: thousands of simpler cooks who can only follow the same recipe in lockstep, but there are thousands of them. For the million additions — all identical, no branching — the stadium wins by a mile. Deep learning is exactly this: the same matrix multiply applied across enormous arrays, which is embarrassingly parallel.
The core operation is the matrix multiply (matmul, or in neural nets a GEMM — general matrix-multiply). Multiplying an \(m\times k\) matrix by a \(k\times n\) matrix is \(m\cdot k\cdot n\) multiply-adds, all independent. GPUs add dedicated tensor cores that do a small matmul (e.g. a \(4\times4\) block) in one instruction, giving tens to hundreds of TFLOPs.
The matmul cost formula. For a single GEMM the number of arithmetic operations is
\[\text{FLOPs} = 2 \cdot m \cdot k \cdot n\]
In words: multiplying an \(m\times k\) matrix by a \(k\times n\) matrix costs about two operations (one multiply, one add) for every one of the \(m\cdot k\cdot n\) inner products that fill the output.
Also written: \(\text{FLOPs} = 2\,mkn = 2\sum_{i=1}^{m}\sum_{j=1}^{n}\sum_{l=1}^{k} 1\) — the triple sum just counts every multiply-add that builds the \(m\times n\) result.
A TPU (tensor processing unit, Google’s custom chip) pushes this further with a systolic array: a grid of multiply-accumulate cells where data flows through and partial sums accumulate as they pass, so one matmul streams through the whole grid without re-reading operands from memory. TPUs are less flexible than GPUs but extremely efficient for the matmul-heavy workloads of training and inference.
Why memory matters more than FLOPs. Here is the part that surprises people. A modern GPU can do, say, 1000 TFLOP/s of compute but only read ~3 TB/s from its HBM (high-bandwidth memory). The ratio — FLOPs per byte the chip can sustain — is called the arithmetic intensity break-even. If your operation does few FLOPs per byte loaded, the cores sit idle waiting for data: you are memory-bound. If it does many FLOPs per byte, you are compute-bound. This is the roofline model.
The roofline says the achievable speed is whichever ceiling you hit first — the memory wall or the compute roof:
\[\text{achievable FLOP/s} = \min\big(\,\text{peak FLOP/s},\ \ \text{intensity}\times\text{bandwidth}\,\big)\]
In words: you can never go faster than the chip’s raw compute peak, and you also can never go faster than the rate at which bytes arrive times how much math you do per byte — the lower of the two ceilings wins.
Also written: with intensity \(I\) (FLOP/byte), bandwidth \(B\) (byte/s), peak \(P\) (FLOP/s): \(\text{rate} = \min(P,\ I\cdot B)\); you are memory-bound when \(I\cdot B < P\), i.e. when \(I < P/B\) (the break-even intensity).
A small worked check. Suppose a GPU has 312 TFLOP/s (fp16) and 2 TB/s bandwidth. Break-even intensity \(= (312\times10^{12}) / (2\times10^{12}) = 156\) FLOPs per byte. A big matmul reuses each loaded value many times → intensity in the hundreds → compute-bound (good, cores busy). But generating one token at a time in an LLM loads gigabytes of weights to do one vector-matrix product (intensity ≈ 1–2) → wildly memory-bound. That single fact explains why LLM decoding is slow and why nearly every inference trick targets bytes moved, not math done.
The sketch above is a cartoon of the roofline: the real plot is conventionally drawn on log-log axes, where the memory-bound region is a straight diagonal of slope 1 and the compute-bound region is a flat horizontal roof. The break-even tick is schematic, not to scale — read it as “low intensity sits on the slanted wall, high intensity sits under the flat roof,” not as a precise coordinate.
Rule of thumb: training and prompt prefill are usually compute-bound (lots of reuse); single-token decode is memory-bound. Optimize the right phase — speeding up math does nothing for a memory-bound step.
30.1a — The memory hierarchy and why kernel fusion / FlashAttention win
Think of memory like a desk: the GPU has a vast warehouse (HBM, slow to walk to) and a tiny on-chip notepad (SRAM, instant but holds almost nothing). Every time a computation writes a result back to the warehouse and then walks back to fetch it for the next step, that round-trip is the cost — not the math. A GPU’s SRAM is ~100× faster than HBM but ~1000× smaller, so the single biggest software win for a memory-bound workload is to keep intermediate values on the notepad and never round-trip them to the warehouse.
This is the whole idea behind kernel fusion: instead of running, say, a matmul, then a separate bias-add, then a separate activation — each one re-reading the data from HBM — you fuse them into one kernel that loads once, does all three on-chip, and writes once. FlashAttention is the famous example for transformers: classic attention materializes the full \(n\times n\) attention-score matrix in HBM (huge for long sequences), but FlashAttention computes attention in tiles that fit in SRAM, never writing the giant matrix out. The result is the same numbers, far fewer bytes moved, and it is now standard in every serious inference stack.
The same operation, the slow way versus the fused way — every HBM round-trip you erase is the speedup:
graph TD
A[Registers / SRAM: tiny, ~100x faster] --> B[L2 cache]
B --> C[HBM: gigabytes, the bottleneck]
C --> D[CPU RAM / NVMe: huge, very slow]
A -.kernel fusion keeps work up here.-> A
Whenever you see “we made attention/an op faster without changing the result,” it is almost always a memory-traffic trick: fuse operations, tile to fit SRAM, and avoid writing big intermediates back to HBM.
30.2 — Distributed training: data, model, tensor, and pipeline parallelism; ZeRO/FSDP
When a model or its training batch no longer fits on one GPU, you split the work across many. There are four orthogonal ways to cut the cake, and large runs combine all of them (“3D parallelism”).
Data parallelism (DP) — the simplest. Every GPU holds a full copy of the model and processes a different slice of the batch. After each backward pass, GPUs average their gradients with an all-reduce (every device ends up with the summed gradient) and step in sync. Scales throughput linearly until the all-reduce communication dominates. The catch: each GPU must store the entire model, gradients, and optimizer states — fine for a ResNet, impossible for a 70B model.
Tensor parallelism (TP) — split a single layer. A big weight matrix \(W\) is cut into column-shards \(W = [W_1 \mid W_2]\); GPU 1 computes \(xW_1\), GPU 2 computes \(xW_2\), and an all-gather stitches the outputs. This splits one matmul across devices, so it needs very fast interconnect (NVLink) and is kept within a node.
Pipeline parallelism (PP) — split by layers. GPU 1 holds layers 1–8, GPU 2 holds 9–16, etc. Activations flow forward like an assembly line. Naively this leaves most GPUs idle (the “bubble”); micro-batching keeps the pipeline full by streaming many small batches through the stages.
Model parallelism is the umbrella term for TP + PP — any scheme where the model itself (not just the data) is split.
graph LR
subgraph DP[Data Parallel]
d1[GPU full model · batch A]
d2[GPU full model · batch B]
d1 <-->|all-reduce grads| d2
end
subgraph PP[Pipeline Parallel]
p1[GPU: layers 1-8] --> p2[GPU: layers 9-16]
end
subgraph TP[Tensor Parallel]
t1[GPU: W left half] -.all-gather.- t2[GPU: W right half]
end
ZeRO / FSDP — the clever fix for data parallelism’s memory waste. In plain DP, \(N\) GPUs redundantly store \(N\) identical copies of parameters, gradients, and optimizer states. ZeRO (Zero Redundancy Optimizer) shards these across GPUs instead: each device owns only \(1/N\) of them, and gathers the full parameters for a layer just in time for its forward/backward, then frees them. FSDP (Fully Sharded Data Parallel, PyTorch’s implementation) is the same idea. You get data-parallel simplicity with model-parallel memory savings, paying in extra communication.
A quick memory illustration of why this matters: training in mixed precision, the optimizer states (Adam’s two moments plus an fp32 master copy) dominate. A common accounting is ~16 bytes per parameter of training state (2 fp16 weight + 2 fp16 grad + 4+4+4 fp32 master/momentum/variance ≈ 16). For a 7B model that is ~112 GB — already past a single 80 GB GPU. Shard it across 8 GPUs with ZeRO and each holds ~14 GB. That is the whole game.
ZeRO comes in three stages. The picture below is the fastest way to get it: think of the per-parameter training state as three stacked bricks — weights, gradients, optimizer states — and each stage slices one more brick into \(1/N\) pieces spread over the GPUs.
The per-GPU memory for stage 1, written as a formula:
\[\text{mem per GPU} \approx \underbrace{2\Psi}_{\text{fp16 weights}} + \underbrace{2\Psi}_{\text{fp16 grads}} + \frac{\overbrace{K\Psi}^{\text{optimizer states}}}{N_{\text{shard}}}\quad\text{(ZeRO-1, }K\approx 12\text{)}\]
In words: every GPU still keeps a full copy of the 16-bit weights and gradients, but the heavy optimizer states (the fp32 master copy plus Adam’s two moments, about 12 bytes/param) get divided by the number of GPUs you shard across.
Also written: for \(\Psi\) parameters across \(N\) GPUs, ZeRO-1 ≈ \(4\Psi + 12\Psi/N\) bytes/GPU; ZeRO-2 also shards gradients → \(2\Psi + (2\Psi+12\Psi)/N\); ZeRO-3 (FSDP) shards everything → \(\approx 16\Psi/N\) bytes/GPU.
A worked example with small numbers makes the stages tangible. Take a 1B-parameter model (\(\Psi = 1\text{B}\)) on \(N=8\) GPUs, using the 16-bytes/param accounting (2 weight + 2 grad + 12 optimizer):
| Setup | Formula | Bytes/param | Per-GPU memory |
|---|---|---|---|
| Plain DP | \(16\Psi\) | 16 | 16 GB (every GPU full) |
| ZeRO-1 | \(4\Psi + 12\Psi/8\) | \(4 + 1.5 = 5.5\) | 5.5 GB |
| ZeRO-2 | \(2\Psi + 14\Psi/8\) | \(2 + 1.75 = 3.75\) | 3.75 GB |
| ZeRO-3 | \(16\Psi/8\) | \(2\) | 2 GB |
Same model, same 8 GPUs — going from plain DP to ZeRO-3 drops each GPU’s footprint from 16 GB to 2 GB, an 8× saving, purely by not storing redundant copies.
| Strategy | What it splits | Each GPU stores | Main cost | Use when |
|---|---|---|---|---|
| Data (DP) | the batch | full model | all-reduce grads | model fits on 1 GPU |
| Tensor (TP) | one layer’s matmul | a slice of each layer | high-bw all-gather | layer too big; fast NVLink |
| Pipeline (PP) | groups of layers | some layers | pipeline bubble | many layers, slower links OK |
| ZeRO-1 | optimizer states | full params+grads, \(1/N\) opt | light extra comm | DP but opt states don’t fit |
| ZeRO-2 | + gradients | full params, \(1/N\) grad+opt | moderate comm | need more savings |
| ZeRO-3 / FSDP | + parameters | \(1/N\) of everything | gather params JIT | model itself doesn’t fit |
30.2a — Doing it in PyTorch (FSDP)
In practice you almost never hand-roll sharding; the framework wraps your model. PyTorch FSDP is one wrapper call:
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
# launched with torchrun --nproc_per_node=4 train.py
torch.distributed.init_process_group("nccl")
model = MyTransformer().cuda()
# shard params, grads, and optimizer state across all GPUs (ZeRO-3 style)
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16, # compute in bf16
reduce_dtype=torch.bfloat16, # all-reduce grads in bf16
),
device_id=torch.cuda.current_device(),
)
opt = torch.optim.AdamW(model.parameters(), lr=2e-4)
for batch in loader: # each GPU gets a different slice
opt.zero_grad()
loss = model(batch).loss
loss.backward() # FSDP gathers/frees shards just-in-time
opt.step()The Hugging Face accelerate library wraps the same machinery behind a config file, and DeepSpeed exposes the three ZeRO stages directly via a JSON config — pick whichever your stack already uses.
30.3 — Mixed precision
By default, numbers in a network are 32-bit floats (fp32). But most of training tolerates far less precision. Mixed-precision training runs the heavy matmuls in 16-bit — either fp16 (5 exponent, 10 mantissa bits) or bf16 (8 exponent, 7 mantissa bits) — while keeping a few sensitive pieces in fp32. The payoff: half the memory for activations and weights, and 2–8× faster matmuls on tensor cores, with essentially no loss in final accuracy.
The two 16-bit formats trade range vs precision. Picture a number as a ruler: bf16 keeps the long ruler (fp32’s full exponent range, so it almost never runs off the end into overflow) but marks it coarsely; fp16 marks a short ruler finely, so it reads small differences well but falls off the end (overflow to infinity, or underflow of tiny gradients to zero) easily. Because of that, fp16 training needs loss scaling: multiply the loss by a big constant \(S\) before backprop so small gradients land in fp16’s representable range, then divide by \(S\) before the optimizer step. bf16 usually needs no scaling — which is why it has become the default on hardware that supports it.
The three formats laid side by side — same total story, different split between range and precision:
graph TD
A[fp32 master weights] -->|cast| B[bf16/fp16 weights]
B --> C[forward and backward in 16-bit, fast tensor-core matmuls]
C --> D[gradients]
D -->|unscale if fp16| E[update fp32 master weights]
E --> A
Concretely: keep an fp32 master copy of the weights (so tiny updates aren’t lost to rounding), cast to 16-bit for the forward/backward, accumulate matmul results in fp32 inside the tensor core, and apply the optimizer step to the fp32 master. For inference, you often skip the master copy entirely and just run bf16 or fp16 weights.
In PyTorch this is two context managers, not a rewrite:
import torch
scaler = torch.cuda.amp.GradScaler() # loss scaling (fp16 only)
for x, y in loader:
opt.zero_grad()
with torch.autocast("cuda", dtype=torch.bfloat16): # or float16
loss = loss_fn(model(x), y) # matmuls run in 16-bit
scaler.scale(loss).backward() # scale → backprop → unscale
scaler.step(opt)
scaler.update()
# with bf16 you can usually drop the GradScaler entirelyDon’t reach for fp16 on hardware that supports bf16 unless you have a reason. fp16 without loss scaling silently NaNs out partway through training; bf16 sidesteps the whole class of overflow/underflow bugs. The convenience is worth the slightly coarser mantissa.
🆕 The newest rung: fp8. The latest accelerators (NVIDIA Hopper/Blackwell) add fp8 training and inference — 8-bit floats in two flavours, e4m3 (more precision) for forward activations/weights and e5m2 (more range) for gradients. It roughly halves memory and bandwidth again versus bf16, but needs per-tensor scaling factors to keep values in range. Treat it as bf16’s successor for the very largest runs, not a default yet.
30.4 — Inference optimization: quantization, pruning, distillation, KV cache, batching, speculative decoding
Serving a trained model has its own bottlenecks, distinct from training: the memory to hold the weights and the per-token KV cache, and the bytes moved each decode step. The six techniques below each attack one or both, and they stack. This is the longest section in the chapter, so it is broken into labelled sub-parts — treat each ### below as its own mini-section (30.4a through 30.4f).
30.4a — Quantization (int8 / nf4 / GPTQ / AWQ)
Quantization stores weights in fewer bits than fp16. The simplest is int8: map each weight to an 8-bit integer with a per-channel scale, \(w \approx s\cdot q\) where \(q\in[-127,127]\). That halves memory versus fp16 and, since decode is memory-bound, roughly doubles decode speed. Going further, 4-bit formats cut memory by 4×. nf4 (NormalFloat-4, from QLoRA) is a 4-bit type whose 16 levels are placed to match a normal distribution — weights are roughly Gaussian, so the levels land where the mass is.
The quantize/dequantize pair is just a scale-and-round, then a scale-back:
\[q = \operatorname{round}\!\Big(\frac{w}{s}\Big),\qquad \hat w = s\cdot q,\qquad s = \frac{\max|w|}{2^{b-1}-1}\]
In words: divide each weight by a shared step size \(s\) and round to the nearest integer code; to read the weight back, multiply the code by \(s\) again. The step size is chosen so the biggest weight maps to the largest code an \(b\)-bit integer can hold.
Also written: with \(b=8\) the denominator is \(2^7-1=127\), so \(q=\operatorname{round}(127\,w/\max|w|)\) and \(\hat w = q\cdot\max|w|/127\) — a symmetric per-tensor (or per-channel) linear map.
The naive way to round every weight independently loses accuracy at 4-bit. GPTQ and AWQ are smarter post-training quantizers: GPTQ rounds weights one column at a time while adjusting the rest to compensate for the error (using second-order/Hessian information from a small calibration set); AWQ (Activation-aware Weight Quantization) notices that a few weight channels carry most of the activation magnitude and protects those by scaling, quantizing the rest aggressively. Both recover near-fp16 quality at 4-bit.
The picture below shows what “round to the nearest code” actually does — continuous fp16 weights snap onto a small grid of allowed levels:
Tiny worked example of int8 with a per-tensor scale:
import numpy as np
w = np.array([0.12, -0.30, 0.61, -0.02]) # fp16 weights
s = np.max(np.abs(w)) / 127 # per-tensor scale = 0.61/127
q = np.round(w / s).astype(np.int8) # -> int8 codes
w_hat = q * s # dequantized
print(q) # [ 25 -62 127 -4]
print(np.round(w_hat,3)) # [0.12 -0.298 0.61 -0.019] ~= originalThe 4 weights went from 4×2 = 8 bytes (fp16) to 4×1 = 4 bytes of codes plus one shared scale — and the reconstruction is within ~0.002 of the originals. At scale, that error is small enough that a well-calibrated int8 model is essentially lossless.
A visual of where the bits go — the same weight at three precisions, and the resulting memory for a 13B model:
| Format | Bits | Memory vs fp16 | Quality | Notes |
|---|---|---|---|---|
| fp16/bf16 | 16 | 1× | baseline | default |
| int8 | 8 | 0.5× | ~lossless | LLM.int8() handles outliers |
| GPTQ | 4 | 0.25× | near-baseline | error-compensating, needs calibration |
| AWQ | 4 | 0.25× | near-baseline | protects salient channels, fast kernels |
| nf4 | 4 | 0.25× | near-baseline | Gaussian-optimal, used in QLoRA |
In practice you rarely implement a quantizer; you load a pre-quantized checkpoint or apply one with a library. With Hugging Face, loading a 4-bit model is a flag:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Gaussian-optimal 4-bit
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-hf", quantization_config=cfg, device_map="auto",
)
# a 26 GB fp16 model now occupies ~6.5 GB — fits a single consumer GPUQLoRA combines this with low-rank adapters: freeze the nf4-quantized base, train only small LoRA matrices on top, and you can fine-tune a 13B model on one consumer GPU — the quantized base never needs gradients.
30.4b — Pruning
Pruning deletes weights judged unimportant (e.g. those nearest zero), turning a dense network sparse. Unstructured pruning zeros individual weights anywhere in the matrix — high compression, but the resulting irregular sparsity pattern rarely speeds up real hardware, which wants dense, regular blocks. Structured pruning removes whole units (channels, attention heads, even entire layers) — it achieves less compression for the same accuracy, but it actually runs faster because the matrices genuinely shrink and stay dense. A middle ground, 2:4 sparsity, forces exactly 2 of every contiguous 4 weights to zero, a regular pattern that modern NVIDIA tensor cores accelerate directly (~2× on those matmuls).
A tiny numeric example makes the three flavours concrete. Take one row of 8 weights and prune ~50%:
original: [ 0.9 -0.1 0.7 0.05 -0.8 0.02 0.6 -0.04 ]
unstructured: [ 0.9 0 0.7 0 -0.8 0 0.6 0 ] # smallest-magnitude zeroed, anywhere
2:4 sparsity: [ 0.9 0 0.7 0 -0.8 0 0.6 0 ] # exactly 2 kept per group of 4
structured: drop the whole column/head this row feeds → matrix shrinks 8→4 wide
Pruning to a flashy “90% sparse” number is misleading if it is unstructured — on a normal GPU that model usually runs at the same speed as the dense one, because the hardware still streams the zeros. Only structured pruning or hardware-supported patterns like 2:4 turn sparsity into real wall-clock speedup. Quote the speedup you measured, not the sparsity you achieved.
30.4c — Knowledge distillation
Knowledge distillation trains a small student model to mimic a large teacher. Instead of (or alongside) the hard labels, the student matches the teacher’s full output probabilities — the “soft targets,” softened with a temperature \(T\) — which carry richer information (“this image is mostly cat, a little dog, definitely not car”). The result is a compact model that punches above its size. DistilBERT and many small instruct models are distilled this way.
\[\mathcal{L} = \alpha\,\underbrace{\text{CE}(y,\ p_{\text{student}})}_{\text{true labels}} + (1-\alpha)\,T^2\,\underbrace{\text{KL}\!\big(p^T_{\text{teacher}}\,\|\,p^T_{\text{student}}\big)}_{\text{match soft targets}}\]
In words: the student’s loss is a blend of two goals — get the real answer right (the cross-entropy on true labels) and imitate the teacher’s softened opinion about every class (the KL term) — with \(\alpha\) setting how much you weight each and \(T^2\) keeping the soft-target gradients comparable in size.
Also written: \(\mathcal{L} = \alpha\,\mathrm{CE}(y,p_s) + (1-\alpha)\,T^2\sum_i p^T_{t,i}\,\log\!\frac{p^T_{t,i}}{p^T_{s,i}}\), where \(p^T\) denotes a softmax taken at temperature \(T\).
The temperature is the heart of it. A softmax with temperature \(T\) computes \(p_i = \dfrac{e^{z_i/T}}{\sum_j e^{z_j/T}}\). At \(T=1\) it is the ordinary softmax; raising \(T\) flattens the distribution, exposing the small probabilities (“dark knowledge”) that tell the student how the teacher relates the wrong classes to each other. The \(T^2\) factor on the KL term rescales its gradient back to the same magnitude as the hard-label term.
In words: dividing the logits by \(T\) before the exponential squashes the gaps between them, so a confident “87% cat” relaxes into “70% cat, 26% dog” — revealing which wrong answers the teacher thought were close calls.
Also written: \(p_i = \operatorname{softmax}(z/T)_i\); as \(T\to\infty\) it approaches the uniform distribution \(1/K\), and as \(T\to 0\) it approaches a one-hot at \(\arg\max_i z_i\).
The bars below show the same teacher logits at two temperatures — raising \(T\) lifts the hidden “dog” signal into view:
A tiny number-in/number-out example. Suppose for one image the teacher’s logits over three classes {cat, dog, car} are \(z = [3.0,\ 1.0,\ -1.0]\).
import numpy as np
z = np.array([3.0, 1.0, -1.0]) # teacher logits: cat, dog, car
def softmax_T(z, T):
e = np.exp((z - z.max()) / T)
return e / e.sum()
print(np.round(softmax_T(z, 1), 3)) # T=1 -> [0.867 0.117 0.016] (very peaked)
print(np.round(softmax_T(z, 2), 3)) # T=2 -> [0.705 0.259 0.035] (softened)At \(T=1\) the teacher says “cat, 87%” and the dog/car signal is nearly invisible. At \(T=2\) the same logits become [0.705, 0.259, 0.035]: the student now sees that dog is a plausible alternative and car is firmly out — the relationship “cats look more like dogs than like cars.” Training the student to reproduce the softened [0.705, 0.259, 0.035] (rather than a one-hot [1, 0, 0]) is what transfers that nuance, and is why a distilled student often beats a same-size model trained on hard labels alone.
30.4d — The KV cache
In a transformer, generating each new token attends to all previous tokens via their keys and values. Recomputing those for the whole history at every step would be quadratic and wasteful. The KV cache stores the K and V vectors of every past token so each new step only computes K/V for the one new token and reuses the rest. This is what makes autoregressive decoding linear instead of quadratic — but the cache grows with sequence length and batch size and can dwarf the weights in memory.
\[\text{KV bytes} = 2 \times L \times n_{\text{layers}} \times d_{\text{model}} \times \text{batch} \times \text{bytes}\]
In words: the cache holds two vectors (a key and a value) for every token position, in every layer, across every dimension of the model, for every request in the batch — multiply all those counts together, times the bytes per number.
Also written: \(\text{KV bytes} = 2\,b\,L\,n_{\text{layers}}\,d_{\text{model}}\,B\) with \(b\) = bytes per element (2 for fp16), \(L\) = sequence length, \(B\) = batch size; it grows linearly in both \(L\) and \(B\).
(the leading 2 is for K and V). For a 7B model (32 layers, \(d=4096\)), one sequence of 2048 tokens in fp16 is \(2\times2048\times32\times4096\times2 \approx 1.0\) GB — per request. The weights are fixed, but the KV cache grows linearly with context length, so at long context (or large batch) it overtakes the weights and becomes the thing that fills the GPU:
The picture is the whole lesson: weights are a flat red line, but the blue KV bars climb with context until, at long context or high batch, the cache is the dominant consumer of GPU memory. That is why the next two techniques — paging the cache and batching around it — exist.
Shrinking the cache directly. Two architectural tricks attack the KV cache at the source, and both are now standard in production LLMs:
- Grouped-query attention (GQA) / multi-query attention (MQA): instead of every attention head having its own K and V, several query heads share one K/V head. If 32 query heads share 8 K/V groups, the KV cache shrinks 4× — which is why models like Llama-2-70B and Mistral use GQA.
- KV-cache quantization: store the cached K and V in int8 or 4-bit instead of fp16, halving or quartering the cache with little quality loss.
30.4e — Batching and continuous batching (vLLM / PagedAttention)
Because decode is memory-bound, processing one request wastes the GPU — you load all the weights to serve a single token. Batching runs many requests together so each weight load serves many tokens, multiplying throughput nearly for free. Static batching is clumsy, though: requests finish at different times, and a batch that waits for its slowest member leaves GPUs idle.
Continuous batching (a.k.a. in-flight batching) fixes this by scheduling at the token level: as soon as one sequence finishes, a new one slots into its place, so the batch is always full. vLLM is the popular server that does this. Its key trick is PagedAttention: instead of reserving one big contiguous block of memory for each request’s KV cache (which wastes a lot to fragmentation and over-reservation), it stores the cache in small fixed-size pages like an operating system’s virtual memory. Pages are allocated on demand and can even be shared across requests with a common prefix. The result is far less wasted memory and much higher throughput.
The animation contrasts the two: in a static batch, finished requests leave dead slots until the slowest one ends; continuous batching backfills them instantly.
graph TD
subgraph Static["Static batch (idle gaps)"]
s1[req1 short, then idle]
s2[req2 very short, then idle]
s3[req3 long, holds the batch]
end
subgraph Cont["Continuous batch (always full)"]
c1[req1 finishes] --> c4[req4 fills the slot]
c2[req2 still running]
c3[req3 finishes] --> c5[req5 fills the slot]
end
Standing this up is a few lines — vLLM does continuous batching and PagedAttention for you:
from vllm import LLM, SamplingParams
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", quantization="awq")
params = SamplingParams(temperature=0.7, max_tokens=128)
# pass MANY prompts at once; vLLM batches them continuously under the hood
prompts = ["Explain the KV cache.", "What is PagedAttention?", "Define TTFT."]
for out in llm.generate(prompts, params):
print(out.outputs[0].text)30.4f — Speculative decoding
Decoding is serial: one token at a time, each needing a full pass over the big model. Speculative decoding breaks the serialism. A small, cheap draft model proposes the next \(k\) tokens quickly; the big target model then verifies all \(k\) in a single parallel forward pass (it can score many positions at once). Tokens the target agrees with are accepted; at the first disagreement it falls back to the target’s own token.
The key property is correctness: speculative decoding is distribution-preserving. With greedy decoding the output is exactly what the target alone would have produced; with sampling, a rejection-sampling correction step makes the accepted tokens follow the target’s distribution in expectation. Either way there is no quality change — the draft model only proposes, never overrules. The speedup comes entirely from doing fewer serial big-model passes, and how big it is depends on the acceptance rate: a draft that matches the target often (easy, predictable text) yields a larger speedup than one that rarely agrees. In practice this lands around 2–3× for typical workloads, but it is acceptance-rate-dependent, not a fixed constant.
The expected speedup has a clean closed form. If the draft proposes \(k\) tokens and each is accepted independently with probability \(\alpha\), the expected number of tokens accepted per verification round is
\[\mathbb{E}[\text{accepted}] = \frac{1 - \alpha^{k+1}}{1 - \alpha}\]
In words: with acceptance rate \(\alpha\), you usually clear several tokens per expensive target pass instead of one; the higher the draft agrees, the more tokens each verification round advances.
Also written: \(\sum_{i=0}^{k}\alpha^{i} = 1+\alpha+\alpha^2+\dots+\alpha^k\) — a geometric series; e.g. \(\alpha=0.8,\,k=4\) gives \(\approx 3.36\) tokens per round versus 1 without speculation.
graph LR
A[draft model proposes k tokens, cheap and serial] --> B[target model verifies all k in one parallel pass]
B --> C{accept prefix that matches}
C -->|all k accepted| D[advance k tokens at once]
C -->|first mismatch| E[take target's token, redraft from there]
Stack these — they’re mostly orthogonal. A typical production recipe: AWQ 4-bit weights + PagedAttention KV cache + continuous batching + speculative decoding. Each attacks a different bottleneck, so the speedups roughly multiply.
Quantization and pruning are lossy — always re-evaluate quality on your own task after applying them, not just on a generic benchmark. A model that looks “near-baseline” on perplexity can quietly regress on your specific domain, formatting, or long-context behaviour.
30.5 — Throughput vs latency: TTFT, tokens/sec, cost math
Two numbers describe serving speed, and they pull against each other. Latency is how fast one user gets their answer; throughput is how many tokens the whole system produces per second across all users. Batching raises throughput (good for cost) but can raise latency (each request waits to be grouped).
For LLMs, latency splits into two phases:
- TTFT (time to first token) — the wait before the first word appears. Dominated by prefill: running the entire prompt through the model once to build its KV cache. Scales with prompt length and is compute-bound.
- TPOT / ITL (time per output token, or inter-token latency) — the gap between subsequent tokens once generation starts. Dominated by memory-bound decode. Its inverse is the tokens/sec a user perceives.
Total time for a response \(\approx \text{TTFT} + (\text{output tokens}-1)\times\text{TPOT}\).
In words: the user waits once for the prompt to be digested (TTFT), then waits one short gap (TPOT) for every additional word that streams out — add them up for the whole response.
Also written: \(T_{\text{total}} \approx \text{TTFT} + (N_{\text{out}}-1)\cdot\text{TPOT}\), and the perceived streaming speed is \(\text{tokens/sec} = 1/\text{TPOT}\).
The timeline below shows the two phases: one big upfront prefill, then a steady drip of decode steps. The pulsing dots are the tokens streaming out at one TPOT apart.
graph LR
A[Prompt arrives] -->|prefill: process whole prompt| B[First token · TTFT]
B -->|decode 1 token| C[token 2]
C -->|decode 1 token| D[token 3 ...]
B -. inter-token latency = TPOT .- C
Cost math. Inference cost ties almost directly to GPU-hours and tokens. Suppose a GPU costs $2/hour and your server sustains 2500 output tokens/sec (with batching). That’s \(2500\times3600 = 9{,}000{,}000\) tokens/hour, so cost \(= \$2 / 9\text{M} \approx \$0.22\) per million output tokens. Now halve the weights with int8: decode speed roughly doubles to 5000 tok/s, dropping cost to ~$0.11 per million tokens. This is the lever — throughput per dollar — that quantization and batching pull. (Prefill/input tokens are usually cheaper per token because prefill is compute-efficient, which is why providers price input and output tokens differently.)
The cost-per-token formula behind that arithmetic:
\[\text{cost per token} = \frac{\text{GPU \$/hour}}{3600 \times \text{tokens/sec}}\]
In words: divide what the hardware costs each second by how many tokens it produces each second — faster decode or cheaper GPUs both push the price down, in direct proportion.
Also written: $/token \(= C_{\text{hr}} / (3600\,R)\) where \(C_{\text{hr}}\) is GPU $/hour and \(R\) is tokens/sec; per million tokens, multiply by \(10^6\).
Don’t optimize throughput blindly. Cranking batch size to maximize tokens/sec can push TTFT and inter-token latency past what feels responsive (people notice chat slower than ~20–30 tokens/sec). Set a latency SLA first, then maximize throughput within it.
30.6 — Edge vs cloud serving
Where the model runs is its own design axis. Cloud serving puts the model on rented datacenter GPUs behind an API: you get near-unlimited scale, the biggest models, easy updates, and pay-per-use economics — at the price of network round-trips, ongoing cost, and sending user data off-device. Edge serving runs the model on the device itself — phone, laptop, car, camera, sensor hub — giving low latency (no network), offline operation, and privacy (data never leaves), but constrained by the device’s tiny memory, compute, and battery.
The constraints flip which techniques matter. On the edge, aggressive 4-bit (or lower) quantization, distillation to a small student, and pruning are not optional niceties but the price of fitting at all; a 70B cloud model becomes a distilled, quantized 3B model on a phone. In the cloud, you instead lean on batching, PagedAttention, and multi-GPU parallelism to maximize throughput per dollar.
A concrete edge stack worth knowing: llama.cpp (with the GGUF quantized format) runs LLMs in 4-bit on a laptop CPU or a phone with no GPU at all; ONNX Runtime and Core ML (Apple) / TensorFlow Lite (Android) compile a small model down to the device’s native accelerator. The common thread is that the heavy compression of §30.4 is what makes the model small enough to ship inside an app.
graph TD
R[Request] --> Q{Latency-critical, private, or offline?}
Q -->|Yes| E[Edge: small quantized/distilled model on-device]
Q -->|No, need biggest model + scale| C[Cloud: large model, batched multi-GPU]
E -.fallback for hard queries.-> C
| Edge | Cloud | |
|---|---|---|
| Latency | very low (no network) | network round-trip |
| Model size | small (≤ few B params) | up to frontier scale |
| Privacy | data stays on device | data leaves device |
| Cost model | one-time device | pay-per-use, ongoing |
| Key techniques | 4-bit quant, distillation, pruning | batching, PagedAttention, parallelism |
| Updates | ship new app build | swap server-side |
A common production pattern is hybrid: a small edge model handles the easy, frequent, latency-sensitive cases on-device and routes only the hard ones to a big cloud model — getting the privacy and speed of edge with the capability of cloud as a fallback.
30.7 — Worked GPU-memory estimate
Let’s put the chapter’s numbers together and answer a concrete question every practitioner faces: will this fit, and on how many GPUs? Take Llama-2-13B and an 80 GB A100.
Inference, fp16. Weights are the floor: \(13\text{B} \times 2\text{ bytes} = 26\text{ GB}\). Quantize to int8 and it’s \(13\text{ GB}\); to 4-bit (GPTQ/AWQ/nf4), \(\approx 6.5\text{ GB}\). On top sits the KV cache. With 40 layers and \(d=5120\), one 4096-token sequence in fp16 is \(2\times4096\times40\times5120\times2 \approx 3.4\text{ GB}\). So fp16 inference of one long sequence ≈ \(26 + 3.4 \approx 30\text{ GB}\) — fits one 80 GB card with room to batch; 4-bit weights leave room for many concurrent sequences.
Training, fp16 + Adam. Now the optimizer states dominate. Using the ~16 bytes/param accounting (fp16 weight + fp16 grad + fp32 master + fp32 momentum + fp32 variance):
\[13\times10^{9} \times 16\text{ bytes} \approx 208\text{ GB of state}\]
In words: every parameter drags along about 16 bytes of training bookkeeping, so a 13-billion-parameter model needs over 200 GB just for its state — far more than any single card holds.
Also written: \(M_{\text{train}} \approx 16\,\Psi\) bytes for \(\Psi\) parameters; sharded over \(N\) GPUs with ZeRO-3 it becomes \(\approx 16\,\Psi / N\) per GPU.
Add activations (which scale with batch size and sequence length and can be tens of GB more). 208 GB does not fit one 80 GB GPU. With ZeRO/FSDP across 4 GPUs, each holds \(\approx 52\text{ GB}\) of state plus its activation share — now it fits, which is exactly why sharded data-parallel training is standard.
# Quick estimator: does it fit?
def fits(params_b, bytes_per_param, n_gpu, gpu_gb=80, overhead_gb=4):
total = params_b * bytes_per_param # GB of model state
per_gpu = total / n_gpu + overhead_gb # ZeRO shards state across GPUs
return round(per_gpu, 1), per_gpu <= gpu_gb
print(fits(13, 2, 1)) # fp16 inference weights -> (30.0, True)
print(fits(13, 16, 1)) # fp16+Adam training -> (212.0, False)
print(fits(13, 16, 4)) # same, sharded x4 -> (56.0, True)Back-of-envelope rules: inference fp16 ≈ 2 bytes/param, int8 ≈ 1, 4-bit ≈ 0.5; full Adam training ≈ 16 bytes/param. Then add the KV cache for inference and activations for training. Reach for this estimate before renting GPUs — it tells you the precision and GPU count in one minute.
30.8 — Quick reference
| Term / formula | Meaning | When / why it matters |
|---|---|---|
| Roofline: \(\min(P,\,I\cdot B)\) | speed = lower of compute peak and intensity×bandwidth | decide if a step is memory- or compute-bound before optimizing |
| Arithmetic intensity \(I\) | FLOPs done per byte loaded | low \(I\) (LLM decode ≈1–2) → memory-bound; high \(I\) (matmul) → compute-bound |
| GEMM FLOPs \(=2mkn\) | cost of an \(m\times k\) by \(k\times n\) matmul | sizing compute for any layer |
| Kernel fusion / FlashAttention | load once, compute on SRAM, store once | erase HBM round-trips; same numbers, far fewer bytes moved |
| Data parallel (DP) | full model per GPU, split the batch | model fits one GPU, want more throughput |
| Tensor parallel (TP) | split one layer’s matmul across GPUs | a single layer too big; needs fast NVLink |
| Pipeline parallel (PP) | split layers into stages | many layers, slower links OK; watch the bubble |
| ZeRO/FSDP (3 stages) | shard optimizer→grads→params, \(\approx 16\Psi/N\) at stage 3 | DP simplicity with model-parallel memory savings |
| Mixed precision (bf16/fp16) | 16-bit matmuls + fp32 master copy | half the memory, 2–8× faster; prefer bf16 (no loss scaling) |
| Quantization \(\hat w = s\cdot q\) | store weights in 8/4 bits via scale-and-round | int8 ≈ lossless 2× shrink; nf4/GPTQ/AWQ 4× at near-baseline |
| Pruning (unstructured / 2:4 / structured) | zero unimportant weights | only structured / 2:4 give real wall-clock speedup |
| Distillation, soft targets at temp \(T\) | small student mimics teacher’s softened probs | compact model that punches above its size |
| KV cache \(=2bLn_{\text{layers}}d_{\text{model}}B\) | store past K/V so decode is linear | grows with context/batch; can exceed the weights |
| GQA/MQA, KV quantization | share or compress K/V heads | shrink the cache at the source |
| PagedAttention + continuous batching | page KV like virtual memory, backfill slots | reclaim wasted KV memory, keep batches full (vLLM) |
| Speculative decoding, \(\frac{1-\alpha^{k+1}}{1-\alpha}\) | draft proposes \(k\), target verifies in parallel | distribution-preserving ~2–3× speedup, acceptance-rate-dependent |
| TTFT / TPOT | first-token wait (prefill) / inter-token gap (decode) | the throughput–latency tradeoff; set a latency SLA first |
| Cost/token \(=\dfrac{C_{\text{hr}}}{3600\,R}\) | GPU $/hour over tokens/sec | faster decode or cheaper GPUs lower price proportionally |
| Memory rule of thumb | ≈2 B/param fp16, ≈1 int8, ≈0.5 4-bit, ≈16 Adam training | one-minute fit check before renting GPUs |
30.9 — Key takeaways
- Modern AI is memory-bound, not compute-bound: GPUs/TPUs have far more FLOPs than memory bandwidth, so most inference tricks move fewer bytes (the roofline model). LLM decode is the canonical memory-bound step; the memory hierarchy (SRAM vs HBM) is why kernel fusion / FlashAttention win.
- Four parallelism axes — data, tensor, pipeline, and ZeRO/FSDP sharding — combine to train models too big for one GPU; ZeRO’s three stages progressively shard optimizer states, gradients, then parameters.
- Mixed precision (bf16/fp16 with an fp32 master copy) halves memory and speeds matmuls; prefer bf16 to dodge fp16’s overflow/loss-scaling headaches, with fp8 emerging on the newest hardware.
- Quantization (int8, 4-bit nf4/GPTQ/AWQ), pruning (unstructured / structured / 2:4), and distillation (soft targets at temperature \(T\)) shrink models for cheaper, faster serving — 4-bit cuts weight memory 4× at near-baseline quality; QLoRA fine-tunes on top of a frozen 4-bit base.
- The KV cache makes decoding linear but grows with context until it can exceed the weights in size; GQA/MQA and KV quantization shrink it at the source, while PagedAttention + continuous batching (vLLM) reclaim wasted KV memory and keep batches full; speculative decoding is distribution-preserving and adds an acceptance-rate-dependent ~2–3× speedup.
- Serving is a throughput–latency tradeoff measured by TTFT and tokens/sec; cost per million tokens falls directly out of tokens/sec per GPU-dollar.
- Edge vs cloud flips which techniques matter; a memory estimate (≈2 B/param inference, ≈16 B/param Adam training) tells you precision and GPU count before you spend a cent.
30.10 — See also
- MLOps & Deployment — serving infrastructure, monitoring, and the deployment lifecycle around these optimized models.
- Tools & Frameworks — PyTorch/FSDP, vLLM, and the libraries that implement this chapter’s techniques.
- Large Language Models — the transformer decoding loop, KV cache, and scaling that make inference efficiency critical.
- Attention & Transformers — the attention mechanism whose keys/values the KV cache and PagedAttention store.
- Neural Networks (Core) — backpropagation and optimizer states that drive the training-memory math.
- Optimization — Adam and the optimizer states whose memory ZeRO/FSDP shards.
↪ The thread continues → Chapter 31 · 🧰 Tools & Frameworks
Infrastructure is the engine; the next chapter is the cockpit — the everyday tools and frameworks (NumPy to PyTorch to vLLM) you actually type to get the work done.
📖 All chapters | ← 29 · 🔧 MLOps & Deployment | 31 · 🧰 Tools & Frameworks →