Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

In this lesson

  • Lesson 9 — Saving, Exporting & Serving Your Model
    • Checkpoints done right: state_dict, not pickled models
    • Inference done right: eval(), inference_mode(), and torch.compile
    • Exporting: getting the model out of Python
    • Serving: a minimal FastAPI inference endpoint
    • 🧪 Your task
    • Key takeaways

🔥 Deep Learning with PyTorch · Lesson 9 — Saving, Exporting & Serving Your Model

🏠 🔥 Course home  |  ← Lesson 08  |  📚 All mini-courses


Lesson 9 — Saving, Exporting & Serving Your Model

Over eight lessons you took raw tensors all the way to a fine-tuned ResNet that classifies real images. But a model that only lives inside your training notebook is a science project, not software. In this lesson we close the loop: you’ll checkpoint the Lesson 8 model correctly (there are more wrong ways than right ways), make inference fast with torch.inference_mode() and torch.compile, export it to portable formats (TorchScript, torch.export, ONNX), and stand up a real HTTP endpoint with FastAPI that anyone can curl. This is the difference between “I trained a model” and “I shipped a model.”

🎯 In this lesson you will: save and load checkpoints with state_dict best practices, run inference correctly with inference_mode and torch.compile, export a model via TorchScript / torch.export / ONNX, serve predictions from a FastAPI endpoint, and map out where to go after this course

Checkpoints done right: state_dict, not pickled models

There are two ways to save a PyTorch model, and one of them will eventually ruin your week. Let’s start by rebuilding the Lesson 8 architecture so this lesson runs standalone:

import torch
from torch import nn
from torchvision import models

def build_model(num_classes: int = 10) -> nn.Module:
    # Architecture only — weights will come from OUR checkpoint,
    # so weights=None (we don't need the ImageNet download here).
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

model = build_model()

The tempting-but-wrong way to save it:

# DON'T do this
torch.save(model, "model_full.pt")

This pickles the entire object — including references to your class definitions, your module paths, even your directory structure. Move the file to a machine where models.resnet18 lives at a different import path, or upgrade torchvision, and the load fails with an obscure ModuleNotFoundError or silently constructs something subtly different. Pickled models are a time bomb.

The right way saves only the state_dict: an ordered dictionary mapping parameter names to tensors. Architecture stays in code (versioned in git), weights stay in the file:

torch.save(model.state_dict(), "resnet18_finetuned.pt")

sd = model.state_dict()
print(type(sd))
print(list(sd.keys())[:4])
print(sd["fc.weight"].shape)
<class 'collections.OrderedDict'>
['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean']
torch.Size([10, 512])

Notice bn1.running_mean in there — the state_dict includes buffers (BatchNorm running statistics), not just trainable parameters. This matters: a model restored without its BN statistics will produce garbage in eval mode, which is why you always save/load the full state_dict rather than iterating over model.parameters().

Here’s the anatomy of what you just saved:

checkpoint.pt (a plain dict — save everything you need to resume) “model”: state_dict conv1.weight [64,3,7,7] bn1.weight [64] bn1.running_mean [64] ← buffer! layer1.0.conv1.weight … … fc.weight [10,512] fc.bias [10] params + buffers, name → tensor “optimizer”: state_dict momentum / Adam moments… needed to RESUME training, not to serve metadata “epoch”: 12 “val_acc”: 0.917 “classes”: [“plane”, “car”, …]

For training checkpoints (as opposed to final weights), save a dict with everything needed to resume — the pattern you’d bolt onto Lesson 4’s training loop:

def save_checkpoint(path, model, optimizer, epoch, val_acc, classes):
    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch,
        "val_acc": val_acc,
        "classes": classes,
    }, path)

Now the loading side, where the real best practices live:

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = build_model(num_classes=10)
state = torch.load(
    "resnet18_finetuned.pt",
    map_location="cpu",      # 1
    weights_only=True,       # 2
)
model.load_state_dict(state)  # 3 — strict=True by default
model = model.to(DEVICE)

Three things to internalize here:

  1. map_location="cpu" — tensors remember which device they were saved from. A checkpoint written on cuda:0 will try to deserialize onto cuda:0, which crashes on a CPU-only server or a machine with fewer GPUs. Always load to CPU first, then .to(DEVICE) explicitly. It costs nothing and works everywhere.
  2. weights_only=True — torch.load historically used pickle, which can execute arbitrary code from a malicious file. weights_only=True restricts deserialization to tensors and safe primitives. It became the default in PyTorch 2.6, but pass it explicitly so your code is safe and self-documenting on any version.
  3. strict=True (the default) — load_state_dict verifies that the checkpoint keys and the model keys match exactly. If you rebuilt the model with num_classes=100 by mistake, you get a loud shape-mismatch error instead of a silently broken model. Only reach for strict=False deliberately — e.g. loading a backbone without its head, as in Lesson 8 — and when you do, check the return value:
result = model.load_state_dict(state, strict=False)
print(result.missing_keys)     # keys the model has but the file doesn't
print(result.unexpected_keys)  # keys the file has but the model doesn't

If you ignore that return value, strict=False becomes a machine for creating models that are randomly initialized in exactly the places you care about.

Inference done right: eval(), inference_mode(), and torch.compile

Serving a model means running it thousands of times without training it once. Three switches control how efficient and correct that is.

Switch 1: model.eval(). This changes layer behavior: Dropout stops dropping, BatchNorm uses its saved running statistics instead of batch statistics. Forget it and a batch-of-one request will be normalized against itself — predictions become noisy nonsense. This is a correctness switch, not a speed switch.

Switch 2: torch.inference_mode(). This changes autograd behavior. You met torch.no_grad() on Lesson 1; inference_mode is its stricter, faster sibling: no graph is built, no gradients tracked, and the version-counter bookkeeping that lets tensors ever re-enter autograd is skipped too. The resulting tensors can never be used in a backward pass — which for serving is a feature, not a limitation.

model.eval()

x = torch.randn(1, 3, 224, 224, device=DEVICE)

with torch.inference_mode():
    logits = model(x)

print(logits.shape, logits.requires_grad)
torch.Size([1, 10]) False

The mistake to avoid: calling .eval() but forgetting inference_mode() (you waste memory building graphs you never use), or the reverse (you save memory but BatchNorm/Dropout still behave as if training). You need both, every time. Bake them into one function so you can’t forget:

@torch.inference_mode()
def predict(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
    return model(x)   # assumes model.eval() was called once at load time

Switch 3: torch.compile. PyTorch 2.x can JIT-compile your model into fused kernels. For inference it’s one line:

compiled = torch.compile(model, mode="reduce-overhead")

# First call is SLOW — this is compilation, not inference.
with torch.inference_mode():
    _ = compiled(x)          # warmup: triggers compile for this input shape
    logits = compiled(x)     # now fast

A quick honest benchmark (on GPU, always synchronize before reading the clock — kernel launches are async, and timing without synchronize() measures how fast Python can enqueue work, not how fast the GPU does it):

import time

def bench(fn, x, iters=50):
    with torch.inference_mode():
        for _ in range(5):        # warmup
            fn(x)
        if DEVICE == "cuda":
            torch.cuda.synchronize()
        t0 = time.perf_counter()
        for _ in range(iters):
            fn(x)
        if DEVICE == "cuda":
            torch.cuda.synchronize()
    return (time.perf_counter() - t0) / iters * 1000

xb = torch.randn(32, 3, 224, 224, device=DEVICE)
print(f"eager:    {bench(model, xb):.2f} ms/batch")
print(f"compiled: {bench(compiled, xb):.2f} ms/batch")
eager:    14.31 ms/batch     # representative GPU numbers; yours will differ
compiled: 9.87 ms/batch

Expect roughly 1.3–2× on a GPU for a convnet; on CPU the gain is smaller. Two caveats for serving:

  • Compilation is shape-specialized. Feed a new input shape and you trigger a recompile (seconds of latency for that one request). In a server, pad or bucket inputs to a fixed shape.
  • torch.compile optimizes the model in place inside your Python process — it is not an export format. The compiled artifact doesn’t leave the process. For portability, you export — next section.

Exporting: getting the model out of Python

Everything so far assumes the consumer of your model runs Python with PyTorch installed. Often they don’t: a C++ service, a mobile app, a Java shop that standardized on ONNX Runtime. Exporting converts your model into a self-contained graph another runtime can execute.

flowchart LR
    T["Trained nn.Module<br/>+ state_dict"] --> E{Who runs it?}
    E -->|"Python service<br/>(this course)"| P["state_dict + code<br/>+ torch.compile<br/>→ FastAPI"]
    E -->|"C++ / legacy<br/>PyTorch runtime"| TS["TorchScript<br/>jit.trace / jit.script"]
    E -->|"Modern PyTorch<br/>2.x pipeline"| EX["torch.export<br/>ExportedProgram"]
    EX --> AOT["AOTInductor /<br/>ExecuTorch (mobile)"]
    E -->|"Any other stack<br/>(ONNX Runtime, TensorRT…)"| OX["torch.onnx.export<br/>→ model.onnx"]

TorchScript is the classic path, and you’ll see it everywhere in existing codebases. Two flavors:

model_cpu = build_model()
model_cpu.load_state_dict(torch.load("resnet18_finetuned.pt",
                                     map_location="cpu", weights_only=True))
model_cpu.eval()

example = torch.randn(1, 3, 224, 224)

# trace: RUN the model once, record the ops that executed
traced = torch.jit.trace(model_cpu, example)
traced.save("resnet18_traced.pt")

# later, anywhere (no build_model() needed — architecture is in the file):
loaded = torch.jit.load("resnet18_traced.pt")

The trap with tracing: it records one execution path. Any if that depends on tensor values gets frozen at whatever branch the example input took — silently. A model with if x.mean() > 0: ... traced with a positive-mean example will take that branch forever. torch.jit.script compiles the actual Python source instead, preserving control flow, but chokes on dynamic Python. Rule of thumb: trace pure feed-forward models like our ResNet; anything with data-dependent branching needs script — or better, the modern path:

torch.export is the PyTorch 2.x successor. It uses the same graph-capture machinery as torch.compile (Dynamo) to produce a full, sound graph — and errors loudly on data-dependent control flow instead of silently baking in a branch:

ep = torch.export.export(model_cpu, (example,))
torch.export.save(ep, "resnet18.pt2")
print(type(ep))
<class 'torch.export.exported_program.ExportedProgram'>

An ExportedProgram is the input to downstream compilers: AOTInductor for server binaries, ExecuTorch for mobile/edge. For this course, know that it exists and that new projects should prefer it over TorchScript.

ONNX is the lingua franca when the target isn’t PyTorch at all. The modern exporter rides on torch.export under the hood:

torch.onnx.export(
    model_cpu, (example,), "resnet18.onnx",
    input_names=["image"], output_names=["logits"],
    dynamo=True,   # the modern torch.export-based path
)

An export you haven’t verified is a rumor. Always check numerical parity against the original:

import onnxruntime as ort   # pip install onnxruntime
import numpy as np

sess = ort.InferenceSession("resnet18.onnx")
onnx_out = sess.run(None, {"image": example.numpy()})[0]

with torch.inference_mode():
    torch_out = model_cpu(example).numpy()

np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-3, atol=1e-5)
print("max abs diff:", np.abs(torch_out - onnx_out).max())
max abs diff: 2.3841858e-06

Tiny float discrepancies are normal (different kernels, different op fusion); large ones mean an unsupported op got approximated and you must not ship.

Path Artifact Needs Python? Best for Watch out for
state_dict + code .pt weights Yes Python services, resuming training Architecture must match exactly
TorchScript self-contained .pt No (libtorch) Existing C++ deployments Tracing freezes control flow silently
torch.export .pt2 ExportedProgram No (via AOTInductor/ExecuTorch) New PyTorch 2.x pipelines, mobile Errors on dynamic Python (by design)
ONNX .onnx graph No Non-PyTorch runtimes, TensorRT Verify parity; op coverage

For this lesson’s server we stay on the first row — a Python service loading a state_dict is the simplest thing that works, and simplest-that-works is the correct engineering default.

Serving: a minimal FastAPI inference endpoint

Time to put the model behind HTTP. The architecture of every sane inference server is the same:

sequenceDiagram
    participant C as Client
    participant A as FastAPI app
    participant M as Model (loaded ONCE at startup)
    Note over A,M: startup: build_model → load_state_dict → eval() → warmup
    C->>A: POST /predict (image bytes)
    A->>A: decode + preprocess (resize, normalize) → [1,3,224,224]
    A->>M: forward pass under inference_mode
    M-->>A: logits [1,10]
    A->>A: softmax → top-k labels + probabilities
    A-->>C: JSON {"predictions": [...]}

The one non-negotiable design rule: load the model once at startup, never per-request. Loading ResNet-18 takes ~1 second; a forward pass takes ~10 ms. Load-per-request makes your server 100× slower than the model.

# serve.py
import io
from contextlib import asynccontextmanager

import torch
from fastapi import FastAPI, UploadFile, HTTPException
from PIL import Image
from torchvision import transforms

from model import build_model   # the function from section 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASSES = ["plane", "car", "bird", "cat", "deer",
           "dog", "frog", "horse", "ship", "truck"]

# Same normalization the model was TRAINED with (Lesson 8) — mismatch here
# is the #1 cause of "it worked in the notebook but the API is dumb".
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

The preprocessing pipeline is copied verbatim from training. This deserves emphasis: the model has no opinion about pixels — it learned a function of normalized tensors. Serve it un-normalized inputs and accuracy quietly collapses to near-random while the API happily returns confident JSON. Train/serve preprocessing skew is the most common production ML bug in existence.

state = {}   # holds the model between startup and shutdown

@asynccontextmanager
async def lifespan(app: FastAPI):
    model = build_model(num_classes=len(CLASSES))
    sd = torch.load("resnet18_finetuned.pt",
                    map_location="cpu", weights_only=True)
    model.load_state_dict(sd)
    model.to(DEVICE).eval()
    with torch.inference_mode():                       # warmup:
        model(torch.zeros(1, 3, 224, 224, device=DEVICE))  # first-call costs paid here
    state["model"] = model
    yield
    state.clear()

app = FastAPI(title="Lesson 9 classifier", lifespan=lifespan)

The lifespan context manager is FastAPI’s startup/shutdown hook. Everything expensive — building, loading, moving to device, eval(), and a warmup forward pass (so the first real request doesn’t eat cuDNN autotuning / lazy-init cost) — happens exactly once, before the server accepts traffic.

Now the endpoint itself:

@app.get("/health")
def health():
    return {"status": "ok", "device": DEVICE}

@app.post("/predict")
def predict(file: UploadFile, top_k: int = 3):
    try:
        img = Image.open(io.BytesIO(file.file.read())).convert("RGB")
    except Exception:
        raise HTTPException(status_code=400, detail="not a decodable image")

    x = preprocess(img).unsqueeze(0).to(DEVICE)   # [3,224,224] -> [1,3,224,224]

    with torch.inference_mode():
        logits = state["model"](x)                 # [1, 10]
        probs = logits.softmax(dim=1)[0]           # [10]

    p, idx = probs.topk(min(top_k, len(CLASSES)))
    return {"predictions": [
        {"label": CLASSES[i], "prob": round(float(pi), 4)}
        for pi, i in zip(p.tolist(), idx.tolist())
    ]}

Walk the shapes: preprocess yields [3, 224, 224]; the model expects a batch dimension, so unsqueeze(0) makes it [1, 3, 224, 224] — forget this and you get a dimension error (or worse, on some models, silent misinterpretation). The logits come back [1, 10]; softmax turns them into probabilities

\[p_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}}\]

and [0] peels off the batch dimension. Note the trust-boundary handling: the upload is decoded inside a try and rejected with a 400 — a server that 500s on a corrupt JPEG is a server that pages you at 3am. Also note float(pi): tensors aren’t JSON-serializable; convert at the boundary.

Run it and hit it:

pip install fastapi uvicorn python-multipart pillow
uvicorn serve:app --host 0.0.0.0 --port 8000
curl -s http://localhost:8000/health
curl -s -X POST "http://localhost:8000/predict?top_k=3" \
     -F "file=@cat.jpg" | python -m json.tool
{
    "predictions": [
        {"label": "cat",  "prob": 0.9231},
        {"label": "dog",  "prob": 0.0512},
        {"label": "deer", "prob": 0.0104}
    ]
}

That’s a deployed model. Everything beyond this — batching concurrent requests, autoscaling, model versioning, monitoring for drift — is real and important, but it’s MLOps territory (see the pointer at the end), and this minimal server is the correct foundation all of it builds on.

🧪 Your task

Prove your serialization round-trip is airtight, end to end. Write a script roundtrip.py that:

  1. Builds the model, loads resnet18_finetuned.pt, and computes logits for a fixed input torch.manual_seed(0); x = torch.randn(4, 3, 224, 224) in eval + inference mode.
  2. Exports the model with torch.jit.trace and saves it.
  3. In a fresh torch.jit.load (no call to build_model), computes logits for the same input.
  4. Asserts the two outputs match with torch.allclose(..., atol=1e-5) — and asserts they don’t match if you “forget” model.eval() before tracing (proving to yourself that BatchNorm mode is baked into the export).

Hint: trace the model twice — once in eval() mode, once after model.train() — and compare each traced module’s output against the eval-mode reference. torch.jit.trace records the module in whatever mode it’s currently in.

Solution
# roundtrip.py
import torch
from model import build_model

# --- 1. reference output from the real model ---
model = build_model(num_classes=10)
sd = torch.load("resnet18_finetuned.pt", map_location="cpu", weights_only=True)
model.load_state_dict(sd)
model.eval()

torch.manual_seed(0)
x = torch.randn(4, 3, 224, 224)

with torch.inference_mode():
    ref = model(x)
print("reference logits:", ref.shape)   # torch.Size([4, 10])

# --- 2. export (correctly, in eval mode) ---
traced_eval = torch.jit.trace(model, x)
traced_eval.save("rt_eval.pt")

# --- 3. fresh load, no build_model ---
loaded = torch.jit.load("rt_eval.pt")
loaded.eval()
with torch.inference_mode():
    out = loaded(x)

# --- 4a. round-trip parity ---
assert torch.allclose(ref, out, atol=1e-5), "round-trip mismatch!"
print("eval-mode round trip: OK, max diff =",
      (ref - out).abs().max().item())

# --- 4b. the negative control: trace with train-mode BatchNorm ---
model.train()                      # "forgot" eval() before export
traced_train = torch.jit.trace(model, x)
model.eval()                       # restore

with torch.inference_mode():
    bad = traced_train(x)

assert not torch.allclose(ref, bad, atol=1e-5), \
    "train-mode trace unexpectedly matched — BN stats identical?"
print("train-mode trace differs as expected, max diff =",
      (ref - bad).abs().max().item())

Expected output shape:

reference logits: torch.Size([4, 10])
eval-mode round trip: OK, max diff = 0.0
train-mode trace differs as expected, max diff = 3.87  (magnitude varies)

The second assertion is the lesson: the traced file froze BatchNorm in training mode — it normalizes each request batch by its own statistics forever, no matter what mode the loader sets. Export mode is part of the artifact. (Tracing in train mode also emits a PyTorch warning for exactly this reason.)

Key takeaways

  • Save state_dict, never the pickled model object; keep architecture in code and weights in the file.
  • Load with map_location="cpu" and weights_only=True, move to device explicitly, and treat strict=False (and its return value) as a deliberate, checked decision.
  • Inference needs both switches: model.eval() for layer behavior (BatchNorm/Dropout) and torch.inference_mode() for autograd; missing either is a bug, not a slowdown.
  • torch.compile gives real speedups in-process but is not an export format; it specializes on input shapes, so keep serving shapes fixed.
  • Export paths: TorchScript for legacy/C++ (trace freezes control flow — beware), torch.export for modern PyTorch pipelines, ONNX for everything else — and always verify numerical parity before shipping.
  • A serving endpoint loads the model once at startup, reuses the exact training preprocessing, validates inputs at the trust boundary, and converts tensors to plain Python at the JSON boundary.

That’s the course — nine lessons from torch.tensor([1., 2., 3.]) to a model answering HTTP requests. Where next: the site’s Transformers mini-course takes the nn.Module skills you now have into attention and language models, and the MLOps mini-course picks up exactly where this lesson’s FastAPI server left off — versioning, monitoring, and scaling the thing you just shipped.


🏠 🔥 Course home  |  ← Lesson 08  |  📚 All mini-courses

 

© Kader Mohideen