flowchart LR
T["Trained nn.Module<br/>+ state_dict"] --> E{Who runs it?}
E -->|"Python service<br/>(this course)"| P["state_dict + code<br/>+ torch.compile<br/>→ FastAPI"]
E -->|"C++ / legacy<br/>PyTorch runtime"| TS["TorchScript<br/>jit.trace / jit.script"]
E -->|"Modern PyTorch<br/>2.x pipeline"| EX["torch.export<br/>ExportedProgram"]
EX --> AOT["AOTInductor /<br/>ExecuTorch (mobile)"]
E -->|"Any other stack<br/>(ONNX Runtime, TensorRT…)"| OX["torch.onnx.export<br/>→ model.onnx"]
🔥 Deep Learning with PyTorch · Lesson 9 — Saving, Exporting & Serving Your Model
🏠 🔥 Course home | ← Lesson 08 | 📚 All mini-courses
Lesson 9 — Saving, Exporting & Serving Your Model
Over eight lessons you took raw tensors all the way to a fine-tuned ResNet that classifies real images. But a model that only lives inside your training notebook is a science project, not software. In this lesson we close the loop: you’ll checkpoint the Lesson 8 model correctly (there are more wrong ways than right ways), make inference fast with torch.inference_mode() and torch.compile, export it to portable formats (TorchScript, torch.export, ONNX), and stand up a real HTTP endpoint with FastAPI that anyone can curl. This is the difference between “I trained a model” and “I shipped a model.”
🎯 In this lesson you will: save and load checkpoints with state_dict best practices, run inference correctly with inference_mode and torch.compile, export a model via TorchScript / torch.export / ONNX, serve predictions from a FastAPI endpoint, and map out where to go after this course
Checkpoints done right: state_dict, not pickled models
There are two ways to save a PyTorch model, and one of them will eventually ruin your week. Let’s start by rebuilding the Lesson 8 architecture so this lesson runs standalone:
import torch
from torch import nn
from torchvision import models
def build_model(num_classes: int = 10) -> nn.Module:
# Architecture only — weights will come from OUR checkpoint,
# so weights=None (we don't need the ImageNet download here).
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
model = build_model()The tempting-but-wrong way to save it:
# DON'T do this
torch.save(model, "model_full.pt")This pickles the entire object — including references to your class definitions, your module paths, even your directory structure. Move the file to a machine where models.resnet18 lives at a different import path, or upgrade torchvision, and the load fails with an obscure ModuleNotFoundError or silently constructs something subtly different. Pickled models are a time bomb.
The right way saves only the state_dict: an ordered dictionary mapping parameter names to tensors. Architecture stays in code (versioned in git), weights stay in the file:
torch.save(model.state_dict(), "resnet18_finetuned.pt")
sd = model.state_dict()
print(type(sd))
print(list(sd.keys())[:4])
print(sd["fc.weight"].shape)<class 'collections.OrderedDict'>
['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean']
torch.Size([10, 512])
Notice bn1.running_mean in there — the state_dict includes buffers (BatchNorm running statistics), not just trainable parameters. This matters: a model restored without its BN statistics will produce garbage in eval mode, which is why you always save/load the full state_dict rather than iterating over model.parameters().
Here’s the anatomy of what you just saved:
For training checkpoints (as opposed to final weights), save a dict with everything needed to resume — the pattern you’d bolt onto Lesson 4’s training loop:
def save_checkpoint(path, model, optimizer, epoch, val_acc, classes):
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"val_acc": val_acc,
"classes": classes,
}, path)Now the loading side, where the real best practices live:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = build_model(num_classes=10)
state = torch.load(
"resnet18_finetuned.pt",
map_location="cpu", # 1
weights_only=True, # 2
)
model.load_state_dict(state) # 3 — strict=True by default
model = model.to(DEVICE)Three things to internalize here:
map_location="cpu"— tensors remember which device they were saved from. A checkpoint written oncuda:0will try to deserialize ontocuda:0, which crashes on a CPU-only server or a machine with fewer GPUs. Always load to CPU first, then.to(DEVICE)explicitly. It costs nothing and works everywhere.weights_only=True—torch.loadhistorically used pickle, which can execute arbitrary code from a malicious file.weights_only=Truerestricts deserialization to tensors and safe primitives. It became the default in PyTorch 2.6, but pass it explicitly so your code is safe and self-documenting on any version.strict=True(the default) —load_state_dictverifies that the checkpoint keys and the model keys match exactly. If you rebuilt the model withnum_classes=100by mistake, you get a loud shape-mismatch error instead of a silently broken model. Only reach forstrict=Falsedeliberately — e.g. loading a backbone without its head, as in Lesson 8 — and when you do, check the return value:
result = model.load_state_dict(state, strict=False)
print(result.missing_keys) # keys the model has but the file doesn't
print(result.unexpected_keys) # keys the file has but the model doesn'tIf you ignore that return value, strict=False becomes a machine for creating models that are randomly initialized in exactly the places you care about.
Inference done right: eval(), inference_mode(), and torch.compile
Serving a model means running it thousands of times without training it once. Three switches control how efficient and correct that is.
Switch 1: model.eval(). This changes layer behavior: Dropout stops dropping, BatchNorm uses its saved running statistics instead of batch statistics. Forget it and a batch-of-one request will be normalized against itself — predictions become noisy nonsense. This is a correctness switch, not a speed switch.
Switch 2: torch.inference_mode(). This changes autograd behavior. You met torch.no_grad() on Lesson 1; inference_mode is its stricter, faster sibling: no graph is built, no gradients tracked, and the version-counter bookkeeping that lets tensors ever re-enter autograd is skipped too. The resulting tensors can never be used in a backward pass — which for serving is a feature, not a limitation.
model.eval()
x = torch.randn(1, 3, 224, 224, device=DEVICE)
with torch.inference_mode():
logits = model(x)
print(logits.shape, logits.requires_grad)torch.Size([1, 10]) False
The mistake to avoid: calling .eval() but forgetting inference_mode() (you waste memory building graphs you never use), or the reverse (you save memory but BatchNorm/Dropout still behave as if training). You need both, every time. Bake them into one function so you can’t forget:
@torch.inference_mode()
def predict(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
return model(x) # assumes model.eval() was called once at load timeSwitch 3: torch.compile. PyTorch 2.x can JIT-compile your model into fused kernels. For inference it’s one line:
compiled = torch.compile(model, mode="reduce-overhead")
# First call is SLOW — this is compilation, not inference.
with torch.inference_mode():
_ = compiled(x) # warmup: triggers compile for this input shape
logits = compiled(x) # now fastA quick honest benchmark (on GPU, always synchronize before reading the clock — kernel launches are async, and timing without synchronize() measures how fast Python can enqueue work, not how fast the GPU does it):
import time
def bench(fn, x, iters=50):
with torch.inference_mode():
for _ in range(5): # warmup
fn(x)
if DEVICE == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
fn(x)
if DEVICE == "cuda":
torch.cuda.synchronize()
return (time.perf_counter() - t0) / iters * 1000
xb = torch.randn(32, 3, 224, 224, device=DEVICE)
print(f"eager: {bench(model, xb):.2f} ms/batch")
print(f"compiled: {bench(compiled, xb):.2f} ms/batch")eager: 14.31 ms/batch # representative GPU numbers; yours will differ
compiled: 9.87 ms/batch
Expect roughly 1.3–2× on a GPU for a convnet; on CPU the gain is smaller. Two caveats for serving:
- Compilation is shape-specialized. Feed a new input shape and you trigger a recompile (seconds of latency for that one request). In a server, pad or bucket inputs to a fixed shape.
torch.compileoptimizes the model in place inside your Python process — it is not an export format. The compiled artifact doesn’t leave the process. For portability, you export — next section.
Exporting: getting the model out of Python
Everything so far assumes the consumer of your model runs Python with PyTorch installed. Often they don’t: a C++ service, a mobile app, a Java shop that standardized on ONNX Runtime. Exporting converts your model into a self-contained graph another runtime can execute.
TorchScript is the classic path, and you’ll see it everywhere in existing codebases. Two flavors:
model_cpu = build_model()
model_cpu.load_state_dict(torch.load("resnet18_finetuned.pt",
map_location="cpu", weights_only=True))
model_cpu.eval()
example = torch.randn(1, 3, 224, 224)
# trace: RUN the model once, record the ops that executed
traced = torch.jit.trace(model_cpu, example)
traced.save("resnet18_traced.pt")
# later, anywhere (no build_model() needed — architecture is in the file):
loaded = torch.jit.load("resnet18_traced.pt")The trap with tracing: it records one execution path. Any if that depends on tensor values gets frozen at whatever branch the example input took — silently. A model with if x.mean() > 0: ... traced with a positive-mean example will take that branch forever. torch.jit.script compiles the actual Python source instead, preserving control flow, but chokes on dynamic Python. Rule of thumb: trace pure feed-forward models like our ResNet; anything with data-dependent branching needs script — or better, the modern path:
torch.export is the PyTorch 2.x successor. It uses the same graph-capture machinery as torch.compile (Dynamo) to produce a full, sound graph — and errors loudly on data-dependent control flow instead of silently baking in a branch:
ep = torch.export.export(model_cpu, (example,))
torch.export.save(ep, "resnet18.pt2")
print(type(ep))<class 'torch.export.exported_program.ExportedProgram'>
An ExportedProgram is the input to downstream compilers: AOTInductor for server binaries, ExecuTorch for mobile/edge. For this course, know that it exists and that new projects should prefer it over TorchScript.
ONNX is the lingua franca when the target isn’t PyTorch at all. The modern exporter rides on torch.export under the hood:
torch.onnx.export(
model_cpu, (example,), "resnet18.onnx",
input_names=["image"], output_names=["logits"],
dynamo=True, # the modern torch.export-based path
)An export you haven’t verified is a rumor. Always check numerical parity against the original:
import onnxruntime as ort # pip install onnxruntime
import numpy as np
sess = ort.InferenceSession("resnet18.onnx")
onnx_out = sess.run(None, {"image": example.numpy()})[0]
with torch.inference_mode():
torch_out = model_cpu(example).numpy()
np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-3, atol=1e-5)
print("max abs diff:", np.abs(torch_out - onnx_out).max())max abs diff: 2.3841858e-06
Tiny float discrepancies are normal (different kernels, different op fusion); large ones mean an unsupported op got approximated and you must not ship.
| Path | Artifact | Needs Python? | Best for | Watch out for |
|---|---|---|---|---|
state_dict + code |
.pt weights |
Yes | Python services, resuming training | Architecture must match exactly |
| TorchScript | self-contained .pt |
No (libtorch) | Existing C++ deployments | Tracing freezes control flow silently |
torch.export |
.pt2 ExportedProgram |
No (via AOTInductor/ExecuTorch) | New PyTorch 2.x pipelines, mobile | Errors on dynamic Python (by design) |
| ONNX | .onnx graph |
No | Non-PyTorch runtimes, TensorRT | Verify parity; op coverage |
For this lesson’s server we stay on the first row — a Python service loading a state_dict is the simplest thing that works, and simplest-that-works is the correct engineering default.
Serving: a minimal FastAPI inference endpoint
Time to put the model behind HTTP. The architecture of every sane inference server is the same:
sequenceDiagram
participant C as Client
participant A as FastAPI app
participant M as Model (loaded ONCE at startup)
Note over A,M: startup: build_model → load_state_dict → eval() → warmup
C->>A: POST /predict (image bytes)
A->>A: decode + preprocess (resize, normalize) → [1,3,224,224]
A->>M: forward pass under inference_mode
M-->>A: logits [1,10]
A->>A: softmax → top-k labels + probabilities
A-->>C: JSON {"predictions": [...]}
The one non-negotiable design rule: load the model once at startup, never per-request. Loading ResNet-18 takes ~1 second; a forward pass takes ~10 ms. Load-per-request makes your server 100× slower than the model.
# serve.py
import io
from contextlib import asynccontextmanager
import torch
from fastapi import FastAPI, UploadFile, HTTPException
from PIL import Image
from torchvision import transforms
from model import build_model # the function from section 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASSES = ["plane", "car", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
# Same normalization the model was TRAINED with (Lesson 8) — mismatch here
# is the #1 cause of "it worked in the notebook but the API is dumb".
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])The preprocessing pipeline is copied verbatim from training. This deserves emphasis: the model has no opinion about pixels — it learned a function of normalized tensors. Serve it un-normalized inputs and accuracy quietly collapses to near-random while the API happily returns confident JSON. Train/serve preprocessing skew is the most common production ML bug in existence.
state = {} # holds the model between startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
model = build_model(num_classes=len(CLASSES))
sd = torch.load("resnet18_finetuned.pt",
map_location="cpu", weights_only=True)
model.load_state_dict(sd)
model.to(DEVICE).eval()
with torch.inference_mode(): # warmup:
model(torch.zeros(1, 3, 224, 224, device=DEVICE)) # first-call costs paid here
state["model"] = model
yield
state.clear()
app = FastAPI(title="Lesson 9 classifier", lifespan=lifespan)The lifespan context manager is FastAPI’s startup/shutdown hook. Everything expensive — building, loading, moving to device, eval(), and a warmup forward pass (so the first real request doesn’t eat cuDNN autotuning / lazy-init cost) — happens exactly once, before the server accepts traffic.
Now the endpoint itself:
@app.get("/health")
def health():
return {"status": "ok", "device": DEVICE}
@app.post("/predict")
def predict(file: UploadFile, top_k: int = 3):
try:
img = Image.open(io.BytesIO(file.file.read())).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="not a decodable image")
x = preprocess(img).unsqueeze(0).to(DEVICE) # [3,224,224] -> [1,3,224,224]
with torch.inference_mode():
logits = state["model"](x) # [1, 10]
probs = logits.softmax(dim=1)[0] # [10]
p, idx = probs.topk(min(top_k, len(CLASSES)))
return {"predictions": [
{"label": CLASSES[i], "prob": round(float(pi), 4)}
for pi, i in zip(p.tolist(), idx.tolist())
]}Walk the shapes: preprocess yields [3, 224, 224]; the model expects a batch dimension, so unsqueeze(0) makes it [1, 3, 224, 224] — forget this and you get a dimension error (or worse, on some models, silent misinterpretation). The logits come back [1, 10]; softmax turns them into probabilities
\[p_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}}\]
and [0] peels off the batch dimension. Note the trust-boundary handling: the upload is decoded inside a try and rejected with a 400 — a server that 500s on a corrupt JPEG is a server that pages you at 3am. Also note float(pi): tensors aren’t JSON-serializable; convert at the boundary.
Run it and hit it:
pip install fastapi uvicorn python-multipart pillow
uvicorn serve:app --host 0.0.0.0 --port 8000curl -s http://localhost:8000/health
curl -s -X POST "http://localhost:8000/predict?top_k=3" \
-F "file=@cat.jpg" | python -m json.tool{
"predictions": [
{"label": "cat", "prob": 0.9231},
{"label": "dog", "prob": 0.0512},
{"label": "deer", "prob": 0.0104}
]
}That’s a deployed model. Everything beyond this — batching concurrent requests, autoscaling, model versioning, monitoring for drift — is real and important, but it’s MLOps territory (see the pointer at the end), and this minimal server is the correct foundation all of it builds on.
🧪 Your task
Prove your serialization round-trip is airtight, end to end. Write a script roundtrip.py that:
- Builds the model, loads
resnet18_finetuned.pt, and computes logits for a fixed inputtorch.manual_seed(0); x = torch.randn(4, 3, 224, 224)in eval + inference mode. - Exports the model with
torch.jit.traceand saves it. - In a fresh
torch.jit.load(no call tobuild_model), computes logits for the same input. - Asserts the two outputs match with
torch.allclose(..., atol=1e-5)— and asserts they don’t match if you “forget”model.eval()before tracing (proving to yourself that BatchNorm mode is baked into the export).
Hint: trace the model twice — once in eval() mode, once after model.train() — and compare each traced module’s output against the eval-mode reference. torch.jit.trace records the module in whatever mode it’s currently in.
Solution
# roundtrip.py
import torch
from model import build_model
# --- 1. reference output from the real model ---
model = build_model(num_classes=10)
sd = torch.load("resnet18_finetuned.pt", map_location="cpu", weights_only=True)
model.load_state_dict(sd)
model.eval()
torch.manual_seed(0)
x = torch.randn(4, 3, 224, 224)
with torch.inference_mode():
ref = model(x)
print("reference logits:", ref.shape) # torch.Size([4, 10])
# --- 2. export (correctly, in eval mode) ---
traced_eval = torch.jit.trace(model, x)
traced_eval.save("rt_eval.pt")
# --- 3. fresh load, no build_model ---
loaded = torch.jit.load("rt_eval.pt")
loaded.eval()
with torch.inference_mode():
out = loaded(x)
# --- 4a. round-trip parity ---
assert torch.allclose(ref, out, atol=1e-5), "round-trip mismatch!"
print("eval-mode round trip: OK, max diff =",
(ref - out).abs().max().item())
# --- 4b. the negative control: trace with train-mode BatchNorm ---
model.train() # "forgot" eval() before export
traced_train = torch.jit.trace(model, x)
model.eval() # restore
with torch.inference_mode():
bad = traced_train(x)
assert not torch.allclose(ref, bad, atol=1e-5), \
"train-mode trace unexpectedly matched — BN stats identical?"
print("train-mode trace differs as expected, max diff =",
(ref - bad).abs().max().item())Expected output shape:
reference logits: torch.Size([4, 10])
eval-mode round trip: OK, max diff = 0.0
train-mode trace differs as expected, max diff = 3.87 (magnitude varies)
The second assertion is the lesson: the traced file froze BatchNorm in training mode — it normalizes each request batch by its own statistics forever, no matter what mode the loader sets. Export mode is part of the artifact. (Tracing in train mode also emits a PyTorch warning for exactly this reason.)
Key takeaways
- Save
state_dict, never the pickled model object; keep architecture in code and weights in the file. - Load with
map_location="cpu"andweights_only=True, move to device explicitly, and treatstrict=False(and its return value) as a deliberate, checked decision. - Inference needs both switches:
model.eval()for layer behavior (BatchNorm/Dropout) andtorch.inference_mode()for autograd; missing either is a bug, not a slowdown. torch.compilegives real speedups in-process but is not an export format; it specializes on input shapes, so keep serving shapes fixed.- Export paths: TorchScript for legacy/C++ (trace freezes control flow — beware),
torch.exportfor modern PyTorch pipelines, ONNX for everything else — and always verify numerical parity before shipping. - A serving endpoint loads the model once at startup, reuses the exact training preprocessing, validates inputs at the trust boundary, and converts tensors to plain Python at the JSON boundary.
That’s the course — nine lessons from torch.tensor([1., 2., 3.]) to a model answering HTTP requests. Where next: the site’s Transformers mini-course takes the nn.Module skills you now have into attention and language models, and the MLOps mini-course picks up exactly where this lesson’s FastAPI server left off — versioning, monitoring, and scaling the thing you just shipped.