flowchart TB
A["Load backbone, weights='imagenet'<br/>include_top=False"] --> B["base_model.trainable = False<br/>call with training=False"]
B --> C["Phase 1 β train head only<br/>Adam lr=1e-3, ~10 epochs"]
C --> D{"val accuracy<br/>plateaued?"}
D -- no --> C
D -- yes --> E["Unfreeze top blocks<br/>(keep layers < fine_tune_at frozen)"]
E --> F["RE-COMPILE with lr 10β100x lower<br/>Adam lr=1e-5"]
F --> G["Phase 2 β fine-tune<br/>~10 more epochs, watch val_loss"]
G --> H["Best model β Day 9: save & serve"]
style C fill:#6366f140,stroke:#6366f1
style F fill:#f59e0b40,stroke:#f59e0b
style G fill:#22c55e40,stroke:#22c55e
π Deep Learning with TensorFlow & Keras Β· Day 8 β Transfer Learning: Stand on a Million Shoulders
π π Course home | β Day 07 | Day 09 β | π All mini-courses
Day 8 β Transfer Learning: Stand on a Million Shoulders
Yesterday you fought overfitting the hard way β dropout, weight decay, augmentation, early stopping β squeezing generalization out of a model trained from scratch. Today you sidestep the fight entirely. Instead of asking a randomly-initialized network to discover edges, textures, and shapes from your few thousand images, you borrow a network that already learned them from 1.4 million ImageNet images, bolt a new head onto it, and teach only the part thatβs specific to your problem. This is transfer learning, and on small datasets it isnβt a nice-to-have β itβs the difference between 75% accuracy and 98%. Keras makes the whole recipe almost embarrassingly short via keras.applications, but there are two famous traps (BatchNorm behavior and preprocessing mismatches) that silently ruin results, and weβll walk straight into both β deliberately β so you recognize them forever.
π― Today you will: load a pretrained MobileNetV2 backbone with include_top=False, attach a new classification head with augmentation and preprocessing baked into the model, train in two phases (frozen backbone β unfrozen top blocks at low LR), understand the BatchNorm training=False gotcha, and hit ~98% on cats-vs-dogs with only 2,000 training images.
Why a network trained on ImageNet helps with your problem
A convnet trained on ImageNet doesnβt just learn βthis is a Labrador.β It learns a hierarchy of visual features, and the lower you go, the more universal they are:
Edge detectors and texture filters are useful for any photographic task β cats, X-rays, satellite imagery. Only the last layers (and the classification head) encode βImageNetβs 1000 classes specifically.β So the recipe writes itself:
- Cut off the head (
include_top=False) β you donβt want 1000 ImageNet logits. - Freeze the backbone and train only a small new head. The backbone acts as a fixed feature extractor.
- Optionally unfreeze the top few blocks and continue with a tiny learning rate, letting the task-specific-ish late features adapt to your domain.
Skipping step 2 and fine-tuning everything from the start is the classic beginner mistake: your new head starts with random weights, so its gradients in the first steps are large and random β and they flow back into the pretrained weights, wrecking them before the head learns anything. Freeze first, always.
If you did the PyTorch course: this is the same idea as torchvision.models + param.requires_grad_(False), but the ergonomics differ. In Keras, freezing is layer.trainable = False, and β crucially β it only takes effect after you compile(). Change trainable, forget to recompile, and nothing changes. Keep that in your head all day.
A genuinely small dataset
Transfer learning shines when data is scarce, so letβs be scarce on purpose: the classic filtered cats-vs-dogs set β 2,000 training and 1,000 validation images. Small enough that yesterdayβs from-scratch CNN would overfit within a handful of epochs.
import keras
import tensorflow as tf
path = keras.utils.get_file(
"cats_and_dogs_filtered.zip",
origin="https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip",
extract=True,
)
import pathlib
base_dir = pathlib.Path(path).parent / "cats_and_dogs_filtered_extracted" / "cats_and_dogs_filtered"
IMG_SIZE = (160, 160)
BATCH = 32
train_ds = keras.utils.image_dataset_from_directory(
base_dir / "train", image_size=IMG_SIZE, batch_size=BATCH, shuffle=True, seed=42
)
val_ds = keras.utils.image_dataset_from_directory(
base_dir / "validation", image_size=IMG_SIZE, batch_size=BATCH
)Found 2000 files belonging to 2 classes.
Found 1000 files belonging to 2 classes.
Two Day 3 notes worth repeating. First, image_dataset_from_directory yields (images, labels) batches where images are float32 in [0, 255] β not normalized. Every pretrained backbone has an opinion about the input range it was trained on, and feeding it the wrong range is the second-most-common way to get mysteriously bad accuracy (weβll handle it properly in a minute). Second, finish the pipeline the way Day 3 taught you:
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(AUTOTUNE)
val_ds = val_ds.prefetch(AUTOTUNE)Carve out a test split if you want honest final numbers; for today weβll report on val_ds to keep the code focused on the transfer recipe itself.
The backbone: include_top=False and what comes out of it
keras.applications is a zoo of pretrained architectures β VGG16, ResNet50, MobileNetV2/V3, EfficientNetB0βB7, EfficientNetV2, ConvNeXt β all with ImageNet weights one keyword away. Weβll start with MobileNetV2: small (2.3M backbone params), fast, and accurate enough to demonstrate everything.
base_model = keras.applications.MobileNetV2(
input_shape=IMG_SIZE + (3,),
include_top=False, # chop off the 1000-class ImageNet head
weights="imagenet", # download pretrained weights (~9 MB)
)
base_model.trainable = False # freeze EVERYTHING β phase 1
print(base_model.output.shape)(None, 5, 5, 1280)
Read that shape. With include_top=True youβd get (None, 1000) β ImageNet probabilities, useless to us. With include_top=False the model stops at the last convolutional block: a 5Γ5 spatial grid of 1280-channel feature vectors (160 Γ· 32 = 5; the network downsamples by 32Γ total). This is the βvisual vocabularyβ output β each of the 25 spatial positions is a 1280-dim description of whatβs there. Our headβs first job is to collapse that grid into one vector, and the standard tool is GlobalAveragePooling2D: average over the 5Γ5 grid, giving (None, 1280). Compared to Flatten (which would give 32,000 features and a huge, overfit-prone Dense layer), global pooling is smaller, translation-tolerant, and works at any input resolution.
base_model.trainable = False recursively freezes all 154 layers. Verify the freeze the way youβd verify anything β by counting:
print(f"total params: {base_model.count_params():,}")
model_check = keras.Sequential([base_model])
model_check.compile() # trainable flags are baked in at compile time
print(f"trainable params: {sum(w.shape.num_elements() for w in base_model.trainable_weights):,}")total params: 2,257,984
trainable params: 0
Preprocessing: every backbone has an opinion
MobileNetV2 was trained on inputs scaled to [-1, 1]. ResNet50 (the original) wants BGR channel order with ImageNet means subtracted. EfficientNetV2 embeds its rescaling inside the model and wants raw [0, 255]. Feed any of them the wrong range and you wonβt get an error β youβll get a model that trains, converges, and plateaus 10β20 points below where it should, with nothing to tell you why.
| Backbone | Expected input | Where preprocessing lives |
|---|---|---|
| MobileNetV2 | [-1, 1] | keras.applications.mobilenet_v2.preprocess_input |
| ResNet50 | BGR, mean-subtracted | keras.applications.resnet50.preprocess_input |
| EfficientNetV2B0 | [0, 255] raw | built into the model β pass pixels straight in |
The robust habit: always use the preprocess_input that ships in the same module as the backbone, and put it inside the model so it can never be forgotten at inference time. (Tomorrow, when we export for serving, this decision pays off: the served model accepts raw pixels.)
Assembling the model: augmentation + preprocessing + backbone + head
Now the full model, functional-style (Day 2), with Day 7βs augmentation layers riding along:
data_augmentation = keras.Sequential([
keras.layers.RandomFlip("horizontal"),
keras.layers.RandomRotation(0.1),
], name="augmentation")
preprocess = keras.applications.mobilenet_v2.preprocess_input
inputs = keras.Input(shape=IMG_SIZE + (3,))
x = data_augmentation(inputs) # active only during training, identity in inference
x = preprocess(x) # [0,255] -> [-1,1], baked into the graph
x = base_model(x, training=False) # <-- the single most important line today
x = keras.layers.GlobalAveragePooling2D()(x) # (None, 5, 5, 1280) -> (None, 1280)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(1)(x) # one logit; sigmoid lives in the loss
model = keras.Model(inputs, outputs)Block by block:
data_augmentation(inputs)β augmentation as layers means it runs on-GPU, is exported with the model, and automatically becomes a no-op whentraining=False. Exactly Day 7βs setup.preprocess(x)β for MobileNetV2 this is justx / 127.5 - 1, but expressed via the official function so a backbone swap only requires changing one name.base_model(x, training=False)β stop. This deserves its own section.
The BatchNorm trap: trainable=False is not training=False
Keras has two similarly-named, completely different switches, and BatchNormalization is the layer where confusing them hurts:
layer.trainable = Falseβ a property. βDonβt update this layerβs weights duringfit().βtraining=Falseβ a call argument. βRun this layer in inference mode right now.β
For most layers these coincide. For BatchNorm they donβt: a BatchNorm layer running in training mode normalizes each batch using that batchβs mean and variance β even if its weights are frozen. Your cats-vs-dogs batches have different statistics than ImageNet batches, so training-mode BatchNorm would re-normalize every feature map with the wrong statistics, effectively scrambling the pretrained features the backbone worked so hard to learn. The symptom is brutal and classic: great training accuracy, garbage validation accuracy, and β worse β when you later unfreeze for fine-tuning, the moving statistics get destroyed and accuracy drops below where phase 1 ended.
base_model(x, training=False) pins the backbone into inference mode permanently for this graph, regardless of whether the outer model is training. BatchNorm always uses its ImageNet moving averages. Frozen or unfrozen, phase 1 or phase 2, this stays correct β which is why we set it once, here, and never touch it again.
PyTorch contrast, for those cross-referencing: PyTorch separates the same two concepts as requires_grad (per-parameter) vs module.eval() (per-module mode), and the same trap exists β freezing parameters without calling .eval() on BatchNorm modules. Neither framework saves you automatically; both make you learn this the hard way exactly once.
Compile and inspect
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy(name="acc")],
)
model.summary()Model: "functional"
βββββββββββββββββββββββββββββββββββ³βββββββββββββββββββββββββ³ββββββββββββββ
β Layer (type) β Output Shape β Param # β
β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©
β input_layer (InputLayer) β (None, 160, 160, 3) β 0 β
β augmentation (Sequential) β (None, 160, 160, 3) β 0 β
β true_divide, subtract (ops) β (None, 160, 160, 3) β 0 β
β mobilenetv2_1.00_160 (Functionβ¦ β (None, 5, 5, 1280) β 2,257,984 β
β global_average_pooling2d β (None, 1280) β 0 β
β dropout (Dropout) β (None, 1280) β 0 β
β dense (Dense) β (None, 1) β 1,281 β
βββββββββββββββββββββββββββββββββββ΄βββββββββββββββββββββββββ΄ββββββββββββββ
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
1,281 trainable parameters. Thatβs the whole learning problem now: one weight vector of length 1280 plus a bias, choosing a linear boundary in the backboneβs feature space. With 2,000 examples and 1,281 parameters, overfitting β yesterdayβs arch-enemy β barely gets a foothold. This is the deep reason transfer learning works on small data: youβre not training a deep network, youβre training logistic regression on world-class features.
One logit + from_logits=True is the same numerical-stability discipline from Day 5: never put a sigmoid in the model and a plain crossentropy after it; let the loss fuse them.
Phase 1: train the head
initial_epochs = 10
history = model.fit(train_ds, validation_data=val_ds, epochs=initial_epochs)Epoch 1/10
63/63 βββββββββββββββ 8s - acc: 0.72 - loss: 0.52 - val_acc: 0.93 - val_loss: 0.20
Epoch 2/10
63/63 βββββββββββββββ 3s - acc: 0.89 - loss: 0.28 - val_acc: 0.96 - val_loss: 0.13
...
Epoch 10/10
63/63 βββββββββββββββ 3s - acc: 0.95 - loss: 0.14 - val_acc: 0.975 - val_loss: 0.076
Notice two things youβd never see training from scratch. First, validation accuracy is 93% after one epoch β the features were already good; the head just had to find a direction in feature space. Second, val_acc runs above train acc. Thatβs not a bug: dropout and augmentation are active during training but off during validation, and the reported training accuracy is averaged over the epoch while validation is measured at the end. From-scratch models on 2,000 images would be deep into overfitting by epoch 10; this one hasnβt started.
~97.5% with a linear head. Now letβs go get the last point.
Phase 2: unfreeze the top blocks, drop the learning rate
The backboneβs late layers encode ImageNet-flavored object parts. Our task is close to ImageNetβs domain (photos of animals), but βcloseβ isnβt βidenticalβ β those layers can do a bit better if we let them adapt. The recipe:
base_model.trainable = True # unfreeze everything...
fine_tune_at = 100 # ...then re-freeze the bottom
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
print(f"layers in backbone: {len(base_model.layers)}")
print(f"unfrozen layers: {sum(l.trainable for l in base_model.layers)}")layers in backbone: 154
unfrozen layers: 54
Why layer 100? MobileNetV2 is 16 inverted-residual blocks; layer index 100 lands around block 12, so weβre adapting roughly the top third β the most task-specific part of the hierarchy, per the diagram at the top of the lesson. Itβs a knob, not a law: closer domains β unfreeze less, farther domains (medical, satellite) β unfreeze more.
Now the part everyone forgets:
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-5), # 100x lower!
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy(name="acc")],
)
model.summary(show_trainable=False) Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
Three deliberate choices in that recompile:
- Recompile at all.
trainablechanges are inert untilcompile()runs again. If your βfine-tuningβ run trains suspiciously fast and improves nothing, you almost certainly skipped this. - Learning rate 1e-5, down from 1e-3. The pretrained weights are already good; large steps would overwrite them faster than your 2,000 images can teach. Rule of thumb: fine-tune at 10β100Γ below your head-training LR. This is Kerasβs blunt version of PyTorchβs discriminative learning rates (parameter groups with per-group LRs) β one low LR for everything unfrozen, which in practice gets you most of the benefit.
training=Falsestays. We set it when we built the graph, so BatchNorm keeps using ImageNet statistics even though the surrounding conv weights now update. For small datasets this is exactly what you want β 63 batches per epoch is nowhere near enough to re-estimate stable batch statistics.
Continue training from where phase 1 stopped, so the history and any LR schedules line up:
fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs
history_ft = model.fit(
train_ds,
validation_data=val_ds,
epochs=total_epochs,
initial_epoch=len(history.epoch), # resume epoch numbering at 10
callbacks=[
keras.callbacks.EarlyStopping( # Day 7 muscle memory
monitor="val_loss", patience=3, restore_best_weights=True
)
],
)Epoch 11/20
63/63 βββββββββββββββ 9s - acc: 0.94 - loss: 0.15 - val_acc: 0.977 - val_loss: 0.068
Epoch 12/20
63/63 βββββββββββββββ 6s - acc: 0.95 - loss: 0.12 - val_acc: 0.980 - val_loss: 0.060
...
Epoch 17/20
63/63 βββββββββββββββ 6s - acc: 0.97 - loss: 0.077 - val_acc: 0.984 - val_loss: 0.049
From 97.5% to ~98.4%. A point of accuracy might sound small, but itβs a 36% reduction in error rate β on 2,000 training images, from ten minutes of compute. Plot both phases end to end and the transition tells the story:
import matplotlib.pyplot as plt
acc = history.history["val_acc"] + history_ft.history["val_acc"]
loss = history.history["val_loss"] + history_ft.history["val_loss"]
plt.figure(figsize=(8, 3))
plt.plot(acc, label="val accuracy")
plt.axvline(initial_epochs - 0.5, ls="--", c="gray", label="start fine-tuning")
plt.legend(); plt.xlabel("epoch"); plt.grid(alpha=0.3)
plt.show()You should see a flat-ish plateau through epoch 10, then a visible step upward right after the dashed line. If instead you see a cliff downward at the dashed line, diagnose in this order: LR too high (most likely), forgot training=False on the backbone (BatchNorm statistics getting trashed), or forgot to recompile (nothing actually changed and the βdropβ is noise).
Swapping backbones: EfficientNetV2 in three edited lines
The whole point of keras.applications sharing one interface is that upgrading the backbone is a find-and-replace, not a rewrite. EfficientNetV2B0 is a stronger, still-small model β and it demonstrates the built-in-preprocessing variant:
base_model = keras.applications.EfficientNetV2B0(
input_shape=IMG_SIZE + (3,),
include_top=False,
weights="imagenet",
include_preprocessing=True, # rescaling lives INSIDE the model
)
base_model.trainable = False
inputs = keras.Input(shape=IMG_SIZE + (3,))
x = data_augmentation(inputs)
x = base_model(x, training=False) # note: NO preprocess_input β raw [0,255] in
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)The deleted preprocess(x) line is the point: EfficientNetV2βs first layers are its preprocessing, so adding MobileNetβs [-1, 1] scaling on top would double-preprocess and tank accuracy β the exact class of silent bug the table in section 3 warned about. When you swap backbones, the preprocessing decision must move with it. Everything else β the freeze, the two phases, the recompile discipline β is identical.
A rough guide for choosing, all at include_top=False sizes:
| Backbone | Backbone params | ImageNet top-1 | Sweet spot |
|---|---|---|---|
| MobileNetV2 | 2.3M | 71.3% | mobile/edge, fast iteration |
| EfficientNetV2B0 | 5.9M | 78.7% | best small-model accuracy |
| ResNet50 | 23.6M | 74.9% | the sturdy baseline everyone knows |
| ConvNeXtTiny | 27.8M | 81.3% | when accuracy matters and GPU is decent |
π§ͺ Your task
Take todayβs two-phase recipe and apply it to a harder, less ImageNet-like dataset: tf_flowers (5 classes, ~3,700 images, loadable via tfds.load("tf_flowers") or the Keras download URL). Build an EfficientNetV2B0 version end to end: split 80/10/10, train the head (phase 1), then fine-tune the top ~30% of layers at low LR (phase 2). Report validation accuracy after each phase β you should see phase 1 land around 88β91% and phase 2 add several points. Since this is 5-class, the head and loss must change from todayβs binary setup.
Hint: the head becomes Dense(5) and the loss becomes SparseCategoricalCrossentropy(from_logits=True) (labels from image_dataset_from_directory are integer-encoded). Remember include_preprocessing=True means no preprocess_input, and remember what must happen after you flip trainable flags.
Solution
import keras
import tensorflow as tf
import pathlib
# --- data ---
data_dir = keras.utils.get_file(
"flower_photos.tgz",
origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
extract=True,
)
data_dir = pathlib.Path(data_dir).parent / "flower_photos_extracted" / "flower_photos"
IMG_SIZE, BATCH = (224, 224), 32
train_ds, val_ds = keras.utils.image_dataset_from_directory(
data_dir, validation_split=0.2, subset="both", seed=42,
image_size=IMG_SIZE, batch_size=BATCH,
)
# carve val into val/test halves
n = val_ds.cardinality() // 2
test_ds = val_ds.take(n)
val_ds = val_ds.skip(n)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(AUTOTUNE)
val_ds, test_ds = val_ds.prefetch(AUTOTUNE), test_ds.prefetch(AUTOTUNE)
# --- model ---
base_model = keras.applications.EfficientNetV2B0(
input_shape=IMG_SIZE + (3,), include_top=False,
weights="imagenet", include_preprocessing=True,
)
base_model.trainable = False
augment = keras.Sequential([
keras.layers.RandomFlip("horizontal"),
keras.layers.RandomRotation(0.1),
keras.layers.RandomZoom(0.1),
])
inputs = keras.Input(shape=IMG_SIZE + (3,))
x = augment(inputs)
x = base_model(x, training=False) # BatchNorm pinned to inference mode
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.3)(x)
outputs = keras.layers.Dense(5)(x) # 5 classes, logits
model = keras.Model(inputs, outputs)
# --- phase 1: head only ---
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
h1 = model.fit(train_ds, validation_data=val_ds, epochs=8)
print("phase 1 val acc:", max(h1.history["val_accuracy"]))
# --- phase 2: unfreeze top ~30%, low LR, RECOMPILE ---
base_model.trainable = True
fine_tune_at = int(len(base_model.layers) * 0.7)
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
model.compile( # <- mandatory after changing trainable
optimizer=keras.optimizers.Adam(1e-5),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
h2 = model.fit(
train_ds, validation_data=val_ds,
epochs=8 + 10, initial_epoch=len(h1.epoch),
callbacks=[keras.callbacks.EarlyStopping(
monitor="val_loss", patience=3, restore_best_weights=True)],
)
print("phase 2 val acc:", max(h2.history["val_accuracy"]))
# --- honest final number ---
loss, acc = model.evaluate(test_ds)
print(f"test accuracy: {acc:.3f}")Typical result: ~90% after phase 1, ~94β96% after phase 2. Flowers are less ImageNet-central than cats and dogs (ImageNet has flower classes, but fewer and coarser), so fine-tuning buys more here β which is exactly the pattern the feature-hierarchy picture predicts: the farther your domain from the pretraining domain, the more the late layers need to move.
Key takeaways
- Pretrained backbones learn a feature hierarchy: early layers are universal, late layers are task-flavored, the head is task-specific. Transfer learning keeps the universal parts and replaces the rest.
include_top=False+GlobalAveragePooling2D+ your ownDensehead is the standard surgery; with the backbone frozen youβre training ~1K parameters, so tiny datasets stop being a problem.- The two-phase recipe: freeze β train head at normal LR β unfreeze top blocks β recompile at 10β100Γ lower LR β fine-tune with early stopping. Never fine-tune under a randomly-initialized head.
trainable=False(donβt update weights) andtraining=False(run in inference mode) are different switches; call the backbone withtraining=Falseso BatchNorm keeps its ImageNet statistics through both phases.- Changing
trainabledoes nothing until youcompile()again β the most common silent failure in Keras fine-tuning. - Every backbone has a preprocessing contract (MobileNetV2 wants [-1,1]; EfficientNetV2 wants raw [0,255] with
include_preprocessing=True). Bake it into the model so inference can never get it wrong.
Tomorrow, that βbake it into the modelβ discipline pays out: we take this fine-tuned model, save it properly, and deploy it as a served endpoint that accepts raw images β Day 9: Saving & deployment.
π π Course home | β Day 07 | Day 09 β | π All mini-courses