flowchart LR
A["Input 32×32×3"] --> B["Augmentation<br/>RandomFlip · RandomRotation · RandomZoom<br/>(training only)"]
B --> C["Rescaling 1/255"]
C --> D["Conv block ×2<br/>32 filters → pool"]
D --> E["Conv block ×2<br/>64 filters → pool"]
E --> F["Conv block<br/>128 filters → pool"]
F --> G["GlobalAveragePooling2D<br/>4×4×128 → 128"]
G --> H["Dense 10<br/>softmax logits"]
📊 Deep Learning with TensorFlow & Keras · Lesson 6 — CNNs in Keras: Convolutions, Pooling, and Augmentation as Layers
🏠 📊 Course home | ← Lesson 05 | Lesson 07 → | 📚 All mini-courses
Lesson 6 — CNNs in Keras: Convolutions, Pooling, and Augmentation as Layers
In the previous lesson you built an end-to-end classifier with Dense layers, and it worked — but only because we flattened every image into a long vector and threw away the one thing images have that tabular data doesn’t: spatial structure. A pixel’s meaning depends on its neighbors. In this lesson you’ll build a convolutional neural network (CNN) that exploits that structure directly, train it on CIFAR-10, and pick up three habits that separate toy CNNs from production ones: doing the shape math in your head before Keras does it for you, baking augmentation into the model as layers, and choosing GlobalAveragePooling2D over Flatten when it matters. You’ll finish by pulling the trained first-layer filters out of the network and looking at what the model actually learned.
🎯 In this lesson you will: master the Conv2D/MaxPooling2D output-shape formula, build augmentation directly into a Keras model with RandomFlip/RandomRotation, train an image CNN on CIFAR-10, compare GlobalAveragePooling2D against Flatten by parameter count, visualize what first-layer filters learn
Why convolutions — locality and weight sharing
A Dense layer connecting a 32×32×3 image to 128 units needs \(32 \times 32 \times 3 \times 128 = 393{,}216\) weights, and every one of them is tied to a specific pixel position. Shift the image one pixel to the right and, as far as the network is concerned, it’s a completely different input. That’s two problems in one: too many parameters, and no translation tolerance.
A convolution fixes both with a single idea: instead of one weight per pixel, learn a small kernel (say 3×3) and slide it across the whole image, computing a dot product at every position. The same 27 weights (3×3×3 for an RGB input) get reused at every location — that’s weight sharing — and because the kernel only ever looks at a 3×3 neighborhood, it’s forced to learn local patterns like edges and color blobs. A pattern learned in the top-left corner is automatically detected in the bottom-right.
One kernel produces one 2-D feature map. A Conv2D(32, 3) layer learns 32 independent kernels, producing 32 feature maps stacked along the channel axis. Each map answers one question everywhere at once: “is there a vertical edge here? a green-to-red transition here?”
If you’re coming from the PyTorch course: Keras/TensorFlow default to channels-last tensors, (batch, height, width, channels), while PyTorch uses (batch, channels, height, width). This bites people porting code — a CIFAR-10 batch is (N, 32, 32, 3) here, not (N, 3, 32, 32).
The shape math you must be able to do in your head
The single most useful formula in CNN work:
\[ o = \left\lfloor \frac{n + 2p - k}{s} \right\rfloor + 1 \]
where \(n\) is input size, \(k\) kernel size, \(s\) stride, \(p\) padding per side. Keras hides \(p\) behind two strings:
padding="valid"(default): no padding, \(o = \lfloor (n-k)/s \rfloor + 1\). The map shrinks.padding="same": pad just enough that with stride 1, output size equals input size (\(o = \lceil n/s \rceil\) in general).
And the two rules that go with it:
- Channels in don’t affect output spatial size — but they do affect parameter count: a
Conv2D(F, k)on \(C\) input channels has \(k \cdot k \cdot C \cdot F + F\) parameters (the \(+F\) is biases). MaxPooling2D(2)halves height and width (stride defaults to pool size) and has zero parameters — it’s a fixed max over each 2×2 window.
Let’s verify against Keras rather than trusting me. Keras 3 builds layers lazily, so we can probe shapes by just calling layers on a dummy tensor:
import keras
from keras import layers
import numpy as np
x = np.zeros((1, 32, 32, 3), dtype="float32") # one fake CIFAR image, channels-last
conv_valid = layers.Conv2D(32, 3, padding="valid")
conv_same = layers.Conv2D(32, 3, padding="same")
pool = layers.MaxPooling2D(2)
print("valid:", conv_valid(x).shape) # (1, 30, 30, 32) ← 32-3+1 = 30
print("same :", conv_same(x).shape) # (1, 32, 32, 32) ← padded to preserve size
print("pool :", pool(conv_same(x)).shape) # (1, 16, 16, 32) ← halvedvalid: (1, 30, 30, 32)
same : (1, 32, 32, 32)
pool : (1, 16, 16, 32)
Check the parameter count against the formula. conv_same has \(3 \cdot 3 \cdot 3 \cdot 32 + 32 = 896\) parameters:
conv_same.build((None, 32, 32, 3)) # or just call it once, as above
kernel, bias = conv_same.weights
print(kernel.shape, bias.shape) # (3, 3, 3, 32) (32,)
print(conv_same.count_params()) # 896(3, 3, 3, 32) (32,)
896
Note the kernel weight layout: (kernel_h, kernel_w, channels_in, channels_out). We’ll use that layout again when we visualize filters. (PyTorch stores the transpose-ish (out, in, kh, kw) — another portability gotcha.)
Here’s the shape trace for the full network we’re about to build. Learn to produce this table by hand; it’s how you debug “expected shape X, got Y” errors before they happen:
| Layer | Output shape | Params | Why |
|---|---|---|---|
| Input | (32, 32, 3) | 0 | CIFAR-10 |
| Conv2D(32, 3, same) | (32, 32, 32) | 896 | 3·3·3·32+32 |
| Conv2D(32, 3, same) | (32, 32, 32) | 9,248 | 3·3·32·32+32 |
| MaxPooling2D(2) | (16, 16, 32) | 0 | halve |
| Conv2D(64, 3, same) | (16, 16, 64) | 18,496 | 3·3·32·64+64 |
| Conv2D(64, 3, same) | (16, 16, 64) | 36,928 | 3·3·64·64+64 |
| MaxPooling2D(2) | (8, 8, 64) | 0 | halve |
| Conv2D(128, 3, same) | (8, 8, 128) | 73,856 | 3·3·64·128+128 |
| MaxPooling2D(2) | (4, 4, 128) | 0 | halve |
Notice the classic CNN rhythm: spatial size shrinks, channel depth grows. The network trades “where” resolution for “what” richness as you go deeper.
Augmentation as layers inside the model
CIFAR-10 has only 50,000 training images, and our network has enough capacity to memorize them. The cheapest fix is data augmentation: show the network randomly flipped and rotated variants so it can’t latch onto exact pixel arrangements.
In the PyTorch world you’d typically do this in the Dataset/transforms pipeline. Keras offers something nicer: augmentation as layers, inside the model itself.
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.05), # ±5% of a full turn ≈ ±18°
layers.RandomZoom(0.1),
],
name="augmentation",
)Two things make this approach worth adopting:
- Train/inference asymmetry is automatic. These layers are active only when the model is called with
training=True— whichmodel.fit()does for you. At evaluation and inference time (model.evaluate,model.predict, serving), they become identity functions. No “oops, I augmented my test set” bugs, and no separate preprocessing code path to keep in sync when you deploy on Lesson 9. - It runs where the model runs. On GPU, augmentation happens on-device as part of the forward pass, instead of on CPU in the input pipeline. And because it’s part of the saved model, whoever loads your
.kerasfile gets the exact same preprocessing for free.
What breaks if you do it wrong: putting random augmentation in your tf.data pipeline and forgetting to disable it for the validation split is a classic way to get mysteriously low validation accuracy. The layer approach makes that mistake structurally impossible.
One caution: don’t add augmentations that destroy the label. Horizontal flips are safe for CIFAR-10 (a flipped cat is a cat); vertical flips are not (an upside-down truck is not a typical truck), and for digit datasets like MNIST even horizontal flips are wrong (a flipped “3” isn’t a “3”).
The overall architecture we’re assembling:
Note the Rescaling(1/255) layer sits inside the model too, right after augmentation — same philosophy. The model accepts raw uint8-style pixel values and owns its own normalization. Ship the model, ship the preprocessing.
Build and train the CNN on CIFAR-10
Now the full, runnable program. Stage one — data. We keep the pipeline dead simple because the model owns preprocessing:
import keras
from keras import layers
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print(x_train.shape, x_train.dtype, y_train.shape)
# (50000, 32, 32, 3) uint8 (50000, 1)
class_names = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
train_ds = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(10_000)
.batch(128)
.prefetch(tf.data.AUTOTUNE)
)
val_ds = (
tf.data.Dataset.from_tensor_slices((x_test, y_test))
.batch(256)
.prefetch(tf.data.AUTOTUNE)
)This is the Lesson 3 pattern, minus any .map() — no normalization, no augmentation in the pipeline. Raw uint8 tensors flow straight into the model. (Casting happens implicitly: Rescaling multiplies by a float, promoting the input to float32.)
Stage two — the model, using the Functional API from Lesson 2:
def build_cnn(num_classes: int = 10) -> keras.Model:
inputs = keras.Input(shape=(32, 32, 3))
x = data_augmentation(inputs) # identity at inference time
x = layers.Rescaling(1.0 / 255)(x)
x = layers.Conv2D(32, 3, padding="same", activation="relu")(x)
x = layers.Conv2D(32, 3, padding="same", activation="relu")(x)
x = layers.MaxPooling2D(2)(x) # → 16×16×32
x = layers.Conv2D(64, 3, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, padding="same", activation="relu")(x)
x = layers.MaxPooling2D(2)(x) # → 8×8×64
x = layers.Conv2D(128, 3, padding="same", activation="relu")(x)
x = layers.MaxPooling2D(2)(x) # → 4×4×128
x = layers.GlobalAveragePooling2D()(x) # → 128
outputs = layers.Dense(num_classes)(x) # logits, no softmax
return keras.Model(inputs, outputs, name="cifar10_cnn")
model = build_cnn()
model.summary()Methodology notes, block by block:
- Two convs before each pool. Two stacked 3×3 convolutions see a 5×5 neighborhood (receptive fields compose) but with fewer parameters and an extra nonlinearity compared to one 5×5 conv. This is the VGG insight and it’s still the default idiom.
padding="same"everywhere. Withvalidpadding, each 3×3 conv shaves 2 pixels off; on a 32×32 input you’d run out of image fast. Let pooling — a deliberate, controlled operation — be the only thing that shrinks the map.- Logits out, no
softmax. Exactly as on Lesson 5: we’ll passfrom_logits=Trueto the loss. Softmax-then-log is numerically worse than a fused log-softmax, and Keras fuses it for you when you hand it logits. - Notice what you didn’t have to compute: the flatten size. In PyTorch you’d hardcode
nn.Linear(4*4*128, ...)and get a shape error when you change the architecture. Keras infers it at build time — and withGlobalAveragePooling2Dthere’s nothing to infer anyway, as the next section explains.
The tail of model.summary():
conv2d_4 (Conv2D) (None, 8, 8, 128) 73,856
max_pooling2d_2 (MaxPooling2D) (None, 4, 4, 128) 0
global_average_pooling2d (None, 128) 0
dense (Dense) (None, 10) 1,290
Total params: 140,714 (549.66 KB)
Trainable params: 140,714 (549.66 KB)
140K parameters — about a third of what a single dense layer on the flattened raw image would cost. That’s weight sharing doing its job.
Stage three — compile and fit, straight from Lesson 4’s compile/fit path:
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)
history = model.fit(train_ds, validation_data=val_ds, epochs=20)Epoch 1/20
391/391 ━━━━━━━━━━ 14s 30ms/step - acc: 0.3194 - loss: 1.8221 - val_acc: 0.4788 - val_loss: 1.4304
...
Epoch 20/20
391/391 ━━━━━━━━━━ 12s 29ms/step - acc: 0.7761 - loss: 0.6382 - val_acc: 0.7723 - val_loss: 0.6650
Expect roughly 75–78% validation accuracy after 20 epochs (exact numbers vary by hardware and seed). Two observations worth internalizing:
- Training accuracy ≈ validation accuracy. The augmentation is doing its job — without it, this network hits ~99% train / ~72% val by epoch 20: textbook overfitting. Augmentation makes the training task harder (train accuracy is measured on augmented images), which is why train accuracy can even sit slightly below validation accuracy early on. That surprises people; it’s normal and healthy.
- In the previous lesson’s dense network plateaued around 45–50% on CIFAR-10. Same data, same loss, same optimizer — the ~28-point jump is purely architecture. That’s the whole argument for convolutions in one number.
GlobalAveragePooling2D vs Flatten
At the end of the conv stack you have a (4, 4, 128) tensor and a decision: how do you get from a spatial feature map to a vector the classifier head can consume?
Option A — Flatten: unroll everything into a \(4 \cdot 4 \cdot 128 = 2048\)-vector. Every one of the 16 spatial positions keeps its own identity, and the following Dense(10) needs \(2048 \cdot 10 + 10 = 20{,}490\) parameters — each tied to a specific position in the final map.
Option B — GlobalAveragePooling2D (GAP): average each feature map over its spatial extent, (4, 4, 128) → (128,). The head becomes Dense(10) with \(128 \cdot 10 + 10 = 1{,}290\) parameters, and it can only ask “how strongly did filter \(j\) fire, anywhere?” — position is deliberately discarded.
Prove the parameter claim in three lines:
def head_params(pool_layer):
inp = keras.Input(shape=(4, 4, 128))
out = layers.Dense(10)(pool_layer(inp))
return keras.Model(inp, out).count_params()
print("Flatten head:", head_params(layers.Flatten())) # 20490
print("GAP head :", head_params(layers.GlobalAveragePooling2D())) # 1290When to choose which:
- GAP is the modern default (every ResNet/EfficientNet uses it). Fewer parameters means less overfitting in the head, and — a genuinely useful bonus — the model becomes input-size agnostic: a conv stack ending in GAP produces a 128-vector whether the input was 32×32 or 96×96.
Flattenwould produce different vector lengths and crash theDenselayer. This matters on Lesson 8, when pretrained backbones get reused at new resolutions. - Flatten still earns its keep when the map is tiny and position is the signal — e.g., digit recognition where “loop in the top half” vs “loop in the bottom half” distinguishes 9 from 6, or any task on small centered inputs where you can afford the parameters.
For our CIFAR net, swapping GAP for Flatten adds ~19K head parameters and typically buys nothing but a slightly bigger train/val gap. We keep GAP.
What did the network learn? Visualizing first-layer filters
First-layer filters are the only ones you can interpret directly, because they operate on raw RGB: each is a 3×3×3 patch you can render as a tiny image. Deeper filters mix already-abstract channels and need fancier techniques (activation maximization — a Lesson 7+ rabbit hole). Let’s pull them out.
import matplotlib.pyplot as plt
# first Conv2D in the model (skip augmentation/rescaling layers)
first_conv = next(l for l in model.layers if isinstance(l, layers.Conv2D))
kernels = first_conv.get_weights()[0] # (3, 3, 3, 32): kh, kw, in, out
print(kernels.shape)
# normalize each filter independently to [0, 1] for display
k_min = kernels.min(axis=(0, 1, 2), keepdims=True)
k_max = kernels.max(axis=(0, 1, 2), keepdims=True)
k_disp = (kernels - k_min) / (k_max - k_min + 1e-8)
fig, axes = plt.subplots(4, 8, figsize=(8, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(k_disp[:, :, :, i], interpolation="nearest")
ax.set_axis_off()
fig.suptitle("First-layer 3×3 filters (RGB)")
plt.tight_layout()
plt.show()Methodology notes: get_weights()[0] returns the kernel as a NumPy array in that (kh, kw, in, out) layout we confirmed earlier, so kernels[:, :, :, i] is filter \(i\) as a 3×3 RGB micro-image. The per-filter normalization matters — raw weights are small values centered near zero, and imshow would clip them to mud without it. And we grab the layer by type, not by index, because augmentation and rescaling layers sit in front of it.
What you’ll see is grainy at 3×3 resolution, but squint and the structure is there — and it’s remarkably consistent across runs, architectures, and even species (biological V1 neurons learn the same vocabulary):
Nobody told the network to become an edge detector. Gradient descent on “name this object” rediscovered Gabor-like edge filters and color-opponent cells because they’re the optimal first move for almost any vision task. That universality is precisely why transfer learning works — Lesson 8 will take a network whose early layers learned this vocabulary on 1.4 million ImageNet photos and reuse it on your data.
To close the loop, watch a filter act on a real image by running the input through just the first conv layer. The Functional API makes carving out a sub-model trivial:
probe = keras.Model(model.inputs, first_conv.output)
fmap = probe(x_test[:1], training=False) # training=False: no augmentation
print(fmap.shape) # (1, 32, 32, 32)
fig, axes = plt.subplots(2, 8, figsize=(10, 3))
for i, ax in enumerate(axes.flat):
ax.imshow(fmap[0, :, :, i], cmap="viridis")
ax.set_axis_off()
plt.show()Each of those 16 panels is one filter’s “opinion map” of the same image — one lights up on horizontal contours, another on a particular color transition. training=False matters here: without it, the augmentation layers would randomly flip and rotate your probe image and you’d be visualizing responses to an input you never saw.
🧪 Your task
The current network downsamples 32→16→8→4 and stops. Add a fourth conv block — Conv2D(256, 3, padding="same", activation="relu") followed by MaxPooling2D(2) — before the GAP layer. Before running anything: (1) compute by hand the output shape after the new pool and the new block’s parameter count, and (2) predict how the Dense head’s parameter count changes. Then build it, check your numbers against model.summary(), and train for 20 epochs. Did validation accuracy improve?
Hint: the input to the new block is (4, 4, 128). Apply the conv-parameter formula \(k \cdot k \cdot C_{in} \cdot F + F\), and remember what GAP does to the head when the number of channels changes — the spatial size of GAP’s input is irrelevant to the head.
Solution
Hand math first. The new conv keeps spatial size (same padding, stride 1): (4, 4, 256). The pool halves it: (2, 2, 256). Parameter count of the new conv: \(3 \cdot 3 \cdot 128 \cdot 256 + 256 = 295{,}168\) — this one block has more than twice the parameters of the entire previous network, because cost scales with \(C_{in} \cdot C_{out}\). The GAP output is now 256-dim, so the head grows from \(128 \cdot 10 + 10 = 1{,}290\) to \(256 \cdot 10 + 10 = 2{,}570\).
def build_cnn_v2(num_classes: int = 10) -> keras.Model:
inputs = keras.Input(shape=(32, 32, 3))
x = data_augmentation(inputs)
x = layers.Rescaling(1.0 / 255)(x)
x = layers.Conv2D(32, 3, padding="same", activation="relu")(x)
x = layers.Conv2D(32, 3, padding="same", activation="relu")(x)
x = layers.MaxPooling2D(2)(x) # 16×16×32
x = layers.Conv2D(64, 3, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, padding="same", activation="relu")(x)
x = layers.MaxPooling2D(2)(x) # 8×8×64
x = layers.Conv2D(128, 3, padding="same", activation="relu")(x)
x = layers.MaxPooling2D(2)(x) # 4×4×128
# --- new block ---
x = layers.Conv2D(256, 3, padding="same", activation="relu")(x)
x = layers.MaxPooling2D(2)(x) # 2×2×256
x = layers.GlobalAveragePooling2D()(x) # 256
outputs = layers.Dense(num_classes)(x)
return keras.Model(inputs, outputs, name="cifar10_cnn_v2")
model_v2 = build_cnn_v2()
model_v2.summary() # confirm: (2,2,256), conv params 295,168, dense 2,570
model_v2.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)
history_v2 = model_v2.fit(train_ds, validation_data=val_ds, epochs=20)
print("v1 best val acc:", max(history.history["val_acc"]))
print("v2 best val acc:", max(history_v2.history["val_acc"]))Typical result: v2 lands around 78–80% validation accuracy — a modest gain of a point or two for tripling the parameter count, and the train/val gap starts opening up. That’s the diminishing-returns wall you hit when you scale capacity without scaling regularization, and it’s exactly the cliffhanger Lesson 7 resolves with dropout, batch normalization, weight decay, and systematic hyperparameter tuning.
Key takeaways
- Convolutions replace per-pixel weights with small shared kernels: fewer parameters, built-in translation tolerance, locality as an inductive bias.
- Output-shape formula: \(o = \lfloor (n + 2p - k)/s \rfloor + 1\);
padding="same"preserves size at stride 1,MaxPooling2D(2)halves it with zero parameters. Conv params: \(k^2 \cdot C_{in} \cdot F + F\). - Keras is channels-last:
(batch, H, W, C)— the opposite of PyTorch. - Put
RandomFlip/RandomRotation/RandomZoom(andRescaling) inside the model: automatically off at inference, runs on-device, ships with the saved model. Pick augmentations that preserve the label. - The CNN rhythm: spatial size shrinks while channel depth grows — trading “where” for “what”.
GlobalAveragePooling2DbeatsFlattenas the default head: ~16× fewer head parameters here, and it makes the model input-size agnostic.Flattenonly when position itself is the signal.- First-layer filters converge to edge detectors and color-opponent blobs no matter the task — the universality that makes transfer learning possible.
In the next lesson: our bigger network is starting to memorize — Lesson 7 fights back with dropout, batch normalization, weight decay, and KerasTuner.
🏠 📊 Course home | ← Lesson 05 | Lesson 07 → | 📚 All mini-courses