Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

On this day

  • Day 5 β€” Classification End to End: From Raw Pixels to a Confusion Matrix
    • The shape of a real training run
    • Data: load, split, and manufacture an imbalance
    • Normalization as a layer, not a preprocessing script
    • Class weights, the model, and checkpointing the best epoch
    • Evaluation: learning curves, then the confusion matrix
    • The complete script
    • πŸ§ͺ Your task
    • Key takeaways

πŸ“Š Deep Learning with TensorFlow & Keras Β· Day 5 β€” Classification End to End: From Raw Pixels to a Confusion Matrix

🏠 πŸ“Š Course home  |  ← Day 04  |  Day 06 β†’  |  πŸ“š All mini-courses


Day 5 β€” Classification End to End: From Raw Pixels to a Confusion Matrix

Over the last four days you built every component in isolation: tensors and GradientTape on Day 1, three ways to define a model on Day 2, tf.data pipelines on Day 3, and the compile/fit machinery versus hand-rolled loops on Day 4. Today you assemble the whole thing into one honest, production-shaped project: an MLP classifier on Fashion-MNIST that handles the messy realities a tutorial usually skips β€” normalization that lives inside the model, a class imbalance you have to compensate for, learning curves you can actually read, a confusion matrix that tells you where the model fails, and a checkpoint of the best epoch (not the last one). By the end you’ll have a single runnable script that is a template for every classification project you do after this course.

🎯 Today you will: bake a Normalization layer into the model with adapt(), correct a skewed dataset with class_weight, checkpoint the best epoch with ModelCheckpoint, diagnose training health from learning curves, and read a confusion matrix like a radiologist reads an X-ray

The shape of a real training run

Most tutorials show you fit() and stop. A real classification project has a pipeline with more stations, and skipping any one of them produces a model that looks fine on paper and embarrasses you in production. Here is the full route we travel today:

flowchart LR
    A[Raw data<br/>uint8 pixels] --> B[Split<br/>train / val / test]
    B --> C[adapt<br/>Normalization layer<br/>on TRAIN only]
    B --> D[Count classes β†’<br/>class_weight dict]
    C --> E[Model = normalizer<br/>+ MLP + logits]
    D --> F[fit with class_weight<br/>+ ModelCheckpoint<br/>+ EarlyStopping]
    E --> F
    F --> G[Learning curves:<br/>is it healthy?]
    F --> H[best_fashion.keras<br/>best epoch, not last]
    H --> I[Confusion matrix<br/>on TEST set]

Two decisions in this diagram are the ones people get wrong most often, so let’s name them up front:

  1. Statistics flow from train data only. The Normalization layer’s mean and standard deviation come from the training split. If you adapt() on the full dataset, information about the test set leaks into your model. On Fashion-MNIST the damage is negligible; on a small medical dataset it can fabricate several points of accuracy that evaporate in the real world.
  2. The saved model is the best epoch, not the final one. Validation accuracy peaks and then decays as the model overfits. ModelCheckpoint(save_best_only=True) freezes the peak.

Data: load, split, and manufacture an imbalance

Fashion-MNIST ships with Keras, so loading is one line. It’s 70,000 grayscale 28Γ—28 images across 10 clothing classes β€” a drop-in replacement for MNIST that is hard enough to make mistakes visible.

import numpy as np
import keras

(x_train_full, y_train_full), (x_test, y_test) = \
    keras.datasets.fashion_mnist.load_data()

print(x_train_full.shape, x_train_full.dtype)   # (60000, 28, 28) uint8
print(y_train_full.shape, y_train_full.dtype)   # (60000,) uint8
print(x_train_full.min(), x_train_full.max())   # 0 255

CLASS_NAMES = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
               "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

Note what you got: raw uint8 in [0, 255], not floats in [0, 1]. We are deliberately not going to divide by 255 here in NumPy β€” that normalization decision belongs inside the model, and you’ll see why in the next section.

First, carve out a validation set. The test set is sacred β€” you touch it exactly once, at the very end. The validation set is your working thermometer during training:

VAL_SIZE = 5000
x_val,   y_val   = x_train_full[-VAL_SIZE:], y_train_full[-VAL_SIZE:]
x_train, y_train = x_train_full[:-VAL_SIZE], y_train_full[:-VAL_SIZE]

The order matters: split before you compute any statistic or reshape any distribution. The validation set must look like the world you’ll be evaluated on β€” balanced, untouched.

Fashion-MNIST is perfectly balanced (6,000 images per class), which is lovely and completely unrepresentative of real life. Fraud detection, defect inspection, rare-disease screening β€” real classification problems are skewed, often violently. So we’ll manufacture an imbalance in the training set to learn the tools for dealing with one. We gut the three footwear classes down to 300 examples each:

rng = np.random.default_rng(42)

RARE = {5: 300, 7: 300, 9: 300}   # Sandal, Sneaker, Ankle boot

keep = np.ones(len(y_train), dtype=bool)
for cls, n_keep in RARE.items():
    idx = np.where(y_train == cls)[0]
    drop = rng.choice(idx, size=len(idx) - n_keep, replace=False)
    keep[drop] = False

x_train, y_train = x_train[keep], y_train[keep]
print(np.bincount(y_train))
[5500 5493 5504 5514 5488  300 5514  300 5486  300]

The mechanics: np.where(y_train == cls)[0] gives the row indices of one class, rng.choice(..., replace=False) samples which ones to delete, and the boolean keep mask applies all deletions in a single vectorized pass. Never delete rows in a Python loop with np.delete β€” it re-allocates the entire array every call.

Now footwear is ~18Γ— rarer than everything else in training, while validation and test remain balanced. A naive model will learn that predicting β€œSneaker” is rarely worth the risk β€” and its overall accuracy will barely show it, because sneakers are only 10% of the test set. This is exactly the trap where accuracy lies and a confusion matrix tells the truth.

Normalization as a layer, not a preprocessing script

The classic way to normalize is x = (x - mean) / std in NumPy before training. It works β€” until you deploy. Six months later, someone feeds the served model raw [0, 255] pixels because the normalization constants lived in a preprocessing script that didn’t ship with the model. This is one of the most common production ML bugs in existence.

Keras’ answer is keras.layers.Normalization: a layer that stores the statistics as part of the model. Whatever you export on Day 9 will contain the mean and variance inside the graph. There is no preprocessing script to lose.

from keras import layers

normalizer = layers.Normalization(axis=None)
normalizer.adapt(x_train.astype("float32"))

print(float(normalizer.mean), float(normalizer.variance) ** 0.5)
72.939 90.020

Three things deserve a close look:

  • axis=None means β€œcompute one scalar mean and one scalar variance over everything”. That’s right for grayscale images, where every pixel shares one intensity distribution. For tabular data with, say, 13 features of wildly different scales, you’d use axis=-1 to get 13 per-feature means. For RGB images you’d normalize per channel. Getting this axis wrong is silent β€” the model still trains, just worse.
  • adapt() is not training. It’s a single statistics pass β€” the layer iterates over the data once and stores mean and variance as non-trainable weights. No gradients, no epochs. This is the Keras 3 idiom for any stateful preprocessing layer (StringLookup, TextVectorization, Discretization all work the same way). The PyTorch equivalent is computing x.mean()/x.std() yourself and hard-coding them into a transforms.Normalize β€” Keras just formalizes the pattern and welds the result to the model.
  • We adapted on x_train only β€” the post-imbalance training split. Validation and test pixels never influence these numbers.

At inference the layer computes \(\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}\), turning our [0, 255] pixels into roughly zero-mean, unit-variance inputs β€” the regime where gradient descent with default learning rates behaves well. Feed a raw-pixel MLP the unnormalized values and watch your first-epoch loss start around 60 instead of 2.3, because the initial logits are scaled by inputs ~90Γ— too large.

Class weights, the model, and checkpointing the best epoch

Weighting the loss

With 300 sneakers versus 5,500 shirts, each gradient step sees ~18 shirt examples for every sneaker. The standard fix is to scale each example’s loss by a per-class weight so rare classes shout louder. The usual β€œbalanced” heuristic:

\[w_c = \frac{N}{C \cdot n_c}\]

where \(N\) is the total number of training examples, \(C\) the number of classes, and \(n_c\) the count of class \(c\). A class at exactly average frequency gets weight 1.0; our rare classes get ~13.3.

counts = np.bincount(y_train)
n_classes = len(counts)
class_weight = {c: len(y_train) / (n_classes * counts[c])
                for c in range(n_classes)}

for c in (0, 5):
    print(f"{CLASS_NAMES[c]:>12}: n={counts[c]:>5}  w={class_weight[c]:.2f}")
 T-shirt/top: n= 5500  w=0.73
      Sandal: n=  300  w=13.36

In Keras this is a plain dict passed to fit(class_weight=...) β€” Keras multiplies each sample’s loss by the weight of its true class before averaging. The PyTorch equivalent is nn.CrossEntropyLoss(weight=torch.tensor([...])); same math, but Keras attaches it to the training call rather than the loss object, which means your evaluation loss stays unweighted β€” usually what you want.

The effect on the loss: instead of \(\mathcal{L} = \frac{1}{B}\sum_i \ell_i\), each batch computes \(\mathcal{L} = \frac{1}{B}\sum_i w_{y_i}\, \ell_i\). Misclassifying a sandal now costs ~13 shirts’ worth of gradient.

The model

A straightforward MLP, functional style (Day 2), with the normalizer as its first layer:

def build_model(normalizer: layers.Normalization) -> keras.Model:
    inputs = keras.Input(shape=(28, 28))
    x = normalizer(inputs)              # stats baked into the graph
    x = layers.Flatten()(x)             # (batch, 28, 28) -> (batch, 784)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation="relu")(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(10)(x)       # logits β€” no softmax here
    return keras.Model(inputs, outputs, name="fashion_mlp")

model = build_model(normalizer)
model.summary()
Model: "fashion_mlp"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Layer (type)               ┃ Output Shape      ┃   Param # ┃
┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
β”‚ input_layer (InputLayer)   β”‚ (None, 28, 28)    β”‚         0 β”‚
β”‚ normalization (Normalizat… β”‚ (None, 28, 28)    β”‚         3 β”‚
β”‚ flatten (Flatten)          β”‚ (None, 784)       β”‚         0 β”‚
β”‚ dense (Dense)              β”‚ (None, 256)       β”‚   200,960 β”‚
β”‚ dropout (Dropout)          β”‚ (None, 256)       β”‚         0 β”‚
β”‚ dense_1 (Dense)            β”‚ (None, 128)       β”‚    32,896 β”‚
β”‚ dropout_1 (Dropout)        β”‚ (None, 128)       β”‚         0 β”‚
β”‚ dense_2 (Dense)            β”‚ (None, 10)        β”‚     1,290 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
 Total params: 235,149 (918.55 KB)
 Trainable params: 235,146 (918.54 KB)
 Non-trainable params: 3 (16.00 B)

Those 3 non-trainable params are the normalizer’s mean, variance, and count β€” proof the statistics live in the model. And notice the output layer has no softmax: we emit raw logits and tell the loss about it. Computing log(softmax(x)) as two separate ops is numerically unstable for extreme logits; from_logits=True fuses them into a stable log-sum-exp, exactly like PyTorch’s CrossEntropyLoss (which never wants a softmax in the model either).

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)

Sparse because our labels are integer class ids (3), not one-hot vectors ([0,0,0,1,...]). Using CategoricalCrossentropy with integer labels raises a shape error β€” one of the two most common Keras beginner crashes (the other is forgetting from_logits=True, which doesn’t crash at all; it just trains a mediocre model silently, because Keras then treats your raw logits as probabilities).

Callbacks: save the peak, stop the decay

callbacks = [
    keras.callbacks.ModelCheckpoint(
        "best_fashion.keras",
        monitor="val_accuracy",
        save_best_only=True,
        verbose=1,
    ),
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=5,
        restore_best_weights=True,
    ),
]

history = model.fit(
    x_train.astype("float32"), y_train,
    validation_data=(x_val.astype("float32"), y_val),
    epochs=40,
    batch_size=128,
    class_weight=class_weight,
    callbacks=callbacks,
    verbose=2,
)
Epoch 1/40
Epoch 1: val_accuracy improved from -inf to 0.84280, saving model to best_fashion.keras
430/430 - 3s - loss: 0.7842 - accuracy: 0.7676 - val_loss: 0.4494 - val_accuracy: 0.8428
...
Epoch 14/40
Epoch 14: val_accuracy improved from 0.88700 to 0.88900, saving model to best_fashion.keras
...
Epoch 21/40
Epoch 21: val_accuracy did not improve from 0.88900
430/430 - 2s - loss: 0.2412 - accuracy: 0.9147 - val_loss: 0.3421 - val_accuracy: 0.8852

(Your exact numbers will differ β€” dropout and shuffling are stochastic β€” but the shape will match.)

Dissect the setup:

  • ModelCheckpoint with save_best_only=True writes a full .keras file (architecture + weights + optimizer state + our normalizer stats) every time val_accuracy sets a new record, and only then. The monitored quantity’s name must match a logged metric exactly β€” monitor "val_acc" by mistake and you get a warning per epoch and no checkpoints.
  • EarlyStopping(patience=5) kills the run after 5 epochs without val_loss improvement, and restore_best_weights=True rolls the in-memory model back to its best epoch. Belt and suspenders with the checkpoint file: the file survives a crash, the restore fixes the live object.
  • In PyTorch this whole block is code you write yourself: track best_val, torch.save(model.state_dict(), ...) inside an if, a patience counter, a manual re-load. Keras callbacks are that loop’s hooks, prepackaged. Day 4’s custom-loop skills tell you what’s inside the box; today you get to just use the box.

Evaluation: learning curves, then the confusion matrix

Reading the curves

fit() returned a History object whose .history dict holds one list per metric per epoch. Plot loss and accuracy side by side, train against validation:

import matplotlib.pyplot as plt

hist = history.history
epochs = range(1, len(hist["loss"]) + 1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
ax1.plot(epochs, hist["loss"], label="train")
ax1.plot(epochs, hist["val_loss"], label="val")
ax1.set(title="Loss", xlabel="epoch"); ax1.legend()

ax2.plot(epochs, hist["accuracy"], label="train")
ax2.plot(epochs, hist["val_accuracy"], label="val")
ax2.set(title="Accuracy", xlabel="epoch"); ax2.legend()
fig.tight_layout(); fig.savefig("curves.png", dpi=120)

There are three curve shapes you’ll meet again and again. Learn to recognize them on sight:

Underfitting both losses high & flat β†’ bigger model, train longer Healthy both fall, small steady gap β†’ ship it checkpoint here Overfitting val loss turns back up β†’ regularize (Day 7) solid = train loss, dashed = val loss

Our run lands between β€œhealthy” and the early stage of β€œoverfitting”: training loss keeps dropping while validation loss bottoms out around epoch 13–16 and drifts up. That drift is precisely why the checkpoint monitors validation, and why EarlyStopping ends the run around epoch 20 instead of wasting 20 more.

One subtlety specific to Keras: the training metrics in history are computed during the epoch β€” a running average while dropout is active and weights are still moving β€” whereas validation metrics are computed after the epoch with dropout off. So in epoch 1 you’ll often see val_accuracy > accuracy. It’s not a bug; the two numbers aren’t measured under the same conditions.

The confusion matrix

Overall accuracy is a single number averaged over a balanced test set, which means our starved footwear classes can be quietly terrible without moving it much. The confusion matrix breaks the score open. Load the checkpointed model β€” not the possibly-overfit final state β€” and touch the test set for the first and only time:

import tensorflow as tf

best = keras.models.load_model("best_fashion.keras")

test_loss, test_acc = best.evaluate(x_test.astype("float32"), y_test, verbose=0)
print(f"test accuracy: {test_acc:.4f}")          # β‰ˆ 0.87

logits = best.predict(x_test.astype("float32"), verbose=0)  # (10000, 10)
y_pred = logits.argmax(axis=1)                              # (10000,)

cm = tf.math.confusion_matrix(y_test, y_pred, num_classes=10).numpy()

tf.math.confusion_matrix returns a 10Γ—10 integer tensor: rows are true classes, columns are predictions (the same convention as scikit-learn β€” but always check, because some libraries transpose it). Entry \((i, j)\) counts test images of class \(i\) that the model called class \(j\).

predicted class β†’ true class β†’ 137 851 diagonal = correct; row-normalized β†’ recall cell (Shirt, Coat) = 137: 137 shirts predicted as coats β€” a specific, fixable failure column sums β†’ how often a class is over-predicted (precision’s denominator)

Print it with names, plus per-class recall (the diagonal divided by row sums):

recall = cm.diagonal() / cm.sum(axis=1)
for name, r in zip(CLASS_NAMES, recall):
    bar = "β–ˆ" * int(r * 30)
    print(f"{name:>12} {r:6.1%} {bar}")
 T-shirt/top  84.3% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
     Trouser  97.4% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
    Pullover  81.0% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
       Dress  89.1% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
        Coat  83.2% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
      Sandal  93.8% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
       Shirt  67.5% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
     Sneaker  92.6% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
         Bag  96.5% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
  Ankle boot  94.1% β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ

Two findings jump out, and neither is visible in the headline 87%:

  • The class weights worked. Despite training on only 300 examples each, the footwear classes recover to ~93% recall. Rerun without class_weight (do it β€” one argument to delete) and watch Sandal/Sneaker/Ankle-boot recall crater into the 60s–70s while overall accuracy drops only ~3 points. That gap between β€œaccuracy barely moved” and β€œa third of sneakers misclassified” is the entire argument for per-class evaluation.
  • Shirt is the real problem child β€” and it was never rare. Look along the Shirt row: its errors flow into T-shirt/top, Pullover, and Coat. These classes genuinely overlap visually at 28Γ—28 grayscale, and a fully-connected net that destroyed all spatial structure with Flatten can’t tell a collar from a crew neck. No amount of reweighting fixes confusion between classes the features can’t separate. That’s a representation problem β€” which is precisely tomorrow’s topic.

A heatmap makes the error flows obvious at a glance:

fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(cm, cmap="Blues")
ax.set_xticks(range(10), CLASS_NAMES, rotation=45, ha="right")
ax.set_yticks(range(10), CLASS_NAMES)
ax.set(xlabel="predicted", ylabel="true")
for i in range(10):
    for j in range(10):
        if cm[i, j] > 0:
            ax.text(j, i, cm[i, j], ha="center", va="center",
                    color="white" if cm[i, j] > cm.max() / 2 else "black",
                    fontsize=7)
fig.colorbar(im); fig.tight_layout(); fig.savefig("confusion.png", dpi=120)

The complete script

Everything above, assembled in dependency order. Save as day5_fashion.py; it runs end to end in a few minutes on CPU.

"""Day 5 β€” Fashion-MNIST classification, end to end."""
import numpy as np
import tensorflow as tf
import keras
from keras import layers
import matplotlib.pyplot as plt

keras.utils.set_random_seed(42)

CLASS_NAMES = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
               "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

# ---- 1. data: load, split, manufacture imbalance -------------------
(x_full, y_full), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_val, y_val = x_full[-5000:], y_full[-5000:]
x_train, y_train = x_full[:-5000], y_full[:-5000]

rng = np.random.default_rng(42)
keep = np.ones(len(y_train), dtype=bool)
for cls in (5, 7, 9):                                  # starve footwear
    idx = np.where(y_train == cls)[0]
    keep[rng.choice(idx, size=len(idx) - 300, replace=False)] = False
x_train, y_train = x_train[keep], y_train[keep]

x_train = x_train.astype("float32")
x_val   = x_val.astype("float32")
x_test  = x_test.astype("float32")

# ---- 2. class weights ----------------------------------------------
counts = np.bincount(y_train)
class_weight = {c: len(y_train) / (len(counts) * counts[c])
                for c in range(len(counts))}

# ---- 3. normalization: stats from TRAIN only, stored in the model --
normalizer = layers.Normalization(axis=None)
normalizer.adapt(x_train)

# ---- 4. model -------------------------------------------------------
inputs = keras.Input(shape=(28, 28))
x = layers.Flatten()(normalizer(inputs))
x = layers.Dropout(0.3)(layers.Dense(256, activation="relu")(x))
x = layers.Dropout(0.3)(layers.Dense(128, activation="relu")(x))
outputs = layers.Dense(10)(x)                          # logits
model = keras.Model(inputs, outputs, name="fashion_mlp")

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)

# ---- 5. train with best-epoch checkpointing ------------------------
history = model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=40, batch_size=128,
    class_weight=class_weight,
    callbacks=[
        keras.callbacks.ModelCheckpoint("best_fashion.keras",
                                        monitor="val_accuracy",
                                        save_best_only=True),
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=5,
                                      restore_best_weights=True),
    ],
    verbose=2,
)

# ---- 6. learning curves --------------------------------------------
h, ep = history.history, range(1, len(history.history["loss"]) + 1)
fig, (a1, a2) = plt.subplots(1, 2, figsize=(11, 4))
a1.plot(ep, h["loss"], label="train"); a1.plot(ep, h["val_loss"], label="val")
a1.set(title="Loss", xlabel="epoch"); a1.legend()
a2.plot(ep, h["accuracy"], label="train"); a2.plot(ep, h["val_accuracy"], label="val")
a2.set(title="Accuracy", xlabel="epoch"); a2.legend()
fig.tight_layout(); fig.savefig("curves.png", dpi=120)

# ---- 7. final evaluation: best model, test set, once ---------------
best = keras.models.load_model("best_fashion.keras")
_, test_acc = best.evaluate(x_test, y_test, verbose=0)
print(f"\ntest accuracy: {test_acc:.4f}")

y_pred = best.predict(x_test, verbose=0).argmax(axis=1)
cm = tf.math.confusion_matrix(y_test, y_pred, num_classes=10).numpy()
recall = cm.diagonal() / cm.sum(axis=1)
for name, r in zip(CLASS_NAMES, recall):
    print(f"{name:>12} {r:6.1%}")

# sanity checks β€” the run is broken if any of these fail
assert test_acc > 0.80, "model failed to train"
assert recall[7] > 0.80, "class weights failed: Sneaker recall collapsed"
assert best.layers[1].count_params() == 3  # normalizer stats travel with the model

The three asserts at the bottom are worth keeping in every training script you write: an accuracy floor, a check on the specific thing you engineered for (rare-class recall), and a structural check that the preprocessing is inside the model. Silent regressions die loudly.

πŸ§ͺ Your task

The class-weight formula treated the imbalance as something to fix in the loss. The other classic remedy fixes it in the data: oversampling β€” repeat the rare examples so each batch sees them more often.

Your task: instead of passing class_weight to fit(), build an oversampled training set in NumPy β€” repeat each rare class’s 300 examples until every class has roughly the same count β€” train on it (no class_weight argument), and compare per-class recall on the test set against today’s class-weighted run. Which footwear class benefits more from which method? Is overall accuracy different?

Hint: np.repeat on indices does the heavy lifting: get each rare class’s indices, tile them up to ~5,500 with np.resize(idx, 5500), concatenate with all other indices, then shuffle the combined index array before slicing x_train/y_train β€” otherwise all repeated sandals land in consecutive batches and training destabilizes.

Solution
import numpy as np
import keras
from keras import layers
import tensorflow as tf

keras.utils.set_random_seed(42)

# --- rebuild the imbalanced split exactly as in the lesson ---
(x_full, y_full), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_val, y_val = x_full[-5000:].astype("float32"), y_full[-5000:]
x_train, y_train = x_full[:-5000], y_full[:-5000]

rng = np.random.default_rng(42)
keep = np.ones(len(y_train), dtype=bool)
for cls in (5, 7, 9):
    idx = np.where(y_train == cls)[0]
    keep[rng.choice(idx, size=len(idx) - 300, replace=False)] = False
x_train, y_train = x_train[keep].astype("float32"), y_train[keep]

# --- oversample: tile rare-class indices up to the majority count ---
target = int(np.bincount(y_train).max())        # ~5500
parts = []
for cls in range(10):
    idx = np.where(y_train == cls)[0]
    if len(idx) < target:
        idx = np.resize(idx, target)            # cycles through the 300, repeating
    parts.append(idx)
all_idx = rng.permutation(np.concatenate(parts))  # SHUFFLE β€” critical
x_bal, y_bal = x_train[all_idx], y_train[all_idx]
print(np.bincount(y_bal))   # ~[5500 5500 ... 5500] β€” balanced by repetition

# --- same model as the lesson, NO class_weight ---
normalizer = layers.Normalization(axis=None)
normalizer.adapt(x_train)                       # stats from ORIGINAL train data

inputs = keras.Input(shape=(28, 28))
x = layers.Flatten()(normalizer(inputs))
x = layers.Dropout(0.3)(layers.Dense(256, activation="relu")(x))
x = layers.Dropout(0.3)(layers.Dense(128, activation="relu")(x))
model = keras.Model(inputs, layers.Dense(10)(x))

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)

model.fit(
    x_bal, y_bal,
    validation_data=(x_val, y_val),
    epochs=40, batch_size=128,
    callbacks=[
        keras.callbacks.ModelCheckpoint("best_oversampled.keras",
                                        monitor="val_accuracy",
                                        save_best_only=True),
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=5,
                                      restore_best_weights=True),
    ],
    verbose=2,
)

# --- compare per-class recall against the class-weighted run ---
best = keras.models.load_model("best_oversampled.keras")
y_pred = best.predict(x_test.astype("float32"), verbose=0).argmax(axis=1)
cm = tf.math.confusion_matrix(y_test, y_pred, num_classes=10).numpy()
recall = cm.diagonal() / cm.sum(axis=1)

CLASS_NAMES = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
               "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
for name, r in zip(CLASS_NAMES, recall):
    print(f"{name:>12} {r:6.1%}")

What you should observe: the two methods land within a couple of points of each other on footwear recall and on overall accuracy β€” which is the real lesson. Mathematically, repeating a sample \(k\) times and weighting its loss by \(k\) produce the same expected gradient; they differ in variance (oversampling shows the literal same 300 images over and over, so the model can memorize them β€” watch the train/val gap grow a bit faster) and in cost (oversampling makes each epoch ~15% longer here, since the dataset physically grew). class_weight is one argument and zero data copies; it’s the default choice. Oversampling earns its keep when you combine it with data augmentation, so each repetition is a different distorted view β€” that trick arrives on Day 7.

Key takeaways

  • Split before you compute anything: adapt(), class counts, and imbalance surgery all happen on the training split only; validation and test stay pristine.
  • layers.Normalization + adapt() welds preprocessing statistics into the model as non-trainable weights β€” the exported model on Day 9 will carry them, so the β€œforgot to normalize at inference” bug becomes impossible.
  • class_weight={c: N/(CΒ·n_c)} in fit() scales each example’s loss by its class weight β€” the Keras counterpart of PyTorch’s CrossEntropyLoss(weight=...), and it rescued 300-example classes to ~93% recall here.
  • Output raw logits and use from_logits=True; Sparse losses/metrics for integer labels. The wrong combination either crashes (shape error) or, worse, silently trains a weaker model.
  • ModelCheckpoint(save_best_only=True) + EarlyStopping(restore_best_weights=True): ship the validation peak, not the overfit tail.
  • Overall accuracy averages away per-class failure; the confusion matrix and per-class recall show which classes fail and where the errors go β€” and they revealed that our worst class (Shirt) isn’t a data-quantity problem but a representation problem: Flatten threw away the spatial structure needed to distinguish it.

Tomorrow we stop flattening: convolutional layers keep the 2D structure of the image, and the shirt-versus-pullover confusion that stumped today’s MLP starts to melt.


🏠 πŸ“Š Course home  |  ← Day 04  |  Day 06 β†’  |  πŸ“š All mini-courses

 

Β© Kader Mohideen