flowchart TD
F["model.fit(train_ds, validation_data=val_ds, epochs=N)"] --> E{"for each epoch"}
E --> B{"for each batch"}
B --> TS["train_step(batch)"]
TS --> FWD["forward pass<br/>y_pred = model(x, training=True)"]
FWD --> L["loss = compute_loss(y, y_pred)"]
L --> G["GradientTape β gradients"]
G --> O["optimizer applies updates<br/>ΞΈ β ΞΈ β Ξ·βΞΈL"]
O --> M["metrics.update_state(y, y_pred)"]
M --> B
B -->|"epoch ends"| V["evaluate on validation_data<br/>(test_step per batch)"]
V --> CB["callbacks: on_epoch_end(logs)<br/>EarlyStopping Β· ModelCheckpoint<br/>ReduceLROnPlateau Β· TensorBoard"]
CB -->|"continue"| E
CB -->|"stop_training = True"| DONE["return History"]
π Deep Learning with TensorFlow & Keras Β· Day 4 β compile/fit, the Callback Zoo, and Rolling Your Own Training Loop
π π Course home | β Day 03 | Day 05 β | π All mini-courses
Day 4 β compile/fit, the Callback Zoo, and Rolling Your Own Training Loop
Yesterday you built tf.data pipelines that stream batches to a model faster than the GPU can eat them. Today you learn what happens to those batches once they arrive β and, more importantly, who is in charge of that process. Keras gives you a dial with three settings: at one end, model.compile() + model.fit() does everything for you; at the other end, you write the whole training loop yourself with tf.GradientTape (the same tape you met on Day 1); and in the middle sits Keras 3βs train_step override, which lets you customize the math while keeping fit()βs machinery β callbacks, progress bars, distribution, all of it. If you come from PyTorch, hereβs the one-line orientation: the raw GradientTape loop is the PyTorch training loop you already write by hand, and fit() is what PyTorch users install Lightning to get. By the end of today you will have trained the same model all three ways and know exactly when to reach for each.
π― Today you will: wire losses/metrics/optimizers through compile(), train with fit() + validation_data and four essential callbacks, rewrite the identical training as a raw GradientTape loop, override train_step for the best-of-both middle path, and learn the decision rule for choosing between them
Setup: the model and pipeline weβll train three times
We need one fixed setup so the three training styles are directly comparable. Fashion-MNIST, a tf.data pipeline exactly like Day 3βs, and a small MLP built with the Functional API from Day 2.
import tensorflow as tf
import keras
keras.utils.set_random_seed(42) # reproducible across all three runs
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
# Carve out a validation split BEFORE building datasets β never validate on test data.
x_val, y_val = x_train[-5000:], y_train[-5000:]
x_train, y_train = x_train[:-5000], y_train[:-5000]
BATCH = 128
train_ds = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(10_000)
.batch(BATCH)
.prefetch(tf.data.AUTOTUNE)
)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH).prefetch(tf.data.AUTOTUNE)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH).prefetch(tf.data.AUTOTUNE)Two deliberate choices here. First, set_random_seed pins Python, NumPy, and backend RNGs at once, so weight initialization and shuffling match across our three experiments β without it, comparing loss curves is comparing noise. Second, the validation split is taken from the training data. fit() will happily accept validation_split=0.1 and do this for you, but only for array inputs β it cannot split a tf.data.Dataset (it has no idea how long the dataset is). Since Day 3 taught you to live in tf.data, you split manually and pass validation_data=val_ds.
def build_model() -> keras.Model:
inputs = keras.Input(shape=(28, 28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dense(128, activation="relu")(x)
outputs = keras.layers.Dense(10)(x) # NOTE: no softmax β raw logits
return keras.Model(inputs, outputs, name="fashion_mlp")
model = build_model()The output layer has no activation. It emits logits β raw, unnormalized scores of shape (batch, 10). This is on purpose, and it sets up the single most common Keras bug, which we defuse in the next section.
Anatomy of compile(): loss, metrics, optimizer
compile() doesnβt train anything. It attaches three things to the model: how wrong it is (loss), how to fix it (optimizer), and what to report (metrics).
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)You could write optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] β the strings are aliases for default-configured objects. Prefer the object form the moment you care about any hyperparameter, because the string form hides two decisions:
The learning rate. "adam" means Adam(learning_rate=0.001). Fine here, but youβll want control on Day 7, and passing the object is how you get it.
from_logits. This is the trap. The categorical cross-entropy for one example with true class \(y\) is
\[ \mathcal{L} = -\log p_y, \qquad p = \mathrm{softmax}(z) \]
where \(z\) are the logits. Someone has to apply that softmax. Either your last layer does it (activation="softmax") and the loss receives probabilities, or the loss does it internally (from_logits=True) and receives logits. Pick exactly one. The string "sparse_categorical_crossentropy" defaults to from_logits=False. If your model outputs logits and you use the string, Keras treats raw scores like probabilities: loss values go weird, training limps along at mediocre accuracy, and nothing crashes β the worst failure mode in deep learning. Keeping logits in the model and from_logits=True in the loss is the numerically stabler choice (the log-softmax is fused inside the loss), so thatβs our house style.
Sparse vs non-sparse. Our labels are integer class indices, shape (batch,) β thatβs Sparse.... If you one-hot encode to (batch, 10), use CategoricalCrossentropy instead. Mismatching them raises a shape error, which at least fails loudly.
Metrics vs loss. The loss is differentiated; metrics are only watched. Metrics are stateful objects β SparseCategoricalAccuracy keeps running total and count variables, update_state(y, y_pred) folds each batch in, result() reads the epoch-so-far value, and reset_state() zeroes it between epochs. fit() manages this lifecycle for you. In the custom loop, you will β and forgetting reset_state() is the classic bug that makes epoch 2βs accuracy mysteriously blend with epoch 1βs.
fit() and the callback zoo
Here is what fit() actually runs. Keep this picture in your head β everything else today is about which parts of it you take over.
Everything inside train_step is the math; everything around it is bookkeeping. Callbacks are objects that get poked at defined moments (on_train_begin, on_epoch_end, on_batch_end, β¦) with a logs dict containing the current metric values. Four of them cover 95% of real training runs:
callbacks = [
# 1. Stop when val_loss stops improving; roll back to the best epoch's weights.
keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=3, # tolerate 3 stagnant epochs before stopping
restore_best_weights=True, # otherwise you keep the LAST (worse) weights!
),
# 2. Continuously save the best model seen so far to disk.
keras.callbacks.ModelCheckpoint(
filepath="checkpoints/best.keras",
monitor="val_accuracy",
mode="max", # "accuracy improving" means going UP
save_best_only=True,
),
# 3. When val_loss plateaus, cut the learning rate β often buys a late improvement.
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5, # new_lr = lr * 0.5
patience=2, # trigger faster than EarlyStopping does
min_lr=1e-5,
),
# 4. Log everything for the TensorBoard UI.
keras.callbacks.TensorBoard(log_dir="logs/fit_run", histogram_freq=1),
]
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=30, # an upper bound β EarlyStopping decides the real number
callbacks=callbacks,
)Notes on the choreography, because the defaults bite:
patienceinterplay.ReduceLROnPlateauhaspatience=2,EarlyStoppinghaspatience=3. That ordering is deliberate: the LR gets cut before early stopping gives up, so the model gets one cheap rescue attempt. Set them the other way around and the LR schedule never fires.restore_best_weights=Trueis not the default. Without it, EarlyStopping halts training but leaves the model holding the weights from the final β by definition disappointing β epoch.mode. Checkpointing monitorsval_accuracy, where bigger is better, hencemode="max". Keras infers the mode for common names, but being explicit costs nothing and reads clearly..kerasfilepath. Keras 3βs native saving format (whole model: architecture + weights + optimizer state). More on this on Day 9.epochs=30is a budget, not a plan. With EarlyStopping in place, you set epochs generously and let validation loss decide.
Typical output β note the LR cut at epoch 9 and the early stop at 12:
Epoch 1/30
430/430 ββββββββββ 3s 5ms/step - accuracy: 0.8225 - loss: 0.5081 - val_accuracy: 0.8618 - val_loss: 0.3800
Epoch 2/30
430/430 ββββββββββ 2s 4ms/step - accuracy: 0.8710 - loss: 0.3524 - val_accuracy: 0.8720 - val_loss: 0.3492
...
Epoch 9/30
430/430 ββββββββββ 2s 4ms/step - accuracy: 0.9124 - loss: 0.2351 - val_accuracy: 0.8896 - val_loss: 0.3241
Epoch 9: ReduceLROnPlateau reducing learning rate to 0.0005.
...
Epoch 12/30
430/430 ββββββββββ 2s 4ms/step - accuracy: 0.9301 - loss: 0.1882 - val_accuracy: 0.8912 - val_loss: 0.3305
Epoch 12: early stopping. Restoring model weights from the end of the best epoch: 9.
Launch TensorBoard from a terminal with tensorboard --logdir logs and open http://localhost:6006 β you get live loss/metric curves, the LR schedule, and (thanks to histogram_freq=1) weight histograms per layer per epoch. The returned history.history dict holds the same per-epoch numbers ({"loss": [...], "val_loss": [...], ...}) if youβd rather matplotlib them.
Finally, honest numbers come from data the callbacks never saw:
test_loss, test_acc = model.evaluate(test_ds, verbose=0)
print(f"test accuracy: {test_acc:.4f}") # β 0.884The validation set steered EarlyStopping, checkpointing, and the LR schedule β itβs been βused upβ as an unbiased estimate. The test set gives the number you report.
The same training as a raw GradientTape loop
Now we throw all of that away and write it ourselves. This is Day 1βs GradientTape grown up: instead of differentiating \(y = x^2\), we differentiate a full modelβs loss with respect to every weight, and we do it once per batch, thousands of times. Structurally, what follows is a PyTorch training loop with the names changed β tape.gradient where PyTorch has loss.backward(), optimizer.apply_gradients where PyTorch has optimizer.step(), and no zero_grad() at all because the tape is rebuilt fresh each step, so gradients canβt accumulate by accident.
Stage 1 β the pieces compile() used to hold, now as plain variables:
model = build_model() # fresh weights, same seed lineage
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc = keras.metrics.SparseCategoricalAccuracy()
val_acc = keras.metrics.SparseCategoricalAccuracy()
train_loss = keras.metrics.Mean() # running average of per-batch losseskeras.metrics.Mean deserves a word: the per-batch loss jumps around, so what fit() prints is a running mean over the epoch. We replicate that with a Mean metric fed one scalar per step.
Stage 2 β one training step, compiled to a graph:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True) # (128, 10) float32
loss = loss_fn(y, logits) # () scalar
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_loss.update_state(loss)
train_acc.update_state(y, logits)
return loss
@tf.function
def val_step(x, y):
logits = model(x, training=False)
val_acc.update_state(y, logits)Line by line, because every line is load-bearing:
with tf.GradientTape() as tape:β the tape records every op on watched variables inside this block only. The forward pass and the loss must both happen inside; moveloss = loss_fn(...)outside thewithandtape.gradientreturnsNones (the tape never saw the loss being computed), andapply_gradientsthrows.model(x, training=True)β thetrainingflag switches layer behavior (Dropout drops, BatchNorm uses batch statistics). Our MLP has neither, but wire it correctly now; on Day 6 and 7 this flag is the difference between training and silently broken training. Note itβsmodel(x, ...), the direct call β notmodel.predict(), which is a batched-inference convenience wrapper and has no business inside a training loop.tape.gradient(loss, model.trainable_variables)β returns a list of tensors, one per weight, each shaped exactly like its variable: for our MLP thatβs[(784,256), (256,), (256,128), (128,), (128,10), (10,)].trainable_variables, notvariablesβ BatchNormβs moving averages, for instance, are variables that must not receive gradient updates.optimizer.apply_gradients(zip(grads, vars))β the update \(\theta \leftarrow \theta - \eta \, \hat{g}\) (with Adamβs per-parameter scaling folded into \(\hat g\)). The optimizer pairs each gradient with its variable via the zip; the two lists are in matching order because both came from the sametrainable_variablescall.@tf.functionβ traces the Python into a TensorFlow graph on first call, then replays the compiled graph. On this small model itβs a 2β3Γ speedup. Two rules from Day 1 still apply: (a) donβt put Python side effects (print, list appends) inside and expect them per-step β they run only during tracing; (b) keep input shapes consistent, because each new shape triggers an expensive retrace. Our final batch is 4,875 % 128 = a partial batch, which does cost one extra trace β harmless here, butdrop_remainder=Truein.batch()is the fix when it isnβt.
Stage 3 β the outer loop, i.e. everything fit() was doing around the math:
EPOCHS = 10
best_val = 0.0
for epoch in range(EPOCHS):
train_loss.reset_state(); train_acc.reset_state(); val_acc.reset_state()
for x, y in train_ds:
train_step(x, y)
for x, y in val_ds:
val_step(x, y)
va = float(val_acc.result())
print(f"epoch {epoch+1:2d} | loss {float(train_loss.result()):.4f} "
f"| acc {float(train_acc.result()):.4f} | val_acc {va:.4f}")
if va > best_val: # a hand-rolled ModelCheckpoint
best_val = va
model.save_weights("checkpoints/manual_best.weights.h5")epoch 1 | loss 0.5062 | acc 0.8231 | val_acc 0.8606
epoch 2 | loss 0.3531 | acc 0.8708 | val_acc 0.8724
...
epoch 10 | loss 0.2210 | acc 0.9178 | val_acc 0.8874
Same numbers as fit(), within noise β as they must be, since itβs the same math. Now tally what you paid for the control: the metric reset_state() calls, the validation loop, the progress printing, best-model tracking β all yours to write and yours to get wrong. And we didnβt implement early stopping, LR-on-plateau, or TensorBoard logging; each is another 5β15 lines. This is why the raw loop is the tool of last resort, not the default: you write it when the structure of training itself is nonstandard β GANs alternating two optimizers, reinforcement learning, gradient accumulation across micro-batches, custom multi-loss balancing. For everything else, thereβs a better deal.
The middle path: overriding train_step
Hereβs the Keras 3 move that most tutorials undersell: subclass keras.Model, override only train_step, and keep everything else. You rewrite the ~8 lines of math you actually care about; fit() still supplies the epoch loop, tf.function compilation, callbacks, progress bar, validation, and (later) multi-GPU distribution β none of which you re-implement.
class CustomModel(keras.Model):
def train_step(self, data):
x, y = data # one batch, as fit() unpacked it
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compute_loss(y=y, y_pred=y_pred) # the loss given to compile()
grads = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
# Update the metrics configured in compile()
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}The contract, piece by piece:
datais exactly one element of whatever you passed tofit()β for ourtrain_ds, an(x, y)tuple. (Passsample_weightin your dataset and it arrives as a 3-tuple; unpack accordingly.)self.compute_loss(y=y, y_pred=y_pred)is Keras 3βs bridge to whatever loss you gavecompile()β including any regularization losses layers have registered. Using it (rather than a hardcodedloss_fn) keeps your subclass reusable with any compiled loss.- The metrics loop feeds the same metric objects
compile()created; the built-inlosstracker takes the scalar, the rest take(y, y_pred). - The return dict is what the progress bar prints and β crucially β what callbacks receive as
logs. Return{"loss": ..., "accuracy": ...}andEarlyStopping(monitor="val_loss")works on your custom step unchanged, becausefit()runs your metrics againstvalidation_datatoo (via the defaulttest_step; override that as well if your evaluation math is also custom). - No
@tf.functionneeded βfit()wraps yourtrain_stepin one automatically.
Now use it exactly like any Keras model β note that building it requires the functional-style call, or subclass with layers as on Day 2:
inputs = keras.Input(shape=(28, 28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dense(128, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = CustomModel(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)
model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks) # everything still worksThe demo above reproduces the default behavior β deliberately, so you can verify the plumbing (it should match the earlier runs). The payoff comes when you change the math: clip gradients before applying them (grads = [tf.clip_by_norm(g, 1.0) for g in grads]), add a custom auxiliary loss, implement mixup inside the step, log gradient norms as a metric. Each of those is a 1β3 line edit here, versus owning a whole loop.
One caveat worth knowing: this backend-specific train_step (it uses tf.GradientTape) ties the model to the TensorFlow backend. Keras 3 also runs on JAX and PyTorch backends, each with its own override idiom β fine for this course, where TF is the point, but donβt be surprised when a JAX tutorialβs train_step looks different.
Choosing your altitude
compile + fit |
train_step override |
raw GradientTape loop | |
|---|---|---|---|
| Epoch/batch loop | Keras | Keras | you |
| Forward/backward/update | Keras | you | you |
| Callbacks, progress, validation | Keras | Keras | you |
Metric lifecycle (reset_state) |
Keras | Keras | you |
tf.function compilation |
automatic | automatic | you add it |
| Typical use | almost everything | clipping, mixup, aux losses, custom logging | GANs, RL, exotic loop structure |
The decision rule: start at level 1. Drop to level 2 the moment you need to touch the per-batch math. Drop to level 3 only when the shape of the loop itself is nonstandard. Moving down is always possible; the mistake is starting at the bottom βfor flexibilityβ and then hand-maintaining early stopping forever.
π§ͺ Your task
Implement gradient clipping with gradient-norm logging via a train_step override. Concretely: subclass keras.Model so that (a) gradients are clipped by global norm to 1.0 before being applied (tf.clip_by_global_norm), and (b) the pre-clip global gradient norm appears in the progress bar and history as a metric named grad_norm. Train it on train_ds for 5 epochs with the EarlyStopping callback attached, and confirm grad_norm shows up in the logs and decreases over training.
Hint: tf.clip_by_global_norm(grads, 1.0) returns a tuple (clipped_grads, global_norm) β it hands you both things you need in one call. Track the norm with a keras.metrics.Mean created in __init__, and remember that any metric you want fit() to reset each epoch must be returned by a metrics property (or simply included in your returned dict and tracked manually).
Solution
import tensorflow as tf
import keras
class ClippedModel(keras.Model):
def __init__(self, *args, clip_norm=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.clip_norm = clip_norm
self.grad_norm_tracker = keras.metrics.Mean(name="grad_norm")
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compute_loss(y=y, y_pred=y_pred)
grads = tape.gradient(loss, self.trainable_variables)
# one call gives clipped grads AND the pre-clip global norm
clipped, global_norm = tf.clip_by_global_norm(grads, self.clip_norm)
self.optimizer.apply_gradients(zip(clipped, self.trainable_variables))
self.grad_norm_tracker.update_state(global_norm)
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
elif metric.name != "grad_norm": # ours is fed above
metric.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
@property
def metrics(self):
# include our tracker so fit() resets it every epoch
return super().metrics + [self.grad_norm_tracker]
inputs = keras.Input(shape=(28, 28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dense(128, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = ClippedModel(inputs, outputs, clip_norm=1.0)
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=5,
callbacks=[keras.callbacks.EarlyStopping(monitor="val_loss", patience=3,
restore_best_weights=True)],
)
print(history.history["grad_norm"])
# e.g. [1.9137, 0.9821, 0.7414, 0.6350, 0.5561] β large early, shrinking as training settlesExpected behavior: epoch 1βs mean grad_norm exceeds 1.0 (clipping is actually firing early on), then it drops below the threshold as the loss surface flattens β at which point the clip becomes a no-op safety net, which is exactly what you want. Callbacks, validation, and the progress bar all worked untouched: thatβs the middle path earning its keep.
Key takeaways
compile()attaches loss, optimizer, and metrics; prefer object form over strings, and matchfrom_logits=Trueto a no-softmax output layer β the mismatch trains badly without crashing.fit(validation_data=...)+ four callbacks (EarlyStoppingwithrestore_best_weights=True,ModelCheckpoint,ReduceLROnPlateauwith shorter patience than EarlyStopping,TensorBoard) is the production-default training recipe.- The raw loop is
GradientTapearound the forward pass,tape.gradientontrainable_variables,optimizer.apply_gradients,@tf.functionfor speed β and you inherit all the bookkeepingfit()was doing, including metricreset_state(). - Keras 3βs
train_stepoverride is the middle path: rewrite only the per-batch math, keep callbacks/validation/distribution for free;self.compute_lossand the metrics loop are the contract. - Decision rule:
fit()by default βtrain_stepwhen the stepβs math changes β full custom loop only when the loopβs structure changes.
Tomorrow we put Days 1β4 together into a complete classification project β real dataset, preprocessing to prediction, with all the evaluation you should never skip.
π π Course home | β Day 03 | Day 05 β | π All mini-courses