flowchart LR
A[data/sft/train.jsonl<br/>from Lesson 8] --> B[render chat template<br/>user / assistant / end tokens]
B --> C[loss mask<br/>labels = -100 outside<br/>assistant spans]
C --> D[pack into 1024-token<br/>blocks]
D --> E[train_sft.py<br/>lr 3e-5 cosine, 3 epochs<br/>bf16 + compile]
F[checkpoints/ckpt.pt<br/>base model, Lesson 7] --> E
E --> G[checkpoints/sft.pt]
G --> H[eval: exact match +<br/>heuristics on held-out]
G --> I[HF Hub:<br/>wikigpt-124m-sft]
📖 Build Your Own Wikipedia LLM · Lesson 9 — SFT: Teaching WikiGPT to Follow Instructions
🏠 📖 Course home | ← Lesson 08 | Lesson 10 → | 📚 All mini-courses
Lesson 9 — SFT: Teaching WikiGPT to Follow Instructions
In Lesson 7 you finished the big run: a WikiGPT-124M base model that, given any prefix, continues it in fluent encyclopedia-English. In Lesson 8 you built the other half of the equation — ~50k synthetic instruction examples (grounded QA, summarization, extraction) generated by Qwen2.5-7B-Instruct, quality-filtered, split into train/held-out, and published to your public GitHub repo. Right now those two artifacts don’t know about each other. Ask your base model “What is the capital of France?” and it doesn’t answer — it continues, usually with something like “is a common question in European geography textbooks. The question of capitals…”. It’s a text predictor, not an assistant.
Supervised fine-tuning (SFT) closes that gap. Mechanically it is almost boringly identical to pretraining — same model, same cross-entropy loss, same training loop — with exactly three changes: the data is rendered through a chat template, the loss is masked so only assistant tokens train, and the learning rate is tiny so we nudge the model instead of bulldozing it. This lesson builds src/train_sft.py by reusing the Lesson 6 engine, runs it in under an hour for under a dollar, and proves the difference with before/after demos and a small judge-free eval harness.
🎯 In this lesson you will: render dialogs into token ids with the chat template, implement loss masking (labels=-100 outside assistant spans) with a token-level worked example, pack short dialogs into 1024-token blocks, write and run src/train_sft.py at lr 3e-5 for 3 epochs (<1 h, <$1 on the 4090), compare base vs SFT on held-out prompts, build an exact-match + heuristics eval harness, and push wikigpt-124m-sft to the Hugging Face Hub.
Why this works at all — and why it’s cheap
Pretraining taught the model the distribution of Wikipedia text: facts, grammar, discourse structure, 4B tokens’ worth. SFT does not add knowledge. It adds a behavioral prior: “when you see <|user|>...<|end|><|assistant|>, what follows is a direct, concise answer that stops.” That’s a tiny amount of information compared to what pretraining learned — which is why ~15M tokens of SFT data moves behavior dramatically while 4B tokens were needed for competence, and why the learning rate must be ~10× smaller: at pretraining’s peak lr you’d catastrophically overwrite the knowledge you spent $10 acquiring.
The whole pipeline for this lesson:
The chat template: from JSONL to token ids
Back in Lesson 4 we reserved three special tokens in the tokenizer — <|user|>, <|assistant|>, <|end|> — precisely for this moment. They were never seen during pretraining, so their embeddings are essentially untrained; SFT will give them meaning from scratch. A Lesson 8 example:
{"messages": [
{"role": "user", "content": "According to the passage, what year did the Hanseatic League decline?\n\nPassage: ..."},
{"role": "assistant", "content": "The passage states the Hanseatic League declined in the late 16th century."}
]}renders to this token stream:
<|user|> According to the passage, ... <|end|> <|assistant|> The passage states ... <|end|>
Design decisions, each load-bearing:
<|end|>closes every turn, including the user’s. The model needs an unambiguous signal that the prompt is over — without it, it can’t tell “user is done” from “user paused mid-sentence”.- The assistant’s
<|end|>is part of the training target. This is how the model learns to stop. Skip this and your SFT model answers correctly and then rambles forever — the single most common SFT bug. - No system prompt. At 124M, a system turn is capacity spent on tokens the model can barely use. The behavior is the system prompt.
- No newlines around special tokens. They’re single ids; decoration just wastes context.
Loss masking: labels=-100, and exactly why
Here is the crux of SFT. If you trained with plain next-token loss on the rendered sequence, the model would spend most of its gradient budget learning to imitate users — predicting question tokens, passage tokens, typos and all. We only want gradients from the assistant span. PyTorch’s F.cross_entropy(..., ignore_index=-100) gives us this for free: any position whose label is -100 contributes zero loss and zero gradient.
Formally, with \(A\) = the set of positions inside assistant spans (answer tokens plus the closing <|end|>):
\[ \mathcal{L}_{\text{SFT}} = -\frac{1}{|A|} \sum_{t \in A} \log p_\theta(x_t \mid x_{<t}) \]
Every token still flows through the forward pass — the model attends over the full question — but only assistant positions appear in the loss.
The masking function:
IGNORE = -100
def render_and_mask(tok, messages, uid, aid, eid):
"""Render a dialog to (ids, labels). labels==-100 everywhere except
assistant content tokens and the assistant's closing <|end|>."""
ids, labels = [], []
for turn in messages:
content_ids = tok.encode(turn["content"]).ids
if turn["role"] == "user":
span = [uid] + content_ids + [eid]
ids += span
labels += [IGNORE] * len(span) # never learn to be the user
else: # assistant
ids += [aid] + content_ids + [eid]
labels += [IGNORE] + content_ids + [eid] # train content + stop token
return ids, labelsWorked token-level example
Take the dialog user: Capital of France? → assistant: Paris. (token ids illustrative — yours depend on your BPE merges):
| pos | token | id | label |
|---|---|---|---|
| 0 | <|user|> |
32765 | -100 |
| 1 | Capital |
9214 | -100 |
| 2 | of |
291 | -100 |
| 3 | France |
4870 | -100 |
| 4 | ? |
34 | -100 |
| 5 | <|end|> |
32767 | -100 |
| 6 | <|assistant|> |
32766 | -100 |
| 7 | Paris |
6021 | 6021 |
| 8 | . |
17 | 17 |
| 9 | <|end|> |
32767 | 32767 |
Now the subtlety that trips everyone up: the shift. In the training loop, position \(t\)’s logits are scored against label \(t{+}1\) (logits[:, :-1] vs labels[:, 1:]). So the position that learns to produce Paris is position 6 — the <|assistant|> token — because after shifting, label 7 (Paris, unmasked) becomes its target. That’s why <|assistant|> itself carries label -100 yet the model still learns to start answering the instant it sees it. Three predictions train in this example:
| predicting position | input token there | shifted target | learns |
|---|---|---|---|
| 6 | <|assistant|> |
Paris |
how to begin an answer |
| 7 | Paris |
. |
how to continue it |
| 8 | . |
<|end|> |
how to stop |
If you instead masked labels after shifting by span membership, you’d silently drop the “begin” prediction — build labels aligned to ids as above, then do the standard shift, and it works out.
Packing short dialogs
Your SFT examples average ~300 tokens; padding each one to block_size=1024 would waste ~70% of every batch. So we pack: concatenate rendered examples greedily into 1024-token blocks, exactly like pretraining packed articles in Lesson 6, carrying the label mask along. The tail of the last block is padded with <|end|> at label -100.
One honest caveat: with plain causal attention, example B in a block can attend back into example A. The rigorous fix is block-diagonal attention masks (or position-id resets), and production stacks do that. At 124M with each example ending in a hard <|end|>, measured impact is nil, and the fix would mean touching model.py’s attention path. We skip it and say so.
def pack(examples, block_size, eid):
"""Greedy-pack (ids, labels) pairs into fixed blocks. Drops examples
longer than block_size (a handful of over-long passages from Lesson 8)."""
blocks, cur_ids, cur_lab = [], [], []
for ids, labels in examples:
if len(ids) > block_size:
continue # ~0.5% of the Lesson 8 set; not worth truncation logic
if len(cur_ids) + len(ids) > block_size:
pad = block_size - len(cur_ids)
blocks.append((cur_ids + [eid] * pad, cur_lab + [IGNORE] * pad))
cur_ids, cur_lab = [], []
cur_ids += ids; cur_lab += labels
if cur_ids:
pad = block_size - len(cur_ids)
blocks.append((cur_ids + [eid] * pad, cur_lab + [IGNORE] * pad))
return blocksPacking takes utilization from ~30% to ~97%, which is most of the reason this whole lesson costs under a dollar.
src/train_sft.py — the full file
The engine is Lesson 6’s, stripped of what SFT doesn’t need (no streaming memmap dataset — the packed set fits in RAM; no mid-epoch resume — a full run is 20 minutes) and with the masked-label loss added. Config first:
# configs/sft.yaml
base_ckpt: checkpoints/ckpt.pt # Lesson 7 output
tokenizer: tokenizer/tokenizer.json
train_data: data/sft/train.jsonl # from Lesson 8
out_ckpt: checkpoints/sft.pt
block_size: 1024
batch_size: 16 # 16×1024 = 16k tokens/step fits 24GB in bf16
grad_accum: 4 # effective 64k tokens/step
epochs: 3
lr: 3.0e-5 # ~10x below pretraining peak — nudge, don't bulldoze
min_lr: 3.0e-6
warmup_steps: 100
weight_decay: 0.1
grad_clip: 1.0
wandb_project: wikillm
wandb_run: wikillm-sft
sample_every: 200 # log generations to W&B during trainingWhy these numbers: 3e-5 is the standard full-FT band (1e-5–5e-5); at 3e-4 the model overfits templates in half an epoch and forgets facts (eval ppl on plain Wikipedia text visibly degrades — try it). 3 epochs because SFT sets are small and behavior keeps improving through epoch 2–3 before memorization sets in; watch the sample generations, not just loss. Cosine to a 10% floor — same schedule shape as Lesson 6, reused verbatim.
# src/train_sft.py
"""Full-parameter SFT of WikiGPT-124M on the Lesson 8 synthetic dataset.
Reuses the Lesson 6 engine: bf16 autocast, torch.compile, AdamW, cosine lr.
New here: chat rendering, loss masking (labels=-100), dialog packing."""
import json, math, time, argparse
import torch
import torch.nn.functional as F
import yaml
from tokenizers import Tokenizer
from model import GPT, GPTConfig # Lesson 5
IGNORE = -100
# ---------------- data: render -> mask -> pack ----------------
def render_and_mask(tok, messages, uid, aid, eid):
ids, labels = [], []
for turn in messages:
content_ids = tok.encode(turn["content"]).ids
if turn["role"] == "user":
span = [uid] + content_ids + [eid]
ids += span; labels += [IGNORE] * len(span)
else:
ids += [aid] + content_ids + [eid]
labels += [IGNORE] + content_ids + [eid]
return ids, labels
def pack(examples, block_size, eid):
blocks, cur_ids, cur_lab = [], [], []
for ids, labels in examples:
if len(ids) > block_size:
continue
if len(cur_ids) + len(ids) > block_size:
pad = block_size - len(cur_ids)
blocks.append((cur_ids + [eid] * pad, cur_lab + [IGNORE] * pad))
cur_ids, cur_lab = [], []
cur_ids += ids; cur_lab += labels
if cur_ids:
pad = block_size - len(cur_ids)
blocks.append((cur_ids + [eid] * pad, cur_lab + [IGNORE] * pad))
return blocks
def build_dataset(cfg, tok):
uid = tok.token_to_id("<|user|>"); aid = tok.token_to_id("<|assistant|>")
eid = tok.token_to_id("<|end|>")
assert None not in (uid, aid, eid), "special tokens missing — retrain tokenizer per Lesson 4"
examples = []
with open(cfg["train_data"]) as f:
for line in f:
examples.append(render_and_mask(tok, json.loads(line)["messages"], uid, aid, eid))
blocks = pack(examples, cfg["block_size"], eid)
x = torch.tensor([b[0] for b in blocks], dtype=torch.long)
y = torch.tensor([b[1] for b in blocks], dtype=torch.long)
frac = (y != IGNORE).float().mean().item()
print(f"{len(examples)} dialogs -> {len(blocks)} packed blocks, "
f"{x.numel()/1e6:.1f}M tokens/epoch, {frac:.0%} of positions in loss")
return x, y
# ---------------- schedule (Lesson 6, verbatim shape) ----------------
def get_lr(step, total, cfg):
if step < cfg["warmup_steps"]:
return cfg["lr"] * (step + 1) / cfg["warmup_steps"]
t = (step - cfg["warmup_steps"]) / max(1, total - cfg["warmup_steps"])
return cfg["min_lr"] + 0.5 * (cfg["lr"] - cfg["min_lr"]) * (1 + math.cos(math.pi * t))
# ---------------- main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", default="configs/sft.yaml")
cfg = yaml.safe_load(open(ap.parse_args().config))
device = "cuda"
torch.manual_seed(1337)
tok = Tokenizer.from_file(cfg["tokenizer"])
x_all, y_all = build_dataset(cfg, tok)
# resume the BASE model weights (not its optimizer — fresh AdamW for SFT)
ckpt = torch.load(cfg["base_ckpt"], map_location="cpu")
model = GPT(GPTConfig(**ckpt["model_args"]))
model.load_state_dict(ckpt["model"])
model.to(device)
model = torch.compile(model)
optim = torch.optim.AdamW(model.parameters(), lr=cfg["lr"],
betas=(0.9, 0.95), weight_decay=cfg["weight_decay"])
import wandb
wandb.init(project=cfg["wandb_project"], name=cfg["wandb_run"], config=cfg)
B, accum = cfg["batch_size"], cfg["grad_accum"]
steps_per_epoch = len(x_all) // (B * accum)
total_steps = steps_per_epoch * cfg["epochs"]
print(f"{steps_per_epoch} steps/epoch, {total_steps} total")
step, t0 = 0, time.time()
for epoch in range(cfg["epochs"]):
perm = torch.randperm(len(x_all))
for s in range(steps_per_epoch):
lr = get_lr(step, total_steps, cfg)
for g in optim.param_groups:
g["lr"] = lr
optim.zero_grad(set_to_none=True)
loss_acc = 0.0
for micro in range(accum):
idx = perm[(s * accum + micro) * B:(s * accum + micro + 1) * B]
xb = x_all[idx].to(device, non_blocking=True)
yb = y_all[idx].to(device, non_blocking=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
logits = model(xb)
# the shift: position t predicts label t+1; -100 drops masked positions
loss = F.cross_entropy(
logits[:, :-1].reshape(-1, logits.size(-1)),
yb[:, 1:].reshape(-1),
ignore_index=IGNORE) / accum
loss.backward()
loss_acc += loss.item()
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
optim.step()
if step % 10 == 0:
dt = time.time() - t0; t0 = time.time()
toks = 10 * B * accum * cfg["block_size"] if step else B * accum * cfg["block_size"]
print(f"epoch {epoch} step {step}/{total_steps} loss {loss_acc:.4f} "
f"lr {lr:.2e} gnorm {gnorm:.2f} {toks/max(dt,1e-9)/1e3:.0f}k tok/s")
wandb.log({"sft/loss": loss_acc, "sft/lr": lr,
"sft/grad_norm": gnorm.item(), "sft/epoch": epoch}, step=step)
if step % cfg["sample_every"] == 0 and step > 0:
wandb.log({"sft/sample": wandb.Html(
"<pre>" + chat_sample(model, tok, "What is photosynthesis?") + "</pre>")},
step=step)
step += 1
raw = getattr(model, "_orig_mod", model) # unwrap torch.compile
torch.save({"model": raw.state_dict(), "model_args": ckpt["model_args"],
"sft_config": cfg}, cfg["out_ckpt"])
print(f"saved {cfg['out_ckpt']}")
@torch.no_grad()
def chat_sample(model, tok, prompt, max_new=200):
"""Greedy chat completion — also used by judge_eval.py."""
model.eval()
uid, aid, eid = (tok.token_to_id(t) for t in ("<|user|>", "<|assistant|>", "<|end|>"))
ids = [uid] + tok.encode(prompt).ids + [eid, aid]
x = torch.tensor([ids], device="cuda")
for _ in range(max_new):
with torch.autocast("cuda", dtype=torch.bfloat16):
logits = model(x[:, -1024:])
nxt = logits[0, -1].argmax().item()
if nxt == eid:
break
x = torch.cat([x, torch.tensor([[nxt]], device="cuda")], dim=1)
model.train()
return tok.decode(x[0, len(ids):].tolist())
if __name__ == "__main__":
main()Line-by-line notes on the non-obvious parts:
- Fresh optimizer, not the pretraining one. Lesson 7’s AdamW moments encode “moving fast through Wikipedia”; reusing them at 3e-5 causes a loss spike in the first 50 steps. We load only
ckpt["model"]. - The whole packed dataset lives as two tensors in RAM (~45M tokens × 2 × 8 bytes ≈ 700MB) — no dataloader machinery needed at this scale.
getattr(model, "_orig_mod", model)unwrapstorch.compileso the saved state dict has clean key names loadable without compile.- Sanity number to watch: the
% of positions in lossprintout. For the Lesson 8 mix it should land around 35–45%. Near 100% means your mask isn’t applied (you’re training on questions); near 0% means role names didn’t match.
Launch on vast.ai: <1 hour, <$1
Same workflow as Lesson 7, smaller bill. If your Lesson 7/8 instance is still alive, skip straight to rsync; otherwise:
# find a 4090, create with the usual image
vastai search offers 'gpu_name=RTX_4090 num_gpus=1 inet_down>200' -o 'dph+' | head -5
vastai create instance <OFFER_ID> --image pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel --disk 60
# ship code + base checkpoint + SFT data (checkpoint is ~1.5GB, allow a few minutes)
rsync -avz -e "ssh -p PORT" wikillm/src wikillm/configs wikillm/tokenizer root@HOST:/root/wikillm/
rsync -avz -e "ssh -p PORT" wikillm/checkpoints/ckpt.pt root@HOST:/root/wikillm/checkpoints/
rsync -avz -e "ssh -p PORT" wikillm/data/sft root@HOST:/root/wikillm/data/
ssh -p PORT root@HOST
cd /root/wikillm && pip install -r requirements.txt && wandb login
tmux new -s sft
python src/train_sft.py --config configs/sft.yamlBudget check: ~50k dialogs ≈ 15M supervised-relevant tokens/epoch ≈ 45M packed tokens over 3 epochs. At the ~45–50k tok/s you measured in Lesson 6 (packing keeps us near pretraining throughput), that’s ~18–25 minutes of training, plus rsync and setup: comfortably under an hour, $0.20–0.40 at $0.35–0.45/hr. Watch the wikillm-sft run in W&B: loss falls fast from ~2.5 to ~1.6 in epoch 1, then grinds; the interesting panel is sft/sample, where you can literally watch answers get shorter and more direct epoch by epoch.
Before and after: the demo that justifies the lesson
Run the same held-out prompts through both checkpoints (sample.py from Lesson 7 for the base — raw continuation, no template — and chat_sample for SFT). Real representative behavior at this scale:
Prompt: What is the Great Barrier Reef and where is it located?
Base (ckpt.pt), raw continuation: > What is the Great Barrier Reef and where is it located? is one of several questions addressed in the report. The Great Barrier Reef Marine Park Authority was established in 1975 under the Great Barrier Reef Marine Park Act. The authority is responsible for the management of the marine park, which covers an area of approximately… (continues for 200 tokens, never answers, never stops)
SFT (sft.pt), chat template: > The Great Barrier Reef is the world’s largest coral reef system, composed of thousands of individual reefs. It is located in the Coral Sea, off the coast of Queensland, Australia.
Same weights, essentially — the knowledge was already in the base model. SFT taught it the format of answering: address the question, be concise, stop at <|end|>. Also demo a summarization and an extraction prompt from the held-out set; those show the largest before/after gap because the base model has no notion of “the passage above” as a referent. And be honest in your notes: at 124M the SFT model will still confidently hallucinate on questions outside Wikipedia-frequency facts — SFT shapes behavior, it does not create knowledge. That gap is what Lesson 10’s preference tuning starts to address.
A judge-free eval harness
Loss is a weak proxy for “is it a better assistant”. Before reaching for LLM-as-judge (that’s Lesson 10 territory — this file grows a judge half then), we can measure a lot with exact matching and heuristics on the held-out closed-QA split from Lesson 8 (data/sft/heldout.jsonl — questions whose answers are short, unambiguous spans, withheld from training):
# src/judge_eval.py (part 1: judge-free metrics; Lesson 10 adds the LLM judge)
import json, re, argparse, torch
from tokenizers import Tokenizer
from model import GPT, GPTConfig
from train_sft import chat_sample
def normalize(s):
s = s.lower().strip()
s = re.sub(r"\b(a|an|the)\b", " ", s)
return re.sub(r"[^a-z0-9 ]", "", re.sub(r"\s+", " ", s)).strip()
def distinct3(text):
toks = text.split()
tri = [tuple(toks[i:i+3]) for i in range(len(toks) - 2)]
return len(set(tri)) / max(1, len(tri)) # 1.0 = no repetition loops
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="checkpoints/sft.pt")
ap.add_argument("--eval_data", default="data/sft/heldout.jsonl")
args = ap.parse_args()
tok = Tokenizer.from_file("tokenizer/tokenizer.json")
ckpt = torch.load(args.ckpt, map_location="cpu")
model = GPT(GPTConfig(**ckpt["model_args"])).to("cuda").eval()
model.load_state_dict(ckpt["model"])
rows = [json.loads(l) for l in open(args.eval_data)]
em = contains = stopped = d3_sum = len_ok = 0
for r in rows:
q = r["messages"][0]["content"]; gold = r["messages"][1]["content"]
out = chat_sample(model, tok, q, max_new=200)
em += normalize(out) == normalize(gold)
contains += normalize(gold) in normalize(out) # right span, wordier
stopped += len(tok.encode(out).ids) < 200 # emitted <|end|> in budget
len_ok += len(out.split()) <= 3 * max(1, len(gold.split())) # not 10x too long
d3_sum += distinct3(out)
n = len(rows)
print(f"n={n} exact_match={em/n:.3f} contains_answer={contains/n:.3f} "
f"stopped={stopped/n:.3f} length_ok={len_ok/n:.3f} distinct3={d3_sum/n:.3f}")
if __name__ == "__main__":
main()What each metric catches, and rough targets for a healthy run:
| metric | catches | base model | good SFT |
|---|---|---|---|
exact_match |
wrong or unfocused answers | ~0.00 | 0.15–0.30 |
contains_answer |
right knowledge, verbose delivery | ~0.05 | 0.40–0.60 |
stopped |
broken <|end|> training (the #1 bug) |
~0.00 | >0.95 |
length_ok |
rambling | low | >0.85 |
distinct3 |
degenerate repetition loops | varies | >0.90 |
Run it on both checkpoints — --ckpt checkpoints/ckpt.pt gives your base-model row for free (the base scores near zero on stopped, which is the single clearest quantitative signature of what SFT bought you). If stopped is low on the SFT model, your <|end|> label got masked — go re-read the masking table.
Why full fine-tuning, not LoRA
You’ll see LoRA (low-rank adapters) recommended everywhere for fine-tuning, so it deserves an explicit decision, not a default:
- The entire premise of LoRA is that full FT doesn’t fit. A 7B model in full-FT needs ~112GB for weights+grads+AdamW moments in mixed precision; LoRA trains ~0.5% of parameters so it fits on one GPU. Our 124M model needs ~2GB for all of that. It fits 12× over on the 4090. The constraint LoRA relaxes does not exist here.
- Full FT is the quality ceiling; LoRA approximates it. A rank-r update caps each weight delta at rank r. At 124M we can afford the exact thing, so taking the approximation would be paying (in a hyperparameter — rank, alpha, target modules — you’d have to tune) to get less.
- It would be more code, not less: adapter injection into
model.py, merge-at-save logic, an extra config axis. The lazy and the correct choice coincide.
The crossover is real, though: at 7B+ (say, fine-tuning the Qwen2.5-7B teacher itself) LoRA/QLoRA becomes the only single-GPU path, and it’s the right tool there — see the encyclopedia’s PEFT entry for the mechanics. Rule of thumb: full FT while weights+optimizer fit in VRAM with headroom; LoRA when they don’t.
Ship the checkpoint: wikigpt-124m-sft on the Hub
Your Lesson 8 dataset is already public on GitHub; now the model that consumed it goes public too, so the next lessons (and other people) can pull it anywhere:
pip install huggingface_hub
huggingface-cli login # paste a WRITE token from hf.co/settings/tokens# push_sft.py — run once from wikillm/ (local machine or instance, either works)
from huggingface_hub import HfApi
api = HfApi()
repo = api.create_repo("YOUR_USERNAME/wikigpt-124m-sft", exist_ok=True).repo_id
api.upload_file(path_or_fileobj="checkpoints/sft.pt", path_in_repo="sft.pt", repo_id=repo)
api.upload_file(path_or_fileobj="tokenizer/tokenizer.json", path_in_repo="tokenizer.json", repo_id=repo)
api.upload_file(path_or_fileobj="configs/sft.yaml", path_in_repo="sft.yaml", repo_id=repo)
print(f"https://huggingface.co/{repo}")Add a README model card on the Hub with: the architecture line (124M, 12L/12H/768d, RoPE, SwiGLU, RMSNorm, vocab 32768), the chat format (<|user|>...<|end|><|assistant|>...<|end|>, greedy or low-temperature decoding), a link to your GitHub dataset repo, your eval table from the harness, and an honest limitations note (English-only, Wikipedia-knowledge-only, hallucinates outside it). Note this is a raw PyTorch state dict loaded with your model.py — not a transformers-format repo; Lesson 11 wraps it for serving.
🧪 Your task
The Lesson 8 dataset is single-turn, but render_and_mask already loops over turns — so verify it’s actually multi-turn-safe, with a test instead of a vibe. Write a small self-check that (1) builds a 2-user/2-assistant dialog, (2) renders it, and (3) asserts three invariants: every label is either -100 or equals the id at the same position; no token inside a user span (including both its <|user|> and <|end|>) has a trainable label; both assistant <|end|> tokens are trainable. Then compute what fraction of trainable positions in your real packed training set are <|end|> tokens, and explain in one sentence why that number being high (~10%+) is a feature, not a bug.
Solution
# tests/test_masking.py — run: python tests/test_masking.py
import json
from tokenizers import Tokenizer
import sys; sys.path.insert(0, "src")
from train_sft import render_and_mask, IGNORE
tok = Tokenizer.from_file("tokenizer/tokenizer.json")
uid, aid, eid = (tok.token_to_id(t) for t in ("<|user|>", "<|assistant|>", "<|end|>"))
msgs = [
{"role": "user", "content": "Who wrote Hamlet?"},
{"role": "assistant", "content": "William Shakespeare."},
{"role": "user", "content": "When was it written?"},
{"role": "assistant", "content": "Around 1600."},
]
ids, labels = render_and_mask(tok, msgs, uid, aid, eid)
# (1) labels are -100 or echo the id at the same position
assert all(l == IGNORE or l == i for i, l in zip(ids, labels))
# (2) reconstruct spans and check user spans are fully masked
in_user = False
for i, l in zip(ids, labels):
if i == uid: in_user = True
if in_user: assert l == IGNORE, "user span leaked into loss"
if in_user and i == eid: in_user = False # user's <|end|> masked, then span closes
# (3) both assistant <|end|> are trainable
trainable_ends = sum(1 for i, l in zip(ids, labels) if i == eid and l == eid)
assert trainable_ends == 2, f"expected 2 trainable <|end|>, got {trainable_ends}"
print("masking invariants hold")
# fraction of trainable positions that are <|end|> on the real set
total = ends = 0
for line in open("data/sft/train.jsonl"):
_, labs = render_and_mask(tok, json.loads(line)["messages"], uid, aid, eid)
t = [l for l in labs if l != IGNORE]
total += len(t); ends += sum(1 for l in t if l == eid)
print(f"<|end|> share of trainable tokens: {ends/total:.1%}")Typical output: <|end|> share of trainable tokens: ~1–3% for long answers, up to ~10% on short-answer-heavy closed-QA mixes. It’s a feature because every single training example contributes exactly one “stop here” gradient — stopping is the one behavior reinforced by 100% of examples, which is why a correctly-masked SFT model learns to terminate reliably after less than one epoch, and a model whose <|end|> got masked never learns it at all.
Key takeaways
- SFT is pretraining with three edits: chat-templated data, loss masking (labels=-100 outside assistant spans), and ~10× lower lr. The engine, loss function, and schedule shape are reused from Lesson 6.
- The label array aligns with the input ids and is shifted in the loss; that’s why
<|assistant|>carries -100 yet the model still learns to begin answers — its shifted target is the first answer token. - The assistant’s
<|end|>must be a trainable label; masking it produces a model that answers correctly and never stops — the most common SFT bug, and thestoppedmetric catches it instantly. - Packing dialogs into 1024-token blocks takes batch utilization from ~30% to ~97%; we accept cross-example attention without block-diagonal masks at this scale, deliberately.
- 3 epochs at lr 3e-5 cosine over ~50k examples: ~20 minutes on the 4090, well under $1, tracked as
wikillm-sftin thewikillmW&B project. - SFT changes behavior, not knowledge: the base model already knew the answers; SFT taught it to answer, concisely, and stop. It still hallucinates beyond its Wikipedia knowledge.
- Full FT beats LoRA at 124M because the constraint LoRA exists to relax — optimizer state not fitting in VRAM — doesn’t apply; LoRA becomes the right tool at 7B+.
- The checkpoint now lives publicly as
wikigpt-124m-sfton the HF Hub, alongside your public dataset repo from Lesson 8.
Coming up
Your model now answers — but between two valid answers, it has no idea which one a human would prefer; in Lesson 10 you’ll generate preference pairs with your teacher and implement DPO from scratch to teach it exactly that.
🏠 📖 Course home | ← Lesson 08 | Lesson 10 → | 📚 All mini-courses