flowchart LR
A[Raw data<br/>uint8 pixels] --> B[Split<br/>train / val / test]
B --> C[adapt<br/>Normalization layer<br/>on TRAIN only]
B --> D[Count classes β<br/>class_weight dict]
C --> E[Model = normalizer<br/>+ MLP + logits]
D --> F[fit with class_weight<br/>+ ModelCheckpoint<br/>+ EarlyStopping]
E --> F
F --> G[Learning curves:<br/>is it healthy?]
F --> H[best_fashion.keras<br/>best epoch, not last]
H --> I[Confusion matrix<br/>on TEST set]
π Deep Learning with TensorFlow & Keras Β· Day 5 β Classification End to End: From Raw Pixels to a Confusion Matrix
π π Course home | β Day 04 | Day 06 β | π All mini-courses
Day 5 β Classification End to End: From Raw Pixels to a Confusion Matrix
Over the last four days you built every component in isolation: tensors and GradientTape on Day 1, three ways to define a model on Day 2, tf.data pipelines on Day 3, and the compile/fit machinery versus hand-rolled loops on Day 4. Today you assemble the whole thing into one honest, production-shaped project: an MLP classifier on Fashion-MNIST that handles the messy realities a tutorial usually skips β normalization that lives inside the model, a class imbalance you have to compensate for, learning curves you can actually read, a confusion matrix that tells you where the model fails, and a checkpoint of the best epoch (not the last one). By the end youβll have a single runnable script that is a template for every classification project you do after this course.
π― Today you will: bake a Normalization layer into the model with adapt(), correct a skewed dataset with class_weight, checkpoint the best epoch with ModelCheckpoint, diagnose training health from learning curves, and read a confusion matrix like a radiologist reads an X-ray
The shape of a real training run
Most tutorials show you fit() and stop. A real classification project has a pipeline with more stations, and skipping any one of them produces a model that looks fine on paper and embarrasses you in production. Here is the full route we travel today:
Two decisions in this diagram are the ones people get wrong most often, so letβs name them up front:
- Statistics flow from train data only. The
Normalizationlayerβs mean and standard deviation come from the training split. If youadapt()on the full dataset, information about the test set leaks into your model. On Fashion-MNIST the damage is negligible; on a small medical dataset it can fabricate several points of accuracy that evaporate in the real world. - The saved model is the best epoch, not the final one. Validation accuracy peaks and then decays as the model overfits.
ModelCheckpoint(save_best_only=True)freezes the peak.
Data: load, split, and manufacture an imbalance
Fashion-MNIST ships with Keras, so loading is one line. Itβs 70,000 grayscale 28Γ28 images across 10 clothing classes β a drop-in replacement for MNIST that is hard enough to make mistakes visible.
import numpy as np
import keras
(x_train_full, y_train_full), (x_test, y_test) = \
keras.datasets.fashion_mnist.load_data()
print(x_train_full.shape, x_train_full.dtype) # (60000, 28, 28) uint8
print(y_train_full.shape, y_train_full.dtype) # (60000,) uint8
print(x_train_full.min(), x_train_full.max()) # 0 255
CLASS_NAMES = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]Note what you got: raw uint8 in [0, 255], not floats in [0, 1]. We are deliberately not going to divide by 255 here in NumPy β that normalization decision belongs inside the model, and youβll see why in the next section.
First, carve out a validation set. The test set is sacred β you touch it exactly once, at the very end. The validation set is your working thermometer during training:
VAL_SIZE = 5000
x_val, y_val = x_train_full[-VAL_SIZE:], y_train_full[-VAL_SIZE:]
x_train, y_train = x_train_full[:-VAL_SIZE], y_train_full[:-VAL_SIZE]The order matters: split before you compute any statistic or reshape any distribution. The validation set must look like the world youβll be evaluated on β balanced, untouched.
Fashion-MNIST is perfectly balanced (6,000 images per class), which is lovely and completely unrepresentative of real life. Fraud detection, defect inspection, rare-disease screening β real classification problems are skewed, often violently. So weβll manufacture an imbalance in the training set to learn the tools for dealing with one. We gut the three footwear classes down to 300 examples each:
rng = np.random.default_rng(42)
RARE = {5: 300, 7: 300, 9: 300} # Sandal, Sneaker, Ankle boot
keep = np.ones(len(y_train), dtype=bool)
for cls, n_keep in RARE.items():
idx = np.where(y_train == cls)[0]
drop = rng.choice(idx, size=len(idx) - n_keep, replace=False)
keep[drop] = False
x_train, y_train = x_train[keep], y_train[keep]
print(np.bincount(y_train))[5500 5493 5504 5514 5488 300 5514 300 5486 300]
The mechanics: np.where(y_train == cls)[0] gives the row indices of one class, rng.choice(..., replace=False) samples which ones to delete, and the boolean keep mask applies all deletions in a single vectorized pass. Never delete rows in a Python loop with np.delete β it re-allocates the entire array every call.
Now footwear is ~18Γ rarer than everything else in training, while validation and test remain balanced. A naive model will learn that predicting βSneakerβ is rarely worth the risk β and its overall accuracy will barely show it, because sneakers are only 10% of the test set. This is exactly the trap where accuracy lies and a confusion matrix tells the truth.
Normalization as a layer, not a preprocessing script
The classic way to normalize is x = (x - mean) / std in NumPy before training. It works β until you deploy. Six months later, someone feeds the served model raw [0, 255] pixels because the normalization constants lived in a preprocessing script that didnβt ship with the model. This is one of the most common production ML bugs in existence.
Kerasβ answer is keras.layers.Normalization: a layer that stores the statistics as part of the model. Whatever you export on Day 9 will contain the mean and variance inside the graph. There is no preprocessing script to lose.
from keras import layers
normalizer = layers.Normalization(axis=None)
normalizer.adapt(x_train.astype("float32"))
print(float(normalizer.mean), float(normalizer.variance) ** 0.5)72.939 90.020
Three things deserve a close look:
axis=Nonemeans βcompute one scalar mean and one scalar variance over everythingβ. Thatβs right for grayscale images, where every pixel shares one intensity distribution. For tabular data with, say, 13 features of wildly different scales, youβd useaxis=-1to get 13 per-feature means. For RGB images youβd normalize per channel. Getting this axis wrong is silent β the model still trains, just worse.adapt()is not training. Itβs a single statistics pass β the layer iterates over the data once and storesmeanandvarianceas non-trainable weights. No gradients, no epochs. This is the Keras 3 idiom for any stateful preprocessing layer (StringLookup,TextVectorization,Discretizationall work the same way). The PyTorch equivalent is computingx.mean()/x.std()yourself and hard-coding them into atransforms.Normalizeβ Keras just formalizes the pattern and welds the result to the model.- We adapted on
x_trainonly β the post-imbalance training split. Validation and test pixels never influence these numbers.
At inference the layer computes \(\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}\), turning our [0, 255] pixels into roughly zero-mean, unit-variance inputs β the regime where gradient descent with default learning rates behaves well. Feed a raw-pixel MLP the unnormalized values and watch your first-epoch loss start around 60 instead of 2.3, because the initial logits are scaled by inputs ~90Γ too large.
Class weights, the model, and checkpointing the best epoch
Weighting the loss
With 300 sneakers versus 5,500 shirts, each gradient step sees ~18 shirt examples for every sneaker. The standard fix is to scale each exampleβs loss by a per-class weight so rare classes shout louder. The usual βbalancedβ heuristic:
\[w_c = \frac{N}{C \cdot n_c}\]
where \(N\) is the total number of training examples, \(C\) the number of classes, and \(n_c\) the count of class \(c\). A class at exactly average frequency gets weight 1.0; our rare classes get ~13.3.
counts = np.bincount(y_train)
n_classes = len(counts)
class_weight = {c: len(y_train) / (n_classes * counts[c])
for c in range(n_classes)}
for c in (0, 5):
print(f"{CLASS_NAMES[c]:>12}: n={counts[c]:>5} w={class_weight[c]:.2f}") T-shirt/top: n= 5500 w=0.73
Sandal: n= 300 w=13.36
In Keras this is a plain dict passed to fit(class_weight=...) β Keras multiplies each sampleβs loss by the weight of its true class before averaging. The PyTorch equivalent is nn.CrossEntropyLoss(weight=torch.tensor([...])); same math, but Keras attaches it to the training call rather than the loss object, which means your evaluation loss stays unweighted β usually what you want.
The effect on the loss: instead of \(\mathcal{L} = \frac{1}{B}\sum_i \ell_i\), each batch computes \(\mathcal{L} = \frac{1}{B}\sum_i w_{y_i}\, \ell_i\). Misclassifying a sandal now costs ~13 shirtsβ worth of gradient.
The model
A straightforward MLP, functional style (Day 2), with the normalizer as its first layer:
def build_model(normalizer: layers.Normalization) -> keras.Model:
inputs = keras.Input(shape=(28, 28))
x = normalizer(inputs) # stats baked into the graph
x = layers.Flatten()(x) # (batch, 28, 28) -> (batch, 784)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(10)(x) # logits β no softmax here
return keras.Model(inputs, outputs, name="fashion_mlp")
model = build_model(normalizer)
model.summary()Model: "fashion_mlp"
ββββββββββββββββββββββββββββββ³ββββββββββββββββββββ³ββββββββββββ
β Layer (type) β Output Shape β Param # β
β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©
β input_layer (InputLayer) β (None, 28, 28) β 0 β
β normalization (Normalizatβ¦ β (None, 28, 28) β 3 β
β flatten (Flatten) β (None, 784) β 0 β
β dense (Dense) β (None, 256) β 200,960 β
β dropout (Dropout) β (None, 256) β 0 β
β dense_1 (Dense) β (None, 128) β 32,896 β
β dropout_1 (Dropout) β (None, 128) β 0 β
β dense_2 (Dense) β (None, 10) β 1,290 β
ββββββββββββββββββββββββββββββ΄ββββββββββββββββββββ΄ββββββββββββ
Total params: 235,149 (918.55 KB)
Trainable params: 235,146 (918.54 KB)
Non-trainable params: 3 (16.00 B)
Those 3 non-trainable params are the normalizerβs mean, variance, and count β proof the statistics live in the model. And notice the output layer has no softmax: we emit raw logits and tell the loss about it. Computing log(softmax(x)) as two separate ops is numerically unstable for extreme logits; from_logits=True fuses them into a stable log-sum-exp, exactly like PyTorchβs CrossEntropyLoss (which never wants a softmax in the model either).
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)Sparse because our labels are integer class ids (3), not one-hot vectors ([0,0,0,1,...]). Using CategoricalCrossentropy with integer labels raises a shape error β one of the two most common Keras beginner crashes (the other is forgetting from_logits=True, which doesnβt crash at all; it just trains a mediocre model silently, because Keras then treats your raw logits as probabilities).
Callbacks: save the peak, stop the decay
callbacks = [
keras.callbacks.ModelCheckpoint(
"best_fashion.keras",
monitor="val_accuracy",
save_best_only=True,
verbose=1,
),
keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=5,
restore_best_weights=True,
),
]
history = model.fit(
x_train.astype("float32"), y_train,
validation_data=(x_val.astype("float32"), y_val),
epochs=40,
batch_size=128,
class_weight=class_weight,
callbacks=callbacks,
verbose=2,
)Epoch 1/40
Epoch 1: val_accuracy improved from -inf to 0.84280, saving model to best_fashion.keras
430/430 - 3s - loss: 0.7842 - accuracy: 0.7676 - val_loss: 0.4494 - val_accuracy: 0.8428
...
Epoch 14/40
Epoch 14: val_accuracy improved from 0.88700 to 0.88900, saving model to best_fashion.keras
...
Epoch 21/40
Epoch 21: val_accuracy did not improve from 0.88900
430/430 - 2s - loss: 0.2412 - accuracy: 0.9147 - val_loss: 0.3421 - val_accuracy: 0.8852
(Your exact numbers will differ β dropout and shuffling are stochastic β but the shape will match.)
Dissect the setup:
ModelCheckpointwithsave_best_only=Truewrites a full.kerasfile (architecture + weights + optimizer state + our normalizer stats) every timeval_accuracysets a new record, and only then. The monitored quantityβs name must match a logged metric exactly β monitor"val_acc"by mistake and you get a warning per epoch and no checkpoints.EarlyStopping(patience=5)kills the run after 5 epochs withoutval_lossimprovement, andrestore_best_weights=Truerolls the in-memory model back to its best epoch. Belt and suspenders with the checkpoint file: the file survives a crash, the restore fixes the live object.- In PyTorch this whole block is code you write yourself: track
best_val,torch.save(model.state_dict(), ...)inside anif, a patience counter, a manual re-load. Keras callbacks are that loopβs hooks, prepackaged. Day 4βs custom-loop skills tell you whatβs inside the box; today you get to just use the box.
Evaluation: learning curves, then the confusion matrix
Reading the curves
fit() returned a History object whose .history dict holds one list per metric per epoch. Plot loss and accuracy side by side, train against validation:
import matplotlib.pyplot as plt
hist = history.history
epochs = range(1, len(hist["loss"]) + 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
ax1.plot(epochs, hist["loss"], label="train")
ax1.plot(epochs, hist["val_loss"], label="val")
ax1.set(title="Loss", xlabel="epoch"); ax1.legend()
ax2.plot(epochs, hist["accuracy"], label="train")
ax2.plot(epochs, hist["val_accuracy"], label="val")
ax2.set(title="Accuracy", xlabel="epoch"); ax2.legend()
fig.tight_layout(); fig.savefig("curves.png", dpi=120)There are three curve shapes youβll meet again and again. Learn to recognize them on sight:
Our run lands between βhealthyβ and the early stage of βoverfittingβ: training loss keeps dropping while validation loss bottoms out around epoch 13β16 and drifts up. That drift is precisely why the checkpoint monitors validation, and why EarlyStopping ends the run around epoch 20 instead of wasting 20 more.
One subtlety specific to Keras: the training metrics in history are computed during the epoch β a running average while dropout is active and weights are still moving β whereas validation metrics are computed after the epoch with dropout off. So in epoch 1 youβll often see val_accuracy > accuracy. Itβs not a bug; the two numbers arenβt measured under the same conditions.
The confusion matrix
Overall accuracy is a single number averaged over a balanced test set, which means our starved footwear classes can be quietly terrible without moving it much. The confusion matrix breaks the score open. Load the checkpointed model β not the possibly-overfit final state β and touch the test set for the first and only time:
import tensorflow as tf
best = keras.models.load_model("best_fashion.keras")
test_loss, test_acc = best.evaluate(x_test.astype("float32"), y_test, verbose=0)
print(f"test accuracy: {test_acc:.4f}") # β 0.87
logits = best.predict(x_test.astype("float32"), verbose=0) # (10000, 10)
y_pred = logits.argmax(axis=1) # (10000,)
cm = tf.math.confusion_matrix(y_test, y_pred, num_classes=10).numpy()tf.math.confusion_matrix returns a 10Γ10 integer tensor: rows are true classes, columns are predictions (the same convention as scikit-learn β but always check, because some libraries transpose it). Entry \((i, j)\) counts test images of class \(i\) that the model called class \(j\).
Print it with names, plus per-class recall (the diagonal divided by row sums):
recall = cm.diagonal() / cm.sum(axis=1)
for name, r in zip(CLASS_NAMES, recall):
bar = "β" * int(r * 30)
print(f"{name:>12} {r:6.1%} {bar}") T-shirt/top 84.3% βββββββββββββββββββββββββ
Trouser 97.4% βββββββββββββββββββββββββββββ
Pullover 81.0% ββββββββββββββββββββββββ
Dress 89.1% ββββββββββββββββββββββββββ
Coat 83.2% ββββββββββββββββββββββββ
Sandal 93.8% ββββββββββββββββββββββββββββ
Shirt 67.5% ββββββββββββββββββββ
Sneaker 92.6% βββββββββββββββββββββββββββ
Bag 96.5% ββββββββββββββββββββββββββββ
Ankle boot 94.1% ββββββββββββββββββββββββββββ
Two findings jump out, and neither is visible in the headline 87%:
- The class weights worked. Despite training on only 300 examples each, the footwear classes recover to ~93% recall. Rerun without
class_weight(do it β one argument to delete) and watch Sandal/Sneaker/Ankle-boot recall crater into the 60sβ70s while overall accuracy drops only ~3 points. That gap between βaccuracy barely movedβ and βa third of sneakers misclassifiedβ is the entire argument for per-class evaluation. - Shirt is the real problem child β and it was never rare. Look along the Shirt row: its errors flow into T-shirt/top, Pullover, and Coat. These classes genuinely overlap visually at 28Γ28 grayscale, and a fully-connected net that destroyed all spatial structure with
Flattencanβt tell a collar from a crew neck. No amount of reweighting fixes confusion between classes the features canβt separate. Thatβs a representation problem β which is precisely tomorrowβs topic.
A heatmap makes the error flows obvious at a glance:
fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(cm, cmap="Blues")
ax.set_xticks(range(10), CLASS_NAMES, rotation=45, ha="right")
ax.set_yticks(range(10), CLASS_NAMES)
ax.set(xlabel="predicted", ylabel="true")
for i in range(10):
for j in range(10):
if cm[i, j] > 0:
ax.text(j, i, cm[i, j], ha="center", va="center",
color="white" if cm[i, j] > cm.max() / 2 else "black",
fontsize=7)
fig.colorbar(im); fig.tight_layout(); fig.savefig("confusion.png", dpi=120)The complete script
Everything above, assembled in dependency order. Save as day5_fashion.py; it runs end to end in a few minutes on CPU.
"""Day 5 β Fashion-MNIST classification, end to end."""
import numpy as np
import tensorflow as tf
import keras
from keras import layers
import matplotlib.pyplot as plt
keras.utils.set_random_seed(42)
CLASS_NAMES = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
# ---- 1. data: load, split, manufacture imbalance -------------------
(x_full, y_full), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_val, y_val = x_full[-5000:], y_full[-5000:]
x_train, y_train = x_full[:-5000], y_full[:-5000]
rng = np.random.default_rng(42)
keep = np.ones(len(y_train), dtype=bool)
for cls in (5, 7, 9): # starve footwear
idx = np.where(y_train == cls)[0]
keep[rng.choice(idx, size=len(idx) - 300, replace=False)] = False
x_train, y_train = x_train[keep], y_train[keep]
x_train = x_train.astype("float32")
x_val = x_val.astype("float32")
x_test = x_test.astype("float32")
# ---- 2. class weights ----------------------------------------------
counts = np.bincount(y_train)
class_weight = {c: len(y_train) / (len(counts) * counts[c])
for c in range(len(counts))}
# ---- 3. normalization: stats from TRAIN only, stored in the model --
normalizer = layers.Normalization(axis=None)
normalizer.adapt(x_train)
# ---- 4. model -------------------------------------------------------
inputs = keras.Input(shape=(28, 28))
x = layers.Flatten()(normalizer(inputs))
x = layers.Dropout(0.3)(layers.Dense(256, activation="relu")(x))
x = layers.Dropout(0.3)(layers.Dense(128, activation="relu")(x))
outputs = layers.Dense(10)(x) # logits
model = keras.Model(inputs, outputs, name="fashion_mlp")
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)
# ---- 5. train with best-epoch checkpointing ------------------------
history = model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=40, batch_size=128,
class_weight=class_weight,
callbacks=[
keras.callbacks.ModelCheckpoint("best_fashion.keras",
monitor="val_accuracy",
save_best_only=True),
keras.callbacks.EarlyStopping(monitor="val_loss", patience=5,
restore_best_weights=True),
],
verbose=2,
)
# ---- 6. learning curves --------------------------------------------
h, ep = history.history, range(1, len(history.history["loss"]) + 1)
fig, (a1, a2) = plt.subplots(1, 2, figsize=(11, 4))
a1.plot(ep, h["loss"], label="train"); a1.plot(ep, h["val_loss"], label="val")
a1.set(title="Loss", xlabel="epoch"); a1.legend()
a2.plot(ep, h["accuracy"], label="train"); a2.plot(ep, h["val_accuracy"], label="val")
a2.set(title="Accuracy", xlabel="epoch"); a2.legend()
fig.tight_layout(); fig.savefig("curves.png", dpi=120)
# ---- 7. final evaluation: best model, test set, once ---------------
best = keras.models.load_model("best_fashion.keras")
_, test_acc = best.evaluate(x_test, y_test, verbose=0)
print(f"\ntest accuracy: {test_acc:.4f}")
y_pred = best.predict(x_test, verbose=0).argmax(axis=1)
cm = tf.math.confusion_matrix(y_test, y_pred, num_classes=10).numpy()
recall = cm.diagonal() / cm.sum(axis=1)
for name, r in zip(CLASS_NAMES, recall):
print(f"{name:>12} {r:6.1%}")
# sanity checks β the run is broken if any of these fail
assert test_acc > 0.80, "model failed to train"
assert recall[7] > 0.80, "class weights failed: Sneaker recall collapsed"
assert best.layers[1].count_params() == 3 # normalizer stats travel with the modelThe three asserts at the bottom are worth keeping in every training script you write: an accuracy floor, a check on the specific thing you engineered for (rare-class recall), and a structural check that the preprocessing is inside the model. Silent regressions die loudly.
π§ͺ Your task
The class-weight formula treated the imbalance as something to fix in the loss. The other classic remedy fixes it in the data: oversampling β repeat the rare examples so each batch sees them more often.
Your task: instead of passing class_weight to fit(), build an oversampled training set in NumPy β repeat each rare classβs 300 examples until every class has roughly the same count β train on it (no class_weight argument), and compare per-class recall on the test set against todayβs class-weighted run. Which footwear class benefits more from which method? Is overall accuracy different?
Hint: np.repeat on indices does the heavy lifting: get each rare classβs indices, tile them up to ~5,500 with np.resize(idx, 5500), concatenate with all other indices, then shuffle the combined index array before slicing x_train/y_train β otherwise all repeated sandals land in consecutive batches and training destabilizes.
Solution
import numpy as np
import keras
from keras import layers
import tensorflow as tf
keras.utils.set_random_seed(42)
# --- rebuild the imbalanced split exactly as in the lesson ---
(x_full, y_full), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_val, y_val = x_full[-5000:].astype("float32"), y_full[-5000:]
x_train, y_train = x_full[:-5000], y_full[:-5000]
rng = np.random.default_rng(42)
keep = np.ones(len(y_train), dtype=bool)
for cls in (5, 7, 9):
idx = np.where(y_train == cls)[0]
keep[rng.choice(idx, size=len(idx) - 300, replace=False)] = False
x_train, y_train = x_train[keep].astype("float32"), y_train[keep]
# --- oversample: tile rare-class indices up to the majority count ---
target = int(np.bincount(y_train).max()) # ~5500
parts = []
for cls in range(10):
idx = np.where(y_train == cls)[0]
if len(idx) < target:
idx = np.resize(idx, target) # cycles through the 300, repeating
parts.append(idx)
all_idx = rng.permutation(np.concatenate(parts)) # SHUFFLE β critical
x_bal, y_bal = x_train[all_idx], y_train[all_idx]
print(np.bincount(y_bal)) # ~[5500 5500 ... 5500] β balanced by repetition
# --- same model as the lesson, NO class_weight ---
normalizer = layers.Normalization(axis=None)
normalizer.adapt(x_train) # stats from ORIGINAL train data
inputs = keras.Input(shape=(28, 28))
x = layers.Flatten()(normalizer(inputs))
x = layers.Dropout(0.3)(layers.Dense(256, activation="relu")(x))
x = layers.Dropout(0.3)(layers.Dense(128, activation="relu")(x))
model = keras.Model(inputs, layers.Dense(10)(x))
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
)
model.fit(
x_bal, y_bal,
validation_data=(x_val, y_val),
epochs=40, batch_size=128,
callbacks=[
keras.callbacks.ModelCheckpoint("best_oversampled.keras",
monitor="val_accuracy",
save_best_only=True),
keras.callbacks.EarlyStopping(monitor="val_loss", patience=5,
restore_best_weights=True),
],
verbose=2,
)
# --- compare per-class recall against the class-weighted run ---
best = keras.models.load_model("best_oversampled.keras")
y_pred = best.predict(x_test.astype("float32"), verbose=0).argmax(axis=1)
cm = tf.math.confusion_matrix(y_test, y_pred, num_classes=10).numpy()
recall = cm.diagonal() / cm.sum(axis=1)
CLASS_NAMES = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
for name, r in zip(CLASS_NAMES, recall):
print(f"{name:>12} {r:6.1%}")What you should observe: the two methods land within a couple of points of each other on footwear recall and on overall accuracy β which is the real lesson. Mathematically, repeating a sample \(k\) times and weighting its loss by \(k\) produce the same expected gradient; they differ in variance (oversampling shows the literal same 300 images over and over, so the model can memorize them β watch the train/val gap grow a bit faster) and in cost (oversampling makes each epoch ~15% longer here, since the dataset physically grew). class_weight is one argument and zero data copies; itβs the default choice. Oversampling earns its keep when you combine it with data augmentation, so each repetition is a different distorted view β that trick arrives on Day 7.
Key takeaways
- Split before you compute anything:
adapt(), class counts, and imbalance surgery all happen on the training split only; validation and test stay pristine. layers.Normalization+adapt()welds preprocessing statistics into the model as non-trainable weights β the exported model on Day 9 will carry them, so the βforgot to normalize at inferenceβ bug becomes impossible.class_weight={c: N/(CΒ·n_c)}infit()scales each exampleβs loss by its class weight β the Keras counterpart of PyTorchβsCrossEntropyLoss(weight=...), and it rescued 300-example classes to ~93% recall here.- Output raw logits and use
from_logits=True;Sparselosses/metrics for integer labels. The wrong combination either crashes (shape error) or, worse, silently trains a weaker model. ModelCheckpoint(save_best_only=True)+EarlyStopping(restore_best_weights=True): ship the validation peak, not the overfit tail.- Overall accuracy averages away per-class failure; the confusion matrix and per-class recall show which classes fail and where the errors go β and they revealed that our worst class (Shirt) isnβt a data-quantity problem but a representation problem:
Flattenthrew away the spatial structure needed to distinguish it.
Tomorrow we stop flattening: convolutional layers keep the 2D structure of the image, and the shirt-versus-pullover confusion that stumped todayβs MLP starts to melt.
π π Course home | β Day 04 | Day 06 β | π All mini-courses