flowchart LR
A["prompt idx<br/>(B, T)"] --> B["crop to last<br/>block_size tokens"]
B --> C["model forward<br/>logits (B, T, vocab)"]
C --> D["take last step<br/>logits[:, -1, :]<br/>(B, vocab)"]
D --> E["temperature /<br/>top-k / top-p"]
E --> F["softmax +<br/>multinomial sample"]
F --> G["append token<br/>idx (B, T+1)"]
G -->|repeat| B
⚡ Building Transformers from Scratch with PyTorch · Lesson 9 — Generation & Decoding: Making Your GPT Speak
🏠 ⚡ Course home | ← Lesson 08 | Lesson 10 → | 📚 All mini-courses
Lesson 9 — Generation & Decoding: Making Your GPT Speak
In the previous lesson you watched the loss curve fall and saved a checkpoint of your trained tiny-GPT. In this lesson we cash that in. Training taught the model a conditional distribution \(P(x_t \mid x_{<t})\) — but a distribution is not text. Turning it into text is the job of decoding, and the decisions you make here (temperature, top-k, top-p) change the character of the output as much as another thousand training steps would. We’ll build the autoregressive generate() loop from scratch, implement the three classic sampling strategies by hand, and then confront the dirty secret of naive generation: it recomputes almost everything, every step. We’ll fix that with a KV cache — the single most important inference optimization in every production LLM — and measure the speedup on our own model.
🎯 In this lesson you will: write the autoregressive generate() loop with context cropping, implement temperature / top-k / top-p sampling from raw logits, understand why naive generation is quadratic in work, build a working KV cache into your model, benchmark the speedup on real hardware
The autoregressive loop
A decoder-only transformer generates text one token at a time: feed in the sequence so far, read the model’s prediction for the next token, append it, repeat. The theory is in the encyclopedia’s Attention & Transformers chapter; here is the mechanical loop:
Three details in that diagram carry all the correctness weight, so let’s write the code and walk them.
import torch
import torch.nn.functional as F
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
"""idx: (B, T) tensor of token ids — the prompt. Returns (B, T + max_new_tokens)."""
model.eval()
for _ in range(max_new_tokens):
# 1) crop: the model has learned positions 0..block_size-1 and nothing else
idx_cond = idx[:, -model.config.block_size:]
# 2) forward: logits has shape (B, T, vocab_size) — a prediction at EVERY position
logits, _ = model(idx_cond)
# 3) we only want the prediction after the LAST token
logits = logits[:, -1, :] # (B, vocab_size)
idx_next = sample_next(logits, temperature, top_k, top_p) # (B, 1)
idx = torch.cat((idx, idx_next), dim=1) # grow the sequence
return idxWhy crop (step 1)? Your model uses learned positional embeddings: an nn.Embedding(block_size, n_embd) table from Lesson 3. Position 257 simply does not exist in that table — feed a sequence longer than block_size and you get an index-out-of-range crash (or, with other position schemes, silent garbage the model never trained on). So once the generated sequence exceeds the context window, we slide: keep only the most recent block_size tokens. The model forgets the distant past; that’s the price of a finite window.
Why the last step only (step 3)? During training, predicting at every position was the whole point — one forward pass gave us T supervised predictions (Lesson 8). During generation, positions 0..T-2 predict tokens we already have. Only logits[:, -1, :] — the distribution over what comes after the final token — is new information. Note the inefficiency hiding here: we computed T × vocab_size logits and threw away all but one row. Hold that thought; it’s the motivation for the KV cache later.
Why @torch.no_grad() and model.eval()? No gradients are needed, so no_grad skips building the autograd graph (big memory and speed win), and eval() disables dropout — sampling through active dropout adds noise the model wasn’t trained to emit at inference.
The sampling step is where all of this lesson’s interesting decisions live, so we’ve factored it out. Let’s build sample_next piece by piece.
Temperature: reshaping the distribution
The model emits logits — unnormalized scores. Softmax turns them into probabilities, and temperature \(\tau\) is a knob applied before the softmax:
\[ P(x_t = i) = \frac{\exp(z_i / \tau)}{\sum_j \exp(z_j / \tau)} \]
Dividing by \(\tau < 1\) stretches the gaps between logits, so the softmax sharpens toward the argmax — conservative, repetitive text. \(\tau > 1\) compresses the gaps, flattening the distribution — adventurous, error-prone text. At \(\tau \to 0\) you recover greedy decoding (always pick the argmax); at \(\tau \to \infty\) you’re sampling uniformly from the vocabulary.
The implementation is one line, plus a guard for the greedy case (dividing by 0.0 gives inf logits and multinomial will crash on the resulting NaNs):
def sample_next(logits, temperature=1.0, top_k=None, top_p=None):
"""logits: (B, vocab_size) raw scores for the next token. Returns (B, 1) token ids."""
if temperature == 0.0: # greedy decoding
return logits.argmax(dim=-1, keepdim=True)
logits = logits / temperature
if top_k is not None:
logits = top_k_filter(logits, top_k)
if top_p is not None:
logits = top_p_filter(logits, top_p)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)Order matters and is worth stating explicitly: temperature first, then truncation filters, then softmax, then sample. Temperature changes which tokens survive a top-p cutoff (a flatter distribution needs more tokens to reach the same cumulative mass), so applying it after the filter would give different — and conventionally wrong — behavior.
torch.multinomial(probs, 1) draws one index per batch row according to the given probabilities. This is the entire source of “creativity” in a language model: a single categorical sample per token.
Top-k and top-p, from scratch
Pure temperature sampling has a failure mode: even at \(\tau = 1\), the vocabulary’s long tail collectively holds a few percent of probability mass, so every few dozen tokens you sample something genuinely stupid — and because generation is autoregressive, one bad token poisons everything after it. Truncation strategies cut the tail off before sampling.
Top-k keeps only the k highest-scoring tokens:
def top_k_filter(logits, k):
k = min(k, logits.size(-1)) # don't ask for more than vocab_size
kth_vals = torch.topk(logits, k, dim=-1).values[:, -1, None] # (B, 1): the k-th largest
return logits.masked_fill(logits < kth_vals, float('-inf'))The trick: torch.topk returns values sorted descending, so values[:, -1] is the k-th largest logit per row. Everything strictly below it gets -inf, which softmax maps to exactly zero probability — the masked tokens can never be sampled. This is the same -inf-then-softmax masking pattern you used for causal attention on Lesson 4; it’s the idiom for “delete these entries from a distribution.”
Top-p (nucleus) sampling is adaptive where top-k is rigid. Instead of a fixed count, keep the smallest set of tokens whose cumulative probability exceeds p. When the model is confident, that nucleus might be 3 tokens; when it’s uncertain, it might be 300. Fixed k gets both cases wrong.
def top_p_filter(logits, p):
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cum_probs = torch.cumsum(sorted_probs, dim=-1) # (B, vocab), ascending toward 1.0
# drop a token iff the mass BEFORE it already exceeds p
# (cum - own prob = mass of everything ranked above it)
drop = (cum_probs - sorted_probs) > p # first token: 0 > p is False → always kept
sorted_logits = sorted_logits.masked_fill(drop, float('-inf'))
# unsort: scatter the filtered logits back to original vocabulary positions
out = torch.full_like(logits, float('-inf'))
out.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)
return outWalk the subtle parts:
- Sort first. Nucleus membership is defined on the ranked distribution, so we sort logits descending and softmax them (softmax is order-equivariant, so softmax-then-sort and sort-then-softmax agree).
- The
cum_probs - sorted_probstrick. Naively you’d writedrop = cum_probs > p, but that drops the token that crosses the threshold — and with a confident model, the top token alone can have probability > p, leaving you with an all--infrow and a NaN crash in softmax. Subtracting each token’s own probability asks instead: “did the tokens ranked above me already coverp?” The top-ranked token has zero mass above it, so it is provably always kept. (Most tutorials do this with a clunky roll-the-mask-right-by-one; the subtraction is equivalent and vectorized.) - Unsorting with
scatter_. We filtered in sorted order, butmultinomialneeds probabilities indexed by real token ids.out.scatter_(-1, sorted_idx, src)writessrc[b, j]intoout[b, sorted_idx[b, j]]— exactly the inverse of the sort’s gather. Starting from afull_like(-inf)canvas means anything not explicitly written back stays deleted.
Quick sanity check — always run one of these after writing masking code:
logits = torch.tensor([[2.0, 1.0, 0.5, -1.0, -3.0]])
print(F.softmax(top_p_filter(logits, 0.8), dim=-1))tensor([[0.5875, 0.2161, 0.1311, 0.0292, 0.0000]]) # before: five nonzero entries
tensor([[0.6285, 0.2312, 0.1402, 0.0000, 0.0000]]) # after p=0.8: tail deleted, rest renormalized
The survivors’ probabilities went up — softmax over the filtered logits automatically renormalizes the nucleus to sum to 1. You never renormalize by hand.
Why naive generation is slow: the recomputation problem
Run generate() for 500 tokens and watch what happens to the work per step. At step \(t\), we forward a sequence of length \(t\): every layer recomputes keys, values, and attention for all \(t\) positions — but positions \(0..t-2\) are identical to what we computed last step. The transformer is deterministic; token 3’s key vector in layer 2 doesn’t change because token 47 arrived. We’re recomputing an entire prefix, every step, to obtain one new row of logits.
Count it up for generating \(N\) tokens: the no-cache loop forwards \(1 + 2 + \dots + N = O(N^2)\) token positions total, and since attention itself is quadratic in sequence length, the attention FLOPs sum to \(O(N^3)\). With a cache it’s \(O(N)\) token-forwards and \(O(N^2)\) attention FLOPs — the same asymptotics as a single training forward pass.
The fix follows from one observation about the attention equation. At the new position \(t\), attention needs:
- the query for position \(t\) only — queries at old positions produced old outputs we don’t need;
- the keys and values for positions \(0..t\) — old ones included, but they never change.
So: cache K and V per layer, and each step feed the model only the one new token. Compute its q, k, v; append k and v to the cache; attend the single query against the full cached K and V. Note what we do not cache: queries (only the newest is used) and attention weights (recomputed fresh, since the new query attends over everything). Hence the name KV cache.
One pleasant consequence: when the input is a single token, its query may attend to every cached position — they’re all in its past — so no causal mask is needed during decode steps. The triangle mask from Lesson 4 only matters when multiple query positions coexist in one forward pass (the “prefill” of the prompt).
Implementing the KV cache
We’ll thread an optional cache argument through the three levels of Lesson 7’s model: attention → block → GPT. Each level passes it down and hands the updated cache back up. First, attention (this is Lesson 5’s CausalSelfAttention.forward with ~6 new lines):
import math
class CausalSelfAttention(nn.Module):
# __init__ unchanged from Lesson 5: self.qkv, self.proj, self.n_head, self.head_dim
def forward(self, x, cache=None):
B, T, C = x.shape
q, k, v = self.qkv(x).split(C, dim=2)
# (B, T, C) -> (B, n_head, T, head_dim), as on Lesson 5
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
if cache is not None: # decode step: prepend history
past_k, past_v = cache # each (B, n_head, T_past, head_dim)
k = torch.cat([past_k, k], dim=2) # (B, n_head, T_past + T, head_dim)
v = torch.cat([past_v, v], dim=2)
new_cache = (k, v)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, nh, T, T_total)
if T > 1:
# prefill: multiple queries in one pass -> causal mask needed.
# In our generate loop, prefill always starts with an empty cache, so
# T_total == T and the plain lower-triangular mask from Lesson 4 applies.
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
att = att.masked_fill(~mask, float('-inf'))
# T == 1: one query, attending to its entire past -> no mask.
y = (att.softmax(dim=-1) @ v) # (B, nh, T, head_dim)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(y), new_cacheThe shape story: during decode, q is (B, nh, 1, hd) while k is (B, nh, T_total, hd), so att is (B, nh, 1, T_total) — one row of attention weights per head instead of a full T×T matrix. That single row is the saving. The classic bug here is concatenating on the wrong dim: the cache grows along the sequence axis, dim=2 in our (B, nh, T, hd) layout. Concatenate on dim=1 and you’ll “grow heads” instead — shapes will even survive a few steps before exploding, which makes it a miserable bug to find. Assert k.size(2) if in doubt.
The block just forwards the cache through (MLP and LayerNorms are position-wise — they need no history):
class Block(nn.Module):
# __init__ unchanged from Lesson 6
def forward(self, x, cache=None):
attn_out, new_cache = self.attn(self.ln1(x), cache=cache)
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x, new_cacheThe GPT level has the one genuinely easy-to-miss detail — positional offset. When we feed a single token that is really position 41, it must receive positional embedding 41, not 0. The cache itself tells us the offset:
class GPT(nn.Module):
# __init__ and training forward() unchanged from Lesson 7
def forward_cached(self, idx, caches=None):
B, T = idx.shape
past_len = 0 if caches is None else caches[0][0].size(2) # layer-0 cached K length
assert past_len + T <= self.config.block_size, "KV cache overflows the context window"
pos = torch.arange(past_len, past_len + T, device=idx.device) # THE offset
x = self.tok_emb(idx) + self.pos_emb(pos)
new_caches = []
for i, block in enumerate(self.blocks):
x, c = block(x, cache=None if caches is None else caches[i])
new_caches.append(c)
x = self.ln_f(x)
return self.lm_head(x), new_cachesGet pos wrong — pass arange(0, T) every step — and generation still runs, but every token believes it’s at position 0. Output degrades into repetitive sludge with no error message. This is the KV-cache bug to remember.
The assert marks a real limitation: with learned absolute positions, you can’t just drop the oldest cache entry and keep going, because every cached K was computed with its position baked in — sliding the window shifts positions and invalidates the whole cache. The honest simple policy is: generate up to block_size total, and if you need more, recompute the cache from the cropped sequence. Lesson 10 covers how RoPE-style positions make real LLMs handle this gracefully.
Now the cached generation loop — prefill once, then one-token steps, reusing the exact sample_next from earlier:
@torch.no_grad()
def generate_cached(model, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
model.eval()
assert idx.size(1) + max_new_tokens <= model.config.block_size
logits, caches = model.forward_cached(idx) # PREFILL: whole prompt, one pass
for _ in range(max_new_tokens):
idx_next = sample_next(logits[:, -1, :], temperature, top_k, top_p)
idx = torch.cat([idx, idx_next], dim=1)
logits, caches = model.forward_cached(idx_next, caches) # DECODE: one token
return idxCompare with the naive loop: there we fed idx[:, -block_size:] (everything) each step; here we feed idx_next (one token) and the cache carries the rest. These two phases — prefill (parallel, compute-bound) and decode (serial, memory-bound) — are the vocabulary of every LLM-serving discussion you’ll ever read; you’ve now implemented both.
Correctness check before benchmarking — with a fixed seed, cached and uncached generation must produce byte-identical output (up to float noise, greedy decoding sidesteps even that):
prompt = torch.tensor([encode("ROMEO:")], device=device)
a = generate(model, prompt, 200, temperature=0.0)
b = generate_cached(model, prompt, 200, temperature=0.0)
assert torch.equal(a, b), "cache changes the math somewhere — find it before trusting speedups"If this assert fires, suspect (in order): the positional offset, the concat dim, or a mask applied during decode steps.
Measuring the speedup
Optimizations you don’t measure are optimizations you don’t have:
import time
def bench(fn, *args, **kwargs):
if device == "cuda":
torch.cuda.synchronize() # CUDA is async: without this you time the *launch*
t0 = time.perf_counter()
out = fn(*args, **kwargs)
if device == "cuda":
torch.cuda.synchronize()
return out, time.perf_counter() - t0
prompt = torch.tensor([encode("ROMEO:")], device=device)
n_new = model.config.block_size - prompt.size(1) # fill the window
_ = generate(model, prompt, 10) # warmup (kernel compile, caches)
_, t_naive = bench(generate, model, prompt, n_new, temperature=0.8, top_k=50)
_, t_cached = bench(generate_cached, model, prompt, n_new, temperature=0.8, top_k=50)
print(f"no cache: {t_naive:.2f}s cached: {t_cached:.2f}s speedup: {t_naive/t_cached:.1f}x")Representative numbers for our tiny-GPT (6 layers, n_embd=384, block_size=256, char-level):
no cache: 8.31s cached: 1.94s speedup: 4.3x
Two honest observations about that number. First, the speedup grows with sequence length — the naive loop’s cost per step is proportional to current length, the cached loop’s is nearly flat, so at block_size=1024 the same experiment gives ~12x, and at production context lengths the naive loop is simply unusable. Second, on GPU with a tiny model you may see less than the FLOP math predicts: one-token forwards are so small that kernel-launch overhead, not arithmetic, dominates. The cache’s true payoff appears at scale — which is precisely why every serving stack (and further tricks like paged attention) is built around it.
The cost you pay is memory: \(2 \times n_\text{layer} \times B \times n_\text{head} \times T \times d_\text{head}\) floats.
| Model | Cache per token | Full context |
|---|---|---|
| our tiny-GPT (6L, 384d) | ~18 KB (fp32) | ~4.7 MB @ T=256 |
| a 7B LLM (32L, 4096d) | ~0.5 MB (fp16) | ~64 GB @ T=128k |
For us, invisible. For real models, the KV cache — not the weights — is often what limits how many users fit on one GPU.
Hearing the difference: temperature in practice
Numbers are abstract; text isn’t. Same trained checkpoint, same prompt, same seed structure — only the knobs change:
for temp in (0.5, 0.9, 1.4):
torch.manual_seed(42)
out = generate_cached(model, prompt, 250, temperature=temp, top_p=0.95)
print(f"--- temperature {temp} ---")
print(decode(out[0].tolist()))Typical output from a tiny char-level GPT trained on Shakespeare (yours will differ, the character won’t):
--- temperature 0.5 ---
ROMEO:
I will not stay the world with him,
And therefore the prince and the world,
And therefore the world with the prince...
--- temperature 0.9 ---
ROMEO:
What says the gallant? give me thy hand, good friar;
Her eyes are almost dead with weeping sorrow,
That ever I should live to see thee married.
--- temperature 1.4 ---
ROMEO:
Wjy, hazkl! umbition'd greep-fork'd quaths,
Ere naxt bidsong! O crypole, hew?...
At 0.5 the model rides its safest grooves — grammatical, hypnotic, going nowhere, repeating “the world” because that path is always high-probability. At 0.9 with top_p=0.95 you get the sweet spot: coherent surface, real variety. At 1.4 the tail tokens win often enough that even spelling falls apart — a vivid reminder that for a char-level model, every wrong character is a wrong word. There is no universally correct setting; temperature ≈ 0.7–1.0 with top-p ≈ 0.9–0.95 is the standard starting range, and factual tasks want the low end while creative tasks tolerate the high end.
🧪 Your task
Implement min-p sampling — a newer truncation rule that adapts to model confidence even more directly than top-p: keep every token whose probability is at least min_p × (the probability of the most likely token). Confident distribution → tiny survivor set; flat distribution → large one. Write min_p_filter(logits, min_p) following the conventions of top_k_filter/top_p_filter (take and return raw logits, delete with -inf), wire it into sample_next as a min_p=None argument, and verify: (1) on torch.tensor([[2.0, 1.0, 0.5, -1.0, -3.0]]) with min_p=0.2, exactly three tokens survive; (2) at temperature=1.4 your model’s output stays spellable, unlike the plain-sampling sludge above.
Hint: no sorting needed — compute probs = softmax(logits), take the row-wise max with keepdim=True, and mask where probs < min_p * max. The argmax token trivially survives its own threshold, so the all--inf row problem can’t occur.
Solution
def min_p_filter(logits, min_p):
"""Keep tokens with prob >= min_p * prob(argmax). logits: (B, vocab)."""
probs = F.softmax(logits, dim=-1)
threshold = min_p * probs.max(dim=-1, keepdim=True).values # (B, 1)
return logits.masked_fill(probs < threshold, float('-inf'))
def sample_next(logits, temperature=1.0, top_k=None, top_p=None, min_p=None):
if temperature == 0.0:
return logits.argmax(dim=-1, keepdim=True)
logits = logits / temperature
if top_k is not None:
logits = top_k_filter(logits, top_k)
if top_p is not None:
logits = top_p_filter(logits, top_p)
if min_p is not None:
logits = min_p_filter(logits, min_p)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
# --- check 1: survivor count on a known row ---
logits = torch.tensor([[2.0, 1.0, 0.5, -1.0, -3.0]])
filtered = min_p_filter(logits, 0.2)
assert (filtered > float('-inf')).sum().item() == 3, filtered
# probs are [.588, .216, .131, .029, .001]; threshold = 0.2*.588 = .118
# -> tokens 0,1,2 survive, 3 and 4 are deleted.
# --- check 2: high temperature stays coherent ---
torch.manual_seed(42)
out = generate_cached(model, prompt, 250, temperature=1.4) # sludge
torch.manual_seed(42)
out_mp = generate_cached(model, prompt, 250, temperature=1.4, min_p=0.1) # words again
print(decode(out_mp[0].tolist()))Why it works at high temperature: temperature flattens the distribution, but the ratio between a garbage token and the top token stays informative — min-p cuts on that ratio, so it deletes the tail that temperature just inflated. Note min_p_filter runs on post-temperature logits like the other filters, but because both probs and threshold come from the same tempered distribution, the ratio test is what survives.
Key takeaways
- Generation is a loop: crop to
block_size, forward, keep onlylogits[:, -1, :], filter, sample, append. - Temperature scales logits before softmax: \(\tau<1\) sharpens toward greedy, \(\tau>1\) flattens toward uniform; \(\tau=0\) needs an explicit argmax branch.
- Top-k, top-p, and min-p all use one idiom: set unwanted logits to
-inf, let softmax zero and renormalize. Top-p needs the sort/cumsum/scatter-unsort dance, plus thecum - probs > pguard so the top token always survives. - Naive generation recomputes the entire prefix each step — \(O(N^2)\) token-forwards. Keys and values of past tokens never change, so cache them and feed one token per step: that’s the KV cache, and prefill/decode are its two phases.
- The three KV-cache bugs: wrong concat dim (must be the sequence axis), a causal mask applied to single-token decode, and — nastiest — forgetting the positional offset. Always assert cached == uncached under greedy decoding before benchmarking.
- Speedup grows with context length; cache memory, not compute, is what constrains real-model serving.
In the next lesson, Lesson 10: we put your tiny-GPT next to the real thing — what GPT-2/3-class models change (scale, data, RoPE, RMSNorm, flash attention) and how everything you hand-built maps onto them.
🏠 ⚡ Course home | ← Lesson 08 | Lesson 10 → | 📚 All mini-courses