Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

On this day

  • Day 4 β€” compile/fit, the Callback Zoo, and Rolling Your Own Training Loop
    • Setup: the model and pipeline we’ll train three times
    • Anatomy of compile(): loss, metrics, optimizer
    • fit() and the callback zoo
    • The same training as a raw GradientTape loop
    • The middle path: overriding train_step
    • Choosing your altitude
    • πŸ§ͺ Your task
    • Key takeaways

πŸ“Š Deep Learning with TensorFlow & Keras Β· Day 4 β€” compile/fit, the Callback Zoo, and Rolling Your Own Training Loop

🏠 πŸ“Š Course home  |  ← Day 03  |  Day 05 β†’  |  πŸ“š All mini-courses


Day 4 β€” compile/fit, the Callback Zoo, and Rolling Your Own Training Loop

Yesterday you built tf.data pipelines that stream batches to a model faster than the GPU can eat them. Today you learn what happens to those batches once they arrive β€” and, more importantly, who is in charge of that process. Keras gives you a dial with three settings: at one end, model.compile() + model.fit() does everything for you; at the other end, you write the whole training loop yourself with tf.GradientTape (the same tape you met on Day 1); and in the middle sits Keras 3’s train_step override, which lets you customize the math while keeping fit()’s machinery β€” callbacks, progress bars, distribution, all of it. If you come from PyTorch, here’s the one-line orientation: the raw GradientTape loop is the PyTorch training loop you already write by hand, and fit() is what PyTorch users install Lightning to get. By the end of today you will have trained the same model all three ways and know exactly when to reach for each.

🎯 Today you will: wire losses/metrics/optimizers through compile(), train with fit() + validation_data and four essential callbacks, rewrite the identical training as a raw GradientTape loop, override train_step for the best-of-both middle path, and learn the decision rule for choosing between them

Setup: the model and pipeline we’ll train three times

We need one fixed setup so the three training styles are directly comparable. Fashion-MNIST, a tf.data pipeline exactly like Day 3’s, and a small MLP built with the Functional API from Day 2.

import tensorflow as tf
import keras

keras.utils.set_random_seed(42)   # reproducible across all three runs

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Carve out a validation split BEFORE building datasets β€” never validate on test data.
x_val, y_val = x_train[-5000:], y_train[-5000:]
x_train, y_train = x_train[:-5000], y_train[:-5000]

BATCH = 128
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(10_000)
    .batch(BATCH)
    .prefetch(tf.data.AUTOTUNE)
)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH).prefetch(tf.data.AUTOTUNE)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH).prefetch(tf.data.AUTOTUNE)

Two deliberate choices here. First, set_random_seed pins Python, NumPy, and backend RNGs at once, so weight initialization and shuffling match across our three experiments β€” without it, comparing loss curves is comparing noise. Second, the validation split is taken from the training data. fit() will happily accept validation_split=0.1 and do this for you, but only for array inputs β€” it cannot split a tf.data.Dataset (it has no idea how long the dataset is). Since Day 3 taught you to live in tf.data, you split manually and pass validation_data=val_ds.

def build_model() -> keras.Model:
    inputs = keras.Input(shape=(28, 28))
    x = keras.layers.Flatten()(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dense(128, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)          # NOTE: no softmax β€” raw logits
    return keras.Model(inputs, outputs, name="fashion_mlp")

model = build_model()

The output layer has no activation. It emits logits β€” raw, unnormalized scores of shape (batch, 10). This is on purpose, and it sets up the single most common Keras bug, which we defuse in the next section.

Anatomy of compile(): loss, metrics, optimizer

compile() doesn’t train anything. It attaches three things to the model: how wrong it is (loss), how to fix it (optimizer), and what to report (metrics).

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)

You could write optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] β€” the strings are aliases for default-configured objects. Prefer the object form the moment you care about any hyperparameter, because the string form hides two decisions:

The learning rate. "adam" means Adam(learning_rate=0.001). Fine here, but you’ll want control on Day 7, and passing the object is how you get it.

from_logits. This is the trap. The categorical cross-entropy for one example with true class \(y\) is

\[ \mathcal{L} = -\log p_y, \qquad p = \mathrm{softmax}(z) \]

where \(z\) are the logits. Someone has to apply that softmax. Either your last layer does it (activation="softmax") and the loss receives probabilities, or the loss does it internally (from_logits=True) and receives logits. Pick exactly one. The string "sparse_categorical_crossentropy" defaults to from_logits=False. If your model outputs logits and you use the string, Keras treats raw scores like probabilities: loss values go weird, training limps along at mediocre accuracy, and nothing crashes β€” the worst failure mode in deep learning. Keeping logits in the model and from_logits=True in the loss is the numerically stabler choice (the log-softmax is fused inside the loss), so that’s our house style.

Sparse vs non-sparse. Our labels are integer class indices, shape (batch,) β€” that’s Sparse.... If you one-hot encode to (batch, 10), use CategoricalCrossentropy instead. Mismatching them raises a shape error, which at least fails loudly.

Metrics vs loss. The loss is differentiated; metrics are only watched. Metrics are stateful objects β€” SparseCategoricalAccuracy keeps running total and count variables, update_state(y, y_pred) folds each batch in, result() reads the epoch-so-far value, and reset_state() zeroes it between epochs. fit() manages this lifecycle for you. In the custom loop, you will β€” and forgetting reset_state() is the classic bug that makes epoch 2’s accuracy mysteriously blend with epoch 1’s.

fit() and the callback zoo

Here is what fit() actually runs. Keep this picture in your head β€” everything else today is about which parts of it you take over.

flowchart TD
    F["model.fit(train_ds, validation_data=val_ds, epochs=N)"] --> E{"for each epoch"}
    E --> B{"for each batch"}
    B --> TS["train_step(batch)"]
    TS --> FWD["forward pass<br/>y_pred = model(x, training=True)"]
    FWD --> L["loss = compute_loss(y, y_pred)"]
    L --> G["GradientTape β†’ gradients"]
    G --> O["optimizer applies updates<br/>ΞΈ ← ΞΈ βˆ’ Ξ·βˆ‡ΞΈL"]
    O --> M["metrics.update_state(y, y_pred)"]
    M --> B
    B -->|"epoch ends"| V["evaluate on validation_data<br/>(test_step per batch)"]
    V --> CB["callbacks: on_epoch_end(logs)<br/>EarlyStopping Β· ModelCheckpoint<br/>ReduceLROnPlateau Β· TensorBoard"]
    CB -->|"continue"| E
    CB -->|"stop_training = True"| DONE["return History"]

Everything inside train_step is the math; everything around it is bookkeeping. Callbacks are objects that get poked at defined moments (on_train_begin, on_epoch_end, on_batch_end, …) with a logs dict containing the current metric values. Four of them cover 95% of real training runs:

callbacks = [
    # 1. Stop when val_loss stops improving; roll back to the best epoch's weights.
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=3,                  # tolerate 3 stagnant epochs before stopping
        restore_best_weights=True,   # otherwise you keep the LAST (worse) weights!
    ),
    # 2. Continuously save the best model seen so far to disk.
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoints/best.keras",
        monitor="val_accuracy",
        mode="max",                  # "accuracy improving" means going UP
        save_best_only=True,
    ),
    # 3. When val_loss plateaus, cut the learning rate β€” often buys a late improvement.
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,                  # new_lr = lr * 0.5
        patience=2,                  # trigger faster than EarlyStopping does
        min_lr=1e-5,
    ),
    # 4. Log everything for the TensorBoard UI.
    keras.callbacks.TensorBoard(log_dir="logs/fit_run", histogram_freq=1),
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=30,                       # an upper bound β€” EarlyStopping decides the real number
    callbacks=callbacks,
)

Notes on the choreography, because the defaults bite:

  • patience interplay. ReduceLROnPlateau has patience=2, EarlyStopping has patience=3. That ordering is deliberate: the LR gets cut before early stopping gives up, so the model gets one cheap rescue attempt. Set them the other way around and the LR schedule never fires.
  • restore_best_weights=True is not the default. Without it, EarlyStopping halts training but leaves the model holding the weights from the final β€” by definition disappointing β€” epoch.
  • mode. Checkpointing monitors val_accuracy, where bigger is better, hence mode="max". Keras infers the mode for common names, but being explicit costs nothing and reads clearly.
  • .keras filepath. Keras 3’s native saving format (whole model: architecture + weights + optimizer state). More on this on Day 9.
  • epochs=30 is a budget, not a plan. With EarlyStopping in place, you set epochs generously and let validation loss decide.

Typical output β€” note the LR cut at epoch 9 and the early stop at 12:

Epoch 1/30
430/430 ━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8225 - loss: 0.5081 - val_accuracy: 0.8618 - val_loss: 0.3800
Epoch 2/30
430/430 ━━━━━━━━━━ 2s 4ms/step - accuracy: 0.8710 - loss: 0.3524 - val_accuracy: 0.8720 - val_loss: 0.3492
...
Epoch 9/30
430/430 ━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9124 - loss: 0.2351 - val_accuracy: 0.8896 - val_loss: 0.3241
Epoch 9: ReduceLROnPlateau reducing learning rate to 0.0005.
...
Epoch 12/30
430/430 ━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9301 - loss: 0.1882 - val_accuracy: 0.8912 - val_loss: 0.3305
Epoch 12: early stopping. Restoring model weights from the end of the best epoch: 9.

Launch TensorBoard from a terminal with tensorboard --logdir logs and open http://localhost:6006 β€” you get live loss/metric curves, the LR schedule, and (thanks to histogram_freq=1) weight histograms per layer per epoch. The returned history.history dict holds the same per-epoch numbers ({"loss": [...], "val_loss": [...], ...}) if you’d rather matplotlib them.

Finally, honest numbers come from data the callbacks never saw:

test_loss, test_acc = model.evaluate(test_ds, verbose=0)
print(f"test accuracy: {test_acc:.4f}")   # β‰ˆ 0.884

The validation set steered EarlyStopping, checkpointing, and the LR schedule β€” it’s been β€œused up” as an unbiased estimate. The test set gives the number you report.

The same training as a raw GradientTape loop

Now we throw all of that away and write it ourselves. This is Day 1’s GradientTape grown up: instead of differentiating \(y = x^2\), we differentiate a full model’s loss with respect to every weight, and we do it once per batch, thousands of times. Structurally, what follows is a PyTorch training loop with the names changed β€” tape.gradient where PyTorch has loss.backward(), optimizer.apply_gradients where PyTorch has optimizer.step(), and no zero_grad() at all because the tape is rebuilt fresh each step, so gradients can’t accumulate by accident.

Stage 1 β€” the pieces compile() used to hold, now as plain variables:

model = build_model()                          # fresh weights, same seed lineage

optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc = keras.metrics.SparseCategoricalAccuracy()
val_acc = keras.metrics.SparseCategoricalAccuracy()
train_loss = keras.metrics.Mean()              # running average of per-batch losses

keras.metrics.Mean deserves a word: the per-batch loss jumps around, so what fit() prints is a running mean over the epoch. We replicate that with a Mean metric fed one scalar per step.

Stage 2 β€” one training step, compiled to a graph:

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)               # (128, 10) float32
        loss = loss_fn(y, logits)                      # () scalar
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss.update_state(loss)
    train_acc.update_state(y, logits)
    return loss

@tf.function
def val_step(x, y):
    logits = model(x, training=False)
    val_acc.update_state(y, logits)

Line by line, because every line is load-bearing:

  • with tf.GradientTape() as tape: β€” the tape records every op on watched variables inside this block only. The forward pass and the loss must both happen inside; move loss = loss_fn(...) outside the with and tape.gradient returns Nones (the tape never saw the loss being computed), and apply_gradients throws.
  • model(x, training=True) β€” the training flag switches layer behavior (Dropout drops, BatchNorm uses batch statistics). Our MLP has neither, but wire it correctly now; on Day 6 and 7 this flag is the difference between training and silently broken training. Note it’s model(x, ...), the direct call β€” not model.predict(), which is a batched-inference convenience wrapper and has no business inside a training loop.
  • tape.gradient(loss, model.trainable_variables) β€” returns a list of tensors, one per weight, each shaped exactly like its variable: for our MLP that’s [(784,256), (256,), (256,128), (128,), (128,10), (10,)]. trainable_variables, not variables β€” BatchNorm’s moving averages, for instance, are variables that must not receive gradient updates.
  • optimizer.apply_gradients(zip(grads, vars)) β€” the update \(\theta \leftarrow \theta - \eta \, \hat{g}\) (with Adam’s per-parameter scaling folded into \(\hat g\)). The optimizer pairs each gradient with its variable via the zip; the two lists are in matching order because both came from the same trainable_variables call.
  • @tf.function β€” traces the Python into a TensorFlow graph on first call, then replays the compiled graph. On this small model it’s a 2–3Γ— speedup. Two rules from Day 1 still apply: (a) don’t put Python side effects (print, list appends) inside and expect them per-step β€” they run only during tracing; (b) keep input shapes consistent, because each new shape triggers an expensive retrace. Our final batch is 4,875 % 128 = a partial batch, which does cost one extra trace β€” harmless here, but drop_remainder=True in .batch() is the fix when it isn’t.

Stage 3 β€” the outer loop, i.e. everything fit() was doing around the math:

EPOCHS = 10
best_val = 0.0

for epoch in range(EPOCHS):
    train_loss.reset_state(); train_acc.reset_state(); val_acc.reset_state()

    for x, y in train_ds:
        train_step(x, y)

    for x, y in val_ds:
        val_step(x, y)

    va = float(val_acc.result())
    print(f"epoch {epoch+1:2d} | loss {float(train_loss.result()):.4f} "
          f"| acc {float(train_acc.result()):.4f} | val_acc {va:.4f}")

    if va > best_val:                       # a hand-rolled ModelCheckpoint
        best_val = va
        model.save_weights("checkpoints/manual_best.weights.h5")
epoch  1 | loss 0.5062 | acc 0.8231 | val_acc 0.8606
epoch  2 | loss 0.3531 | acc 0.8708 | val_acc 0.8724
...
epoch 10 | loss 0.2210 | acc 0.9178 | val_acc 0.8874

Same numbers as fit(), within noise β€” as they must be, since it’s the same math. Now tally what you paid for the control: the metric reset_state() calls, the validation loop, the progress printing, best-model tracking β€” all yours to write and yours to get wrong. And we didn’t implement early stopping, LR-on-plateau, or TensorBoard logging; each is another 5–15 lines. This is why the raw loop is the tool of last resort, not the default: you write it when the structure of training itself is nonstandard β€” GANs alternating two optimizers, reinforcement learning, gradient accumulation across micro-batches, custom multi-loss balancing. For everything else, there’s a better deal.

The middle path: overriding train_step

Here’s the Keras 3 move that most tutorials undersell: subclass keras.Model, override only train_step, and keep everything else. You rewrite the ~8 lines of math you actually care about; fit() still supplies the epoch loop, tf.function compilation, callbacks, progress bar, validation, and (later) multi-GPU distribution β€” none of which you re-implement.

class CustomModel(keras.Model):
    def train_step(self, data):
        x, y = data                                    # one batch, as fit() unpacked it

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compute_loss(y=y, y_pred=y_pred)   # the loss given to compile()

        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        # Update the metrics configured in compile()
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

The contract, piece by piece:

  • data is exactly one element of whatever you passed to fit() β€” for our train_ds, an (x, y) tuple. (Pass sample_weight in your dataset and it arrives as a 3-tuple; unpack accordingly.)
  • self.compute_loss(y=y, y_pred=y_pred) is Keras 3’s bridge to whatever loss you gave compile() β€” including any regularization losses layers have registered. Using it (rather than a hardcoded loss_fn) keeps your subclass reusable with any compiled loss.
  • The metrics loop feeds the same metric objects compile() created; the built-in loss tracker takes the scalar, the rest take (y, y_pred).
  • The return dict is what the progress bar prints and β€” crucially β€” what callbacks receive as logs. Return {"loss": ..., "accuracy": ...} and EarlyStopping(monitor="val_loss") works on your custom step unchanged, because fit() runs your metrics against validation_data too (via the default test_step; override that as well if your evaluation math is also custom).
  • No @tf.function needed β€” fit() wraps your train_step in one automatically.

Now use it exactly like any Keras model β€” note that building it requires the functional-style call, or subclass with layers as on Day 2:

inputs = keras.Input(shape=(28, 28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dense(128, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = CustomModel(inputs, outputs)

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)
model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks)  # everything still works

The demo above reproduces the default behavior β€” deliberately, so you can verify the plumbing (it should match the earlier runs). The payoff comes when you change the math: clip gradients before applying them (grads = [tf.clip_by_norm(g, 1.0) for g in grads]), add a custom auxiliary loss, implement mixup inside the step, log gradient norms as a metric. Each of those is a 1–3 line edit here, versus owning a whole loop.

One caveat worth knowing: this backend-specific train_step (it uses tf.GradientTape) ties the model to the TensorFlow backend. Keras 3 also runs on JAX and PyTorch backends, each with its own override idiom β€” fine for this course, where TF is the point, but don’t be surprised when a JAX tutorial’s train_step looks different.

Choosing your altitude

Three altitudes, one dial: control vs. convenience ← more built-in machinery more control β†’ compile() + fit() standard supervised training you write: ~5 lines callbacks βœ“ distrib. βœ“ custom math βœ— the default train_step override custom step math, standard loop shape you write: ~15 lines callbacks βœ“ distrib. βœ“ custom math βœ“ the sweet spot GradientTape loop nonstandard training (GANs, RL, accumulation) you write: ~40+ lines callbacks βœ— (DIY) custom math βœ“βœ“ last resort
compile + fit train_step override raw GradientTape loop
Epoch/batch loop Keras Keras you
Forward/backward/update Keras you you
Callbacks, progress, validation Keras Keras you
Metric lifecycle (reset_state) Keras Keras you
tf.function compilation automatic automatic you add it
Typical use almost everything clipping, mixup, aux losses, custom logging GANs, RL, exotic loop structure

The decision rule: start at level 1. Drop to level 2 the moment you need to touch the per-batch math. Drop to level 3 only when the shape of the loop itself is nonstandard. Moving down is always possible; the mistake is starting at the bottom β€œfor flexibility” and then hand-maintaining early stopping forever.

πŸ§ͺ Your task

Implement gradient clipping with gradient-norm logging via a train_step override. Concretely: subclass keras.Model so that (a) gradients are clipped by global norm to 1.0 before being applied (tf.clip_by_global_norm), and (b) the pre-clip global gradient norm appears in the progress bar and history as a metric named grad_norm. Train it on train_ds for 5 epochs with the EarlyStopping callback attached, and confirm grad_norm shows up in the logs and decreases over training.

Hint: tf.clip_by_global_norm(grads, 1.0) returns a tuple (clipped_grads, global_norm) β€” it hands you both things you need in one call. Track the norm with a keras.metrics.Mean created in __init__, and remember that any metric you want fit() to reset each epoch must be returned by a metrics property (or simply included in your returned dict and tracked manually).

Solution
import tensorflow as tf
import keras

class ClippedModel(keras.Model):
    def __init__(self, *args, clip_norm=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.clip_norm = clip_norm
        self.grad_norm_tracker = keras.metrics.Mean(name="grad_norm")

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compute_loss(y=y, y_pred=y_pred)

        grads = tape.gradient(loss, self.trainable_variables)
        # one call gives clipped grads AND the pre-clip global norm
        clipped, global_norm = tf.clip_by_global_norm(grads, self.clip_norm)
        self.optimizer.apply_gradients(zip(clipped, self.trainable_variables))

        self.grad_norm_tracker.update_state(global_norm)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            elif metric.name != "grad_norm":       # ours is fed above
                metric.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    @property
    def metrics(self):
        # include our tracker so fit() resets it every epoch
        return super().metrics + [self.grad_norm_tracker]


inputs = keras.Input(shape=(28, 28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dense(128, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = ClippedModel(inputs, outputs, clip_norm=1.0)

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,
    callbacks=[keras.callbacks.EarlyStopping(monitor="val_loss", patience=3,
                                             restore_best_weights=True)],
)

print(history.history["grad_norm"])
# e.g. [1.9137, 0.9821, 0.7414, 0.6350, 0.5561]  β€” large early, shrinking as training settles

Expected behavior: epoch 1’s mean grad_norm exceeds 1.0 (clipping is actually firing early on), then it drops below the threshold as the loss surface flattens β€” at which point the clip becomes a no-op safety net, which is exactly what you want. Callbacks, validation, and the progress bar all worked untouched: that’s the middle path earning its keep.

Key takeaways

  • compile() attaches loss, optimizer, and metrics; prefer object form over strings, and match from_logits=True to a no-softmax output layer β€” the mismatch trains badly without crashing.
  • fit(validation_data=...) + four callbacks (EarlyStopping with restore_best_weights=True, ModelCheckpoint, ReduceLROnPlateau with shorter patience than EarlyStopping, TensorBoard) is the production-default training recipe.
  • The raw loop is GradientTape around the forward pass, tape.gradient on trainable_variables, optimizer.apply_gradients, @tf.function for speed β€” and you inherit all the bookkeeping fit() was doing, including metric reset_state().
  • Keras 3’s train_step override is the middle path: rewrite only the per-batch math, keep callbacks/validation/distribution for free; self.compute_loss and the metrics loop are the contract.
  • Decision rule: fit() by default β†’ train_step when the step’s math changes β†’ full custom loop only when the loop’s structure changes.

Tomorrow we put Days 1–4 together into a complete classification project β€” real dataset, preprocessing to prediction, with all the evaluation you should never skip.


🏠 πŸ“Š Course home  |  ← Day 03  |  Day 05 β†’  |  πŸ“š All mini-courses

 

Β© Kader Mohideen