flowchart LR
A["optimizer.zero_grad()<br/>clear old gradients"] --> B["pred = model(xb)<br/>forward pass"]
B --> C["loss = loss_fn(pred, yb)<br/>scalar score"]
C --> D["loss.backward()<br/>compute gradients"]
D --> E["optimizer.step()<br/>update weights"]
E -->|next batch| A
style A fill:#f59e0b33,stroke:#f59e0b
style B fill:#6366f133,stroke:#6366f1
style C fill:#ec489933,stroke:#ec4899
style D fill:#38bdf833,stroke:#38bdf8
style E fill:#22c55e33,stroke:#22c55e
🔥 Deep Learning with PyTorch · Day 4 — The Training Loop, Properly
🏠 🔥 Course home | ← Day 03 | Day 05 → | 📚 All mini-courses
Day 4 — The Training Loop, Properly
Yesterday you built Dataset and DataLoader pipelines that hand you clean, shuffled mini-batches on demand. Today you close the circuit: batches flow in, gradients flow back, and weights actually move. The training loop is the beating heart of every PyTorch project — and it’s also where most silent bugs live. Forgetting zero_grad(), feeding softmax outputs into CrossEntropyLoss, validating with dropout still active — none of these crash. They just quietly train a worse model. By the end of today you’ll understand every line of the canonical loop well enough to spot those bugs on sight, and you’ll have wrapped it all into a reusable train() function with per-epoch validation that you’ll use for the rest of the course.
🎯 Today you will: dissect the five-step training loop and prove why each step exists, choose between MSELoss and CrossEntropyLoss (and learn the logits-not-softmax rule), compare SGD and Adam on the same problem, master model.train() / model.eval() / torch.no_grad(), and ship a reusable train() function with validation tracking
The anatomy of one training step
Every PyTorch training loop — from a linear regression to a 70-billion-parameter language model — is the same five steps repeated:
Let’s build the pieces we need, then run exactly one step under a microscope. First, data — a two-moons classification set built from raw tensors (Day 3 skills, no sklearn required):
import math
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split
torch.manual_seed(42)
def make_moons(n=2000, noise=0.1):
t = torch.rand(n // 2) * math.pi # angles in [0, π)
upper = torch.stack([torch.cos(t), torch.sin(t)], dim=1)
lower = torch.stack([1 - torch.cos(t), 0.5 - torch.sin(t)], dim=1)
X = torch.cat([upper, lower]) + noise * torch.randn(n, 2)
y = torch.cat([torch.zeros(n // 2), torch.ones(n // 2)]).long()
return X, y
X, y = make_moons()
print(X.shape, X.dtype, y.shape, y.dtype)torch.Size([2000, 2]) torch.float32 torch.Size([2000]) torch.int64
Two details are load-bearing here. The inputs are float32 — the default parameter dtype of every nn layer, so no casting surprises. The labels are int64 (.long()) — CrossEntropyLoss demands integer class indices and will throw RuntimeError: expected scalar type Long if you hand it floats. Getting dtypes right at the dataset boundary saves you from sprinkling .float() and .long() all over the loop.
Now split and wrap, exactly as on Day 3:
ds = TensorDataset(X, y)
train_ds, val_ds = random_split(ds, [1600, 400],
generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=256) # no shuffle neededNote the asymmetry: the training loader shuffles (gradient descent on ordered data can bias updates), the validation loader doesn’t (order can’t affect a metric you’re only averaging). Validation also gets a bigger batch size — no gradients will be stored, so memory is cheap there. More on that soon.
And a model — a small MLP with dropout, deliberately included so we can see train() vs eval() matter later:
model = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(64, 2), # 2 output logits, one per class
)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)The final layer has no softmax. That’s not an omission — it’s the single most important rule of today, and we’ll dedicate a whole section to it.
Now, one training step, expanded to slow motion:
xb, yb = next(iter(train_loader)) # one mini-batch: [64, 2], [64]
# 1) clear gradients left over from the previous step
optimizer.zero_grad()
# 2) forward: [64, 2] inputs -> [64, 2] logits
logits = model(xb)
# 3) reduce the whole batch to ONE scalar
loss = loss_fn(logits, yb)
print("loss:", loss.item()) # ≈ 0.72 (near ln(2), i.e. random guessing)
# 4) backward: populate .grad on every parameter
loss.backward()
print("grad norm of first layer:", model[0].weight.grad.norm().item())
# 5) update: w <- w - lr * grad, for every parameter
optimizer.step()loss: 0.7241...
grad norm of first layer: 0.31...
Two sanity checks are baked into that output, and you should get in the habit of reading them. An untrained 2-class classifier should sit near \(-\ln(1/2) \approx 0.693\) — if your initial loss is 4.7 on a 2-class problem, something upstream is broken (for \(C\) classes, expect \(\ln C\)). And the gradient norm being a sensible non-zero number tells you the graph is connected — a grad of None or exactly zero means you detached something.
Why zero_grad() exists at all
PyTorch accumulates gradients into .grad rather than overwriting them. Watch:
w = torch.tensor([1.0], requires_grad=True)
loss = (3 * w).sum()
loss.backward()
print(w.grad) # tensor([3.])
loss = (3 * w).sum()
loss.backward()
print(w.grad) # tensor([6.]) <- accumulated, not replaced!Forget zero_grad() in your loop and every step applies the sum of all previous batches’ gradients — the effective learning rate silently grows without bound and training diverges into NaN, usually after looking fine for a few dozen steps. It’s a maddening bug precisely because nothing errors.
Why does PyTorch do this? Because accumulation is a feature when you want it: gradient accumulation simulates large batches on small GPUs by calling backward() on several micro-batches before a single step() + zero_grad(). The default just puts the responsibility on you. Two idioms you’ll see:
optimizer.zero_grad() # the classic
optimizer.zero_grad(set_to_none=True) # default since PyTorch 2.0: frees memory
# by setting .grad = None instead of zeroingIn PyTorch 2.x, set_to_none=True is already the default, so plain optimizer.zero_grad() does the efficient thing. Position it anywhere before backward() — start-of-loop is conventional because it’s impossible to forget.
Loss functions: MSE, CrossEntropy, and the logits rule
The loss function defines what “better” means. Pick the wrong one and gradient descent will happily optimize the wrong thing.
| Loss | Task | Model outputs | Targets | Shapes (batch \(N\)) |
|---|---|---|---|---|
nn.MSELoss |
regression | real values | real values | [N, *] vs [N, *] — must match |
nn.CrossEntropyLoss |
multi-class | raw logits | class indices (int64) |
[N, C] vs [N] |
nn.BCEWithLogitsLoss |
binary / multi-label | raw logits | floats in {0, 1} | [N, *] vs [N, *] |
For regression, MSELoss is exactly what you’d write by hand:
\[\mathcal{L}_{\text{MSE}} = \frac{1}{N}\sum_{i=1}^{N} (\hat{y}_i - y_i)^2\]
mse = nn.MSELoss()
pred = torch.tensor([[2.5], [0.0], [1.0]])
target = torch.tensor([[3.0], [0.0], [2.0]])
print(mse(pred, target)) # ((0.5)² + 0² + 1²) / 3tensor(0.4167)
One shape trap: if pred is [N, 1] and target is [N], broadcasting turns your loss into an [N, N] comparison averaged down to a plausible-looking scalar. Modern PyTorch warns about this, but the fix is yours: keep shapes identical (target.unsqueeze(1) or pred.squeeze(1)).
For classification, CrossEntropyLoss compares a predicted distribution against the true class:
\[\mathcal{L}_{\text{CE}} = -\frac{1}{N}\sum_{i=1}^{N} \log p_{i, y_i} \qquad \text{where } p_i = \mathrm{softmax}(z_i)\]
And here is the rule that trips up everyone once: nn.CrossEntropyLoss applies softmax internally. Feed it raw logits. Never put a softmax layer before it.
Let’s prove the damage numerically:
import torch.nn.functional as F
logits = torch.tensor([[4.0, -2.0, 0.0]]) # model is very confident: class 0
target = torch.tensor([0])
print("logits -> CE: ", F.cross_entropy(logits, target).item())
probs = logits.softmax(dim=1) # [[0.9796, 0.0024, 0.0179]]
print("softmax -> CE (BUG): ", F.cross_entropy(probs, target).item())logits -> CE: 0.0206
softmax -> CE (BUG): 0.6448
The correct pipeline sees a confident, correct prediction and reports a near-zero loss. The buggy pipeline squashes [4, -2, 0] into [0.98, 0.002, 0.02] — which, treated as logits, is a nearly-uniform distribution — and reports substantial loss no matter how good the model gets. The floor of the buggy loss is far above zero, gradients shrink toward nothing, and training plateaus early. Again: no error message. Just a mediocre model.
The same rule holds for binary classification: use nn.BCEWithLogitsLoss on raw logits rather than sigmoid + nn.BCELoss — the fused version is also more numerically stable via the log-sum-exp trick, exactly the trick you met when you built softmax by hand in the encyclopedia’s numerical-stability entry.
When you do need probabilities — for reporting, or picking the predicted class — apply softmax outside the loss, typically at inference:
with torch.no_grad():
probs = model(xb).softmax(dim=1) # fine here: not feeding a loss
predicted_class = probs.argmax(dim=1) # (argmax of logits gives the same answer)Optimizers: what step() actually does
loss.backward() computes gradients; the optimizer decides what to do with them. Plain SGD is one line of math applied to every parameter:
\[\theta \leftarrow \theta - \eta \, \nabla_\theta \mathcal{L}\]
Everything else — momentum, Adam, AdamW — is refinements on how the gradient is smoothed and scaled per-parameter. You know the theory from the encyclopedia; here’s what matters practically in PyTorch:
# The three you'll actually use, in ascending order of "just works":
opt_sgd = torch.optim.SGD(model.parameters(), lr=0.1)
opt_mom = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
opt_adamw = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)Note the learning rates differ by two orders of magnitude — that’s not arbitrary. Adam-family optimizers normalize each parameter’s update by a running estimate of its gradient magnitude, so lr means something different: 1e-3 is the standard Adam/AdamW starting point, while raw SGD often wants 1e-2 to 1e-1. Copying an SGD learning rate into Adam is a classic way to blow up training.
Let’s race SGD against AdamW on our moons, holding everything else fixed:
def quick_train(opt_name, epochs=15):
torch.manual_seed(42) # identical init for a fair fight
m = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),
nn.Dropout(0.2), nn.Linear(64, 2))
opt = (torch.optim.SGD(m.parameters(), lr=0.1) if opt_name == "sgd"
else torch.optim.AdamW(m.parameters(), lr=1e-3, weight_decay=0.01))
for _ in range(epochs):
m.train()
for xb, yb in train_loader:
opt.zero_grad()
loss = loss_fn(m(xb), yb)
loss.backward()
opt.step()
return loss.item()
print("SGD final batch loss: ", round(quick_train("sgd"), 4))
print("AdamW final batch loss:", round(quick_train("adamw"), 4))SGD final batch loss: 0.1042
AdamW final batch loss: 0.0891
On a problem this easy both win; the differences show up at scale. The working defaults for this course: AdamW with lr=1e-3 as the first thing you try, SGD-with-momentum when you’re chasing the last percent on vision tasks (Day 7 will add learning-rate schedules on top). Prefer AdamW over classic Adam — it decouples weight decay from the adaptive scaling, which is what you almost always want when regularizing.
One more API point: the optimizer holds a reference to model.parameters() from construction time. Create the optimizer after the model is on its final device, and if you ever rebuild the model, rebuild the optimizer too — a stale optimizer pointing at orphaned parameters updates nothing, silently.
model.train(), model.eval(), and torch.no_grad()
Two switches control evaluation-time behavior, they do different things, and you need both.
model.train() / model.eval() toggle layer behavior. Dropout drops activations in train mode and becomes a no-op in eval mode; BatchNorm (Day 7) uses batch statistics in train mode and running averages in eval mode. Watch dropout misbehave:
model.train()
x = torch.randn(1, 2)
print(model(x)) # run twice: DIFFERENT outputs — dropout is sampling masks
print(model(x))
model.eval()
print(model(x)) # run twice: IDENTICAL — dropout disabled
print(model(x))tensor([[ 0.213, -0.145]], grad_fn=<AddmmBackward0>)
tensor([[ 0.187, -0.201]], grad_fn=<AddmmBackward0>)
tensor([[ 0.198, -0.176]], grad_fn=<AddmmBackward0>)
tensor([[ 0.198, -0.176]], grad_fn=<AddmmBackward0>)
Validate in train mode and your metrics are both noisy and pessimistic — you’re evaluating a randomly-lobotomized model. This is the most common cause of “my validation accuracy jumps around by 5% between runs”.
torch.no_grad() toggles autograd bookkeeping. Inside the context, PyTorch skips building the computation graph: no grad_fn, no saved intermediate activations. You can’t call backward(), and you don’t pay memory for the ability to:
model.eval()
with torch.no_grad():
out = model(x)
print(out.requires_grad, out.grad_fn) # False NoneFor a small MLP the savings are trivial; for a real network, the saved activations for backward dominate memory — which is exactly why our validation loader could afford batch_size=256. PyTorch 2.x also offers torch.inference_mode(), a stricter, slightly faster variant for pure inference — fine to use in validation, but no_grad() is the habit that always works.
The pairing rule, worth memorizing as a 2×2:
| graph built? | dropout/BN mode | |
|---|---|---|
| training step | yes | train |
| validation step | no (no_grad) |
eval |
Miss eval() → wrong numbers. Miss no_grad() → correct numbers, wasted memory, and validation loss tensors that keep graphs alive. Miss both → both problems. Always both.
The reusable train() function
Time to assemble everything into the function you’ll import for the rest of the course. Design goals: takes any model/loaders/loss/optimizer, validates every epoch, returns a history you can plot, and handles devices.
def accuracy(logits, targets):
"""Fraction of correct predictions. logits: [N, C], targets: [N]."""
return (logits.argmax(dim=1) == targets).float().mean().item()argmax(dim=1) collapses [N, C] logits to [N] predicted classes; the comparison yields a bool tensor; .float().mean() turns hit-rate into a fraction. No softmax needed — argmax of logits equals argmax of probabilities, because softmax is monotonic.
Now the loop itself, with each design decision annotated below:
def train(model, train_loader, val_loader, loss_fn, optimizer,
epochs, device="cpu"):
model.to(device)
history = {"train_loss": [], "val_loss": [], "val_acc": []}
for epoch in range(1, epochs + 1):
# ---- training phase ----
model.train()
running = 0.0
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
loss = loss_fn(model(xb), yb)
loss.backward()
optimizer.step()
running += loss.item() * xb.size(0) # sum, weighted by batch size
train_loss = running / len(train_loader.dataset)
# ---- validation phase ----
model.eval()
val_running, val_correct = 0.0, 0
with torch.no_grad():
for xb, yb in val_loader:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
val_running += loss_fn(logits, yb).item() * xb.size(0)
val_correct += (logits.argmax(dim=1) == yb).sum().item()
val_loss = val_running / len(val_loader.dataset)
val_acc = val_correct / len(val_loader.dataset)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
print(f"epoch {epoch:3d} | train {train_loss:.4f} "
f"| val {val_loss:.4f} | acc {val_acc:.3f}")
return historyThe details that separate this from a naive loop:
loss.item(), notloss..item()extracts a Python float and detaches from the graph. Accumulating raw loss tensors (running += loss) keeps every batch’s entire computation graph alive — a textbook PyTorch memory leak that grows all epoch.- Weighted averaging.
loss_fnreturns the mean over the batch. If the last batch is smaller (1600 isn’t divisible by 64 — the last batch has 64 here, but often it won’t be), averaging the per-batch means over-weights the small batch. Multiplying byxb.size(0)and dividing by the dataset size gives the exact per-sample mean regardless of batch sizes. - Correct counts, not accuracy means. Same reasoning for accuracy: sum raw correct counts, divide once at the end.
- Device moves inside the loop. The model moves once; the data moves per batch, because the
DataLoaderyields CPU tensors (Day 3’s workers run on CPU). This pattern scales unchanged to GPU: just passdevice="cuda". - Both switches in both phases.
model.train()is set at the top of every epoch, not once before the loop — because the validation phase flipped it toevalat the end of the previous epoch. Forgetting this trains epochs 2+ with dropout off.
Run it:
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),
nn.Dropout(0.2), nn.Linear(64, 2))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
history = train(model, train_loader, val_loader,
nn.CrossEntropyLoss(), optimizer, epochs=30)epoch 1 | train 0.6716 | val 0.6295 | acc 0.828
epoch 2 | train 0.5771 | val 0.5273 | acc 0.860
epoch 3 | train 0.4841 | val 0.4297 | acc 0.880
...
epoch 15 | train 0.1729 | val 0.1364 | acc 0.955
...
epoch 30 | train 0.1042 | val 0.0801 | acc 0.978
Two readings worth internalizing. First, the curve shape: fast early drops, then diminishing returns — healthy. Second, validation loss sits below training loss here, which surprises people: training loss is measured with dropout active (a handicapped model) while validation is measured with dropout off. That gap direction is normal for dropout-regularized models; it flips into the classic overfitting picture (val above train, then rising) when the model starts memorizing — which is exactly what your history dict lets you detect:
import matplotlib.pyplot as plt
plt.plot(history["train_loss"], label="train")
plt.plot(history["val_loss"], label="val")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.show()Keep train() and accuracy() in a file — training.py — because Days 5 through 8 import them. Everything we add later (schedulers, gradient clipping, mixed precision) slots into this skeleton without changing its shape.
🧪 Your task
Real training runs don’t use a fixed epoch count — they stop when validation stops improving and keep the best weights, not the last ones. Extend train() with early stopping: add a patience parameter (default 5); track the best validation loss seen so far and snapshot the model’s weights whenever it improves; if patience consecutive epochs pass without improvement, stop, restore the best weights into the model, and print which epoch they came from. Test it by training your moons model with epochs=200, patience=5 — it should stop long before 200.
Hint: model.state_dict() returns a dict of parameter tensors, and model.load_state_dict(d) restores them — but the tensors in the dict are references to the live weights, so snapshot with copy.deepcopy(model.state_dict()) or the “best” weights will keep training along with the model.
Solution
import copy
def train(model, train_loader, val_loader, loss_fn, optimizer,
epochs, device="cpu", patience=5):
model.to(device)
history = {"train_loss": [], "val_loss": [], "val_acc": []}
best_val = float("inf")
best_state = None
best_epoch = 0
bad_epochs = 0
for epoch in range(1, epochs + 1):
# ---- training phase ----
model.train()
running = 0.0
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
loss = loss_fn(model(xb), yb)
loss.backward()
optimizer.step()
running += loss.item() * xb.size(0)
train_loss = running / len(train_loader.dataset)
# ---- validation phase ----
model.eval()
val_running, val_correct = 0.0, 0
with torch.no_grad():
for xb, yb in val_loader:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
val_running += loss_fn(logits, yb).item() * xb.size(0)
val_correct += (logits.argmax(dim=1) == yb).sum().item()
val_loss = val_running / len(val_loader.dataset)
val_acc = val_correct / len(val_loader.dataset)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
print(f"epoch {epoch:3d} | train {train_loss:.4f} "
f"| val {val_loss:.4f} | acc {val_acc:.3f}")
# ---- early stopping bookkeeping ----
if val_loss < best_val:
best_val = val_loss
best_epoch = epoch
best_state = copy.deepcopy(model.state_dict()) # deepcopy is essential
bad_epochs = 0
else:
bad_epochs += 1
if bad_epochs >= patience:
print(f"early stop at epoch {epoch}; "
f"restoring best weights from epoch {best_epoch} "
f"(val {best_val:.4f})")
break
if best_state is not None:
model.load_state_dict(best_state)
return history
# --- test on the moons ---
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),
nn.Dropout(0.2), nn.Linear(64, 2))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
history = train(model, train_loader, val_loader, nn.CrossEntropyLoss(),
optimizer, epochs=200, patience=5)Typical output: validation loss bottoms out somewhere around epoch 60–90 on this problem, wobbles for five epochs, and the run halts with the best weights restored — a fraction of the 200-epoch budget. The deepcopy is the part most solutions get wrong: state_dict() gives you references, so without copying, “best_state” would silently track the current (later, possibly worse) weights, and the restore would be a no-op.
Key takeaways
- The loop is always the same five beats:
zero_grad → forward → loss → backward → step. Gradients accumulate by design, so a forgottenzero_grad()diverges silently. CrossEntropyLosseats raw logits and softmaxes internally; adding your own softmax double-normalizes, caps confidence, and starves gradients — with no error message. Same logic:BCEWithLogitsLossoversigmoid+BCELoss.MSELossfor regression with matching shapes;CrossEntropyLosstakes[N, C]logits vs[N]int64 class indices.- AdamW at
lr=1e-3is the sane default; SGD wants a learning rate ~100× larger. Never copy an lr across optimizer families. - Validation needs both switches:
model.eval()(fixes dropout/BatchNorm behavior) andtorch.no_grad()(skips graph building, saves memory). They are independent and not interchangeable. - Track losses with
.item()(detach!), weight per-batch means by batch size, and return a history dict — a reusabletrain()intraining.pynow powers the rest of the course.
Tomorrow: classification end to end — real image data, from raw pixels to a confusion matrix, using today’s train() unchanged.