Kader Mohideen
  • About
  • Blog
  • Projects
  • Health
  • Mini Courses
  • Extra
    • AI & ML Encyclopedia
    • Interview Guide
    • AI Interview Prep
    • Book References
    • Quest for AGI
    • AI Papers
    • Lupus

On this day

  • Day 6 — Serving with FastAPI: a Typed Prediction API
    • The shape of a model service
    • Schemas first: pydantic as your contract
    • Loading the champion at startup
    • The endpoints: predict, batch, health, model-info
    • Running it for real: uvicorn workers
    • Proving it works: pytest + a tiny load test
    • 🧪 Your task
    • Key takeaways

🚢 ML in Production — MLOps · Day 6 — Serving with FastAPI: a Typed Prediction API

🏠 🚢 Course home  |  ← Day 05  |  Day 07 →  |  📚 All mini-courses


Day 6 — Serving with FastAPI: a Typed Prediction API

Yesterday we packaged the churn model’s world into a Docker image — same Python, same libraries, same code, everywhere. But an image that just contains a model is a filing cabinet, not a service. Today we give the model a front door: an HTTP API that loads the registry champion from Day 4, validates every incoming request against a typed schema, predicts one customer or a thousand at a time, and tells operators whether it’s alive and exactly which model version is answering. We’ll build it with FastAPI and pydantic, prove it correct with a pytest test-client suite that needs no MLflow server and no real model, and then point a tiny load-test script at it to see what one process can actually do. This is the layer everything after today builds on: Day 7 swaps the model for an LLM, Day 8 puts this API under CI/CD, and Day 9 bolts monitoring onto it.

🎯 Today you will: build a typed /predict endpoint with pydantic request/response schemas, load the registry champion at startup via lifespan, add batch//health//model-info endpoints with real error handling, run it under multiple uvicorn workers, and verify it with a pytest suite plus a load-test script.

The shape of a model service

Before code, fix the mental model. An inference service is a pipeline with a hard boundary in the middle: everything left of the boundary is untrusted bytes from the network; everything right of it is typed, validated data your model can safely consume. FastAPI’s whole value proposition is that pydantic enforces that boundary declaratively — you describe the shape once, and malformed input never reaches your model.

flowchart LR
    C[Client] -->|"POST /predict<br/>JSON bytes"| U[Uvicorn worker]
    U --> V{pydantic<br/>validation}
    V -->|invalid| E422["422 Unprocessable Entity<br/>(field-level errors)"]
    V -->|valid| D[DataFrame with<br/>fixed column order]
    D --> M["model.predict_proba<br/>(champion from registry)"]
    M --> R[Prediction schema]
    R -->|"JSON"| C
    S[(MLflow Registry<br/>churn-classifier@champion)] -.->|loaded once<br/>at startup| M

Three design decisions up front, all of which we’ll justify as we build:

  1. The model loads once, at startup, not per request. Loading a sklearn pipeline from the registry takes hundreds of milliseconds; a prediction takes single-digit milliseconds. Loading per request would make the loader 99% of your latency.
  2. Schemas are the API contract. The pydantic models double as documentation (FastAPI auto-generates OpenAPI docs at /docs) and as the validation layer. One source of truth.
  3. The single and batch endpoints share one prediction path. predict_proba on a 1-row DataFrame and a 1000-row DataFrame is the same call — vectorization means the batch endpoint is nearly free to add and much cheaper per row.

Project layout after today (extending the repo from Days 2–5):

churn-service/
├── src/
│   ├── train.py          # Day 2–3
│   └── serve.py          # ← today
├── tests/
│   └── test_serve.py     # ← today
├── scripts/
│   └── load_test.py      # ← today
├── Dockerfile            # Day 5, CMD updated today
└── requirements.txt      # + fastapi, uvicorn, httpx

Add the serving dependencies (pin them — Day 2’s lesson still applies):

fastapi==0.115.12
uvicorn[standard]==0.34.2
httpx==0.28.1

httpx earns its place twice: FastAPI’s TestClient uses it under the hood, and our load-test script will use it directly.

Schemas first: pydantic as your contract

Start src/serve.py with the request schema. This is the most important code you’ll write today, because every garbage input your model ever doesn’t see is stopped here.

# src/serve.py
from __future__ import annotations

import logging
import os
from contextlib import asynccontextmanager

import mlflow
import pandas as pd
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, ConfigDict, Field, model_validator

logger = logging.getLogger("churn-api")

MODEL_NAME = os.getenv("MODEL_NAME", "churn-classifier")
MODEL_ALIAS = os.getenv("MODEL_ALIAS", "champion")
THRESHOLD = float(os.getenv("DECISION_THRESHOLD", "0.5"))

# The exact column order the model was trained on (Day 2's feature list).
FEATURES = [
    "tenure_months",
    "monthly_charges",
    "total_charges",
    "num_support_tickets",
    "is_month_to_month",
]

FEATURES looks trivial but it’s load-bearing. sklearn models trained on a DataFrame remember their feature names and order; feed columns in a different order via a bare numpy array and you get silently wrong predictions — the single most expensive class of serving bug, because nothing crashes. We will always build a DataFrame with these named columns and let sklearn’s feature-name check defend us.

Now the request schema:

class CustomerFeatures(BaseModel):
    """One customer, as the model expects it."""

    model_config = ConfigDict(
        extra="forbid",  # unknown fields are a client bug — reject loudly
        json_schema_extra={
            "example": {
                "tenure_months": 24,
                "monthly_charges": 79.5,
                "total_charges": 1908.0,
                "num_support_tickets": 3,
                "is_month_to_month": True,
            }
        },
    )

    tenure_months: int = Field(ge=0, le=600, description="Months since signup")
    monthly_charges: float = Field(gt=0, lt=10_000, description="Current monthly bill (USD)")
    total_charges: float = Field(ge=0, description="Lifetime billed amount (USD)")
    num_support_tickets: int = Field(ge=0, le=1_000)
    is_month_to_month: bool = Field(description="True if on a month-to-month contract")

    @model_validator(mode="after")
    def charges_are_consistent(self) -> "CustomerFeatures":
        if self.tenure_months > 0 and self.total_charges < self.monthly_charges:
            raise ValueError(
                "total_charges cannot be less than one month of charges "
                "for a customer with tenure_months > 0"
            )
        return self

Walk through the choices:

  • extra="forbid" — by default pydantic silently drops unknown fields. That default is dangerous in ML serving: a client that misspells monthly_charges as monthlyCharges would get a 422 for the missing field here, but in a schema with defaults it would sail through with a default value and a quietly wrong prediction. Forbid makes typos loud.
  • Field(ge=..., le=...) bounds — these are physical-plausibility rails, not statistics. tenure_months=600 is a 50-year customer; anything beyond that is corrupted input, and it’s far better to 422 than to extrapolate. (Detecting distribution drift within plausible ranges is Day 9’s job — validation and monitoring are different layers.)
  • @model_validator(mode="after") — field validators see one field at a time; cross-field rules (total ≥ one month’s charges) need the whole model. mode="after" runs once all fields are individually valid, so you can safely read self.tenure_months.

The response schemas:

class Prediction(BaseModel):
    # "model_version" starts with "model_", which collides with pydantic v2's
    # protected "model_" namespace and triggers a warning — opt out explicitly.
    model_config = ConfigDict(protected_namespaces=())

    churn_probability: float = Field(ge=0.0, le=1.0)
    churn_label: bool = Field(description=f"True if probability >= threshold")
    model_version: str


class BatchRequest(BaseModel):
    instances: list[CustomerFeatures] = Field(min_length=1, max_length=1024)


class BatchResponse(BaseModel):
    predictions: list[Prediction]
    count: int


class ModelInfo(BaseModel):
    model_config = ConfigDict(protected_namespaces=())

    name: str
    version: str
    alias: str
    run_id: str
    threshold: float
    features: list[str]

Two things worth pausing on. First, the protected_namespaces=() line is a real-world pydantic v2 gotcha: any field starting with model_ clashes with pydantic’s own model_dump/model_validate namespace and emits a UserWarning per class definition. In ML APIs you want fields like model_version, so you opt out deliberately rather than renaming your API contract around a library quirk. Second, max_length=1024 on the batch is backpressure: without a cap, one client posting a million rows turns your worker into a space heater and starves everyone else. Pick a cap, document it, return 422 beyond it.

Loading the champion at startup

Day 4 left us with a registry model churn-classifier and an alias @champion pointing at the blessed version. The serving contract is: this API serves whatever the champion alias points to, and can prove which version that is.

def load_champion() -> tuple[object, dict]:
    """Load the registry champion and its metadata. Module-level on purpose:
    tests monkeypatch this one function and never touch MLflow."""
    uri = f"models:/{MODEL_NAME}@{MODEL_ALIAS}"
    logger.info("Loading model from %s", uri)
    model = mlflow.sklearn.load_model(uri)

    client = mlflow.MlflowClient()
    mv = client.get_model_version_by_alias(MODEL_NAME, MODEL_ALIAS)
    meta = {
        "name": MODEL_NAME,
        "version": mv.version,
        "alias": MODEL_ALIAS,
        "run_id": mv.run_id,
    }
    logger.info("Loaded %s v%s (run %s)", MODEL_NAME, mv.version, mv.run_id)
    return model, meta

Why mlflow.sklearn.load_model and not the generic mlflow.pyfunc.load_model? Pyfunc gives you a lowest-common-denominator .predict() — for a classifier that usually means hard labels only. The sklearn flavor gives back the actual estimator, so we get predict_proba and can serve calibrated probabilities plus apply our own decision threshold. Serve probabilities whenever you can: the business can change the threshold without retraining, and Day 9’s monitoring needs the raw scores.

Now wire loading into the app’s lifespan — the modern replacement for the deprecated @app.on_event("startup"):

@asynccontextmanager
async def lifespan(app: FastAPI):
    # --- startup: runs once per worker process, before the first request ---
    app.state.model, app.state.meta = load_champion()
    yield
    # --- shutdown ---
    app.state.model = None
    logger.info("Model released, shutting down")


app = FastAPI(
    title="Churn Prediction API",
    version="1.0.0",
    lifespan=lifespan,
)

The methodology here matters more than the five lines:

  • app.state is FastAPI’s sanctioned place for per-process singletons. Avoid bare module globals for the model itself: app.state is visible in every handler via request.app.state, is reset cleanly between test-client contexts, and makes the dependency explicit.
  • The lifespan runs per worker process. With --workers 4 you get four independent loads and four copies of the model in memory. We’ll come back to the memory math in the uvicorn section.
  • If load_champion() raises, the worker refuses to start. That’s the behavior you want — crash-fast at deploy time beats serving 500s at request time. Your orchestrator (Docker healthcheck from Day 5, or Kubernetes later) sees the dead process and keeps the old version serving.

The endpoints: predict, batch, health, model-info

One shared prediction path, four thin endpoints:

def _predict_df(model, meta: dict, df: pd.DataFrame) -> list[Prediction]:
    """The single prediction path. df must contain FEATURES columns."""
    proba = model.predict_proba(df[FEATURES])[:, 1]  # shape: (n,) — P(churn)
    version = str(meta["version"])
    return [
        Prediction(
            churn_probability=round(float(p), 6),
            churn_label=bool(p >= THRESHOLD),
            model_version=version,
        )
        for p in proba
    ]


def _require_model(request: Request):
    model = request.app.state.model
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    return model

Shapes, because they’re where this goes wrong: predict_proba on an (n, 5) DataFrame returns (n, 2) — column 0 is P(stay), column 1 is P(churn). [:, 1] takes the churn column, giving (n,). The float(p) cast matters: numpy’s float64 scalars serialize fine, but being explicit at the boundary costs nothing and has saved many a TypeError when a model returns float32.

@app.post("/predict", response_model=Prediction)
def predict(features: CustomerFeatures, request: Request) -> Prediction:
    model = _require_model(request)
    df = pd.DataFrame([features.model_dump()])          # shape (1, 5)
    return _predict_df(model, request.app.state.meta, df)[0]


@app.post("/predict/batch", response_model=BatchResponse)
def predict_batch(batch: BatchRequest, request: Request) -> BatchResponse:
    model = _require_model(request)
    df = pd.DataFrame([i.model_dump() for i in batch.instances])  # shape (n, 5)
    preds = _predict_df(model, request.app.state.meta, df)        # ONE model call
    return BatchResponse(predictions=preds, count=len(preds))

Note both handlers are plain def, not async def — and that’s deliberate, not laziness. predict_proba is blocking CPU work. In an async def handler it would run on the event loop and freeze every concurrent request for its duration. Declared as sync def, FastAPI automatically runs the handler in a threadpool, keeping the event loop free. Rule of thumb: async def only when you await something; blocking model inference gets def.

Also note what the batch endpoint is: one model.predict_proba call for n rows. Looping predict n times would pay pandas/sklearn dispatch overhead n times; vectorized, a 100-row batch typically costs barely more than a single prediction.

The operational endpoints:

@app.get("/health")
def health(request: Request) -> JSONResponse:
    ok = request.app.state.model is not None
    return JSONResponse(
        {"status": "ok" if ok else "unavailable"},
        status_code=200 if ok else 503,
    )


@app.get("/model-info", response_model=ModelInfo)
def model_info(request: Request) -> ModelInfo:
    _require_model(request)
    return ModelInfo(**request.app.state.meta, threshold=THRESHOLD, features=FEATURES)

/health returning 503 with a body (not an exception, not a bare 200) is the contract load balancers and Docker healthchecks expect: 200 means “send me traffic”, anything else means “don’t”. /model-info is the endpoint you will be most grateful for during an incident — “which model version answered this?” should be one curl away, never a hunt through deploy logs.

Finally, the safety net:

@app.exception_handler(Exception)
async def unhandled_exception(request: Request, exc: Exception) -> JSONResponse:
    # Log the full traceback server-side; leak nothing to the client.
    logger.exception("Unhandled error on %s %s", request.method, request.url.path)
    return JSONResponse(status_code=500, content={"detail": "Internal server error"})

The division of labor for errors is now complete, and each failure mode has exactly one owner:

Failure Owner Status Client sees
Malformed JSON / wrong types / out-of-range values pydantic (automatic) 422 Field-level error list
Unknown extra field extra="forbid" 422 Which field is unexpected
Batch too large max_length=1024 422 The limit
Model not loaded _require_model 503 “Model not loaded”
Anything unexpected (model raises, pandas chokes) catch-all handler 500 Generic message; full traceback in logs only

Never echo raw exceptions to clients: tracebacks leak file paths, library versions, and occasionally data. Log rich, respond poor.

Run it locally and poke it:

export MLFLOW_TRACKING_URI=http://127.0.0.1:5000   # Day 3's server
uvicorn src.serve:app --reload

curl -s localhost:8000/model-info | python -m json.tool
{
    "name": "churn-classifier",
    "version": "7",
    "alias": "champion",
    "run_id": "b7c9e2f1a4d84c3e",
    "threshold": 0.5,
    "features": ["tenure_months", "monthly_charges", "total_charges",
                 "num_support_tickets", "is_month_to_month"]
}
curl -s localhost:8000/predict -H 'content-type: application/json' -d '{
  "tenure_months": 2, "monthly_charges": 95.0, "total_charges": 190.0,
  "num_support_tickets": 8, "is_month_to_month": true}'
{"churn_probability": 0.873412, "churn_label": true, "model_version": "7"}

And open http://localhost:8000/docs — FastAPI has generated interactive OpenAPI documentation from your schemas, with the json_schema_extra example pre-filled. You wrote zero documentation code.

Running it for real: uvicorn workers

--reload is for development. In production you run multiple worker processes, because Python’s GIL means one process can’t use multiple cores for CPU-bound inference:

uvicorn src.serve:app --host 0.0.0.0 --port 8000 --workers 4

What actually happens: uvicorn binds the socket once, then forks 4 worker processes that all accept() on it — the OS kernel distributes incoming connections among them. Each worker runs its own event loop, its own lifespan, and therefore its own copy of the model:

uvicorn –workers 4 : one socket, four processes, four model copies requests (clients) shared socket :8000 kernel load-balances worker 1 — event loop + threadpool + model (~120 MB) worker 2 — event loop + threadpool + model (~120 MB) worker 3 — event loop + threadpool + model (~120 MB) worker 4 — restarting: lifespan reloads the model

The sizing arithmetic you should always do before picking --workers:

  • CPU: inference is CPU-bound, so workers beyond your core count just contend. Start with \(W = \text{cores}\) (or cores − 1 to leave room for the OS).
  • Memory: total ≈ \(W \times (\text{app} + \text{model})\). A 120 MB sklearn pipeline × 4 workers = ~500 MB before you serve a single request. For today’s model that’s trivial; for a 5 GB model it’s the reason Day 7 exists — you cannot scale LLM serving by forking copies, which is precisely why vLLM uses one engine with internal batching instead.
  • Throughput ceiling: if one prediction takes mean time \(\bar{t}\) seconds of CPU, the hard upper bound is

\[ \text{RPS}_{\max} \approx \frac{W}{\bar{t}} \]

With 4 workers and \(\bar{t} = 5\) ms, that’s ~800 req/s — a number we’ll test against shortly.

Update Day 5’s Dockerfile CMD to match:

CMD ["uvicorn", "src.serve:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]

(You may see older tutorials front uvicorn with gunicorn for worker management. Since uvicorn 0.30 its own --workers supervisor is production-fit, and in Kubernetes you often run one worker per container and let the orchestrator do the replication — the container is the worker. Either way, the per-process model-copy math is identical.)

Proving it works: pytest + a tiny load test

The test suite must run in CI with no MLflow server and no trained model. That’s what the module-level load_champion was for: monkeypatch that one function and the entire app — lifespan included — runs against a fake.

# tests/test_serve.py
import numpy as np
import pytest
from fastapi.testclient import TestClient

import src.serve as serve


class FakeModel:
    """Deterministic stand-in: churn probability rises with support tickets."""

    def predict_proba(self, X):
        p = np.clip(X["num_support_tickets"].to_numpy() / 10.0, 0.01, 0.99)
        return np.column_stack([1.0 - p, p])   # (n, 2), like sklearn


FAKE_META = {"name": "churn-classifier", "version": "7",
             "alias": "champion", "run_id": "test-run"}

VALID = {
    "tenure_months": 24, "monthly_charges": 79.5, "total_charges": 1908.0,
    "num_support_tickets": 3, "is_month_to_month": True,
}


@pytest.fixture
def client(monkeypatch):
    monkeypatch.setattr(serve, "load_champion", lambda: (FakeModel(), FAKE_META))
    # `with` runs the lifespan: startup before the first request, shutdown after.
    with TestClient(serve.app) as c:
        yield c

Two mechanics to internalize. FakeModel mimics the exact sklearn contract we rely on — predict_proba returning (n, 2) — and it reads X["num_support_tickets"], which doubles as a regression test that we pass a DataFrame with named columns, not a bare array. And TestClient used as a context manager executes the real lifespan, so the tests cover the startup path too, not just the handlers.

Now the behaviors — one test per contract clause:

def test_predict_returns_typed_response(client):
    r = client.post("/predict", json=VALID)
    assert r.status_code == 200
    body = r.json()
    assert set(body) == {"churn_probability", "churn_label", "model_version"}
    assert body["churn_probability"] == pytest.approx(0.3)   # 3 tickets / 10
    assert body["churn_label"] is False                       # 0.3 < 0.5
    assert body["model_version"] == "7"


def test_threshold_flips_label(client):
    r = client.post("/predict", json={**VALID, "num_support_tickets": 8,
                                      "total_charges": 2000.0})
    assert r.json()["churn_label"] is True                    # 0.8 >= 0.5


def test_negative_tenure_rejected(client):
    r = client.post("/predict", json={**VALID, "tenure_months": -1})
    assert r.status_code == 422
    assert r.json()["detail"][0]["loc"] == ["body", "tenure_months"]


def test_unknown_field_rejected(client):
    r = client.post("/predict", json={**VALID, "monthlyCharges": 79.5})
    assert r.status_code == 422                               # extra="forbid"


def test_cross_field_rule(client):
    r = client.post("/predict", json={**VALID, "total_charges": 10.0})
    assert r.status_code == 422
    assert "total_charges" in r.text


def test_batch_is_vectorized_and_ordered(client):
    payload = {"instances": [
        {**VALID, "num_support_tickets": 1},
        {**VALID, "num_support_tickets": 9},
    ]}
    r = client.post("/predict/batch", json=payload)
    assert r.status_code == 200
    body = r.json()
    assert body["count"] == 2
    probs = [p["churn_probability"] for p in body["predictions"]]
    assert probs == pytest.approx([0.1, 0.9])   # order preserved


def test_batch_cap_enforced(client):
    payload = {"instances": [VALID] * 1025}
    r = client.post("/predict/batch", json=payload)
    assert r.status_code == 422


def test_health_ok(client):
    r = client.get("/health")
    assert r.status_code == 200 and r.json()["status"] == "ok"


def test_model_info_names_the_version(client):
    body = client.get("/model-info").json()
    assert body["version"] == "7" and body["features"][0] == "tenure_months"
$ python -m pytest tests/test_serve.py -q
.........                                                    [100%]
9 passed in 0.71s

Sub-second, no network, no MLflow, no Docker. This exact suite becomes a CI gate on Day 8.

The last claim to verify is the throughput math. A full load-testing tool (locust, k6) is overkill for a sanity check; ~40 lines of httpx + asyncio tells us what we need:

# scripts/load_test.py
"""Tiny load test: N requests at fixed concurrency, latency percentiles out."""
import asyncio
import statistics
import time

import httpx

URL = "http://localhost:8000/predict"
PAYLOAD = {
    "tenure_months": 24, "monthly_charges": 79.5, "total_charges": 1908.0,
    "num_support_tickets": 3, "is_month_to_month": True,
}
N_REQUESTS = 2000
CONCURRENCY = 32


async def worker(client: httpx.AsyncClient, sem: asyncio.Semaphore, out: list):
    async with sem:
        t0 = time.perf_counter()
        r = await client.post(URL, json=PAYLOAD)
        out.append((time.perf_counter() - t0, r.status_code))


async def main():
    sem = asyncio.Semaphore(CONCURRENCY)
    results: list[tuple[float, int]] = []
    async with httpx.AsyncClient(timeout=10.0) as client:
        t0 = time.perf_counter()
        await asyncio.gather(*(worker(client, sem, results) for _ in range(N_REQUESTS)))
        wall = time.perf_counter() - t0

    lat = sorted(t for t, _ in results)
    errors = sum(1 for _, s in results if s != 200)
    q = statistics.quantiles(lat, n=100)          # q[i] = (i+1)th percentile
    print(f"requests: {len(lat)}   errors: {errors}   wall: {wall:.2f}s")
    print(f"throughput: {len(lat) / wall:,.0f} req/s")
    print(f"latency  p50: {q[49]*1000:6.1f} ms   p95: {q[94]*1000:6.1f} ms   "
          f"p99: {q[98]*1000:6.1f} ms")


if __name__ == "__main__":
    asyncio.run(main())

The semaphore is the load model: at most 32 requests in flight at once, mimicking 32 concurrent clients. Typical output against 4 workers on a laptop:

requests: 2000   errors: 0   wall: 3.41s
throughput: 587 req/s
latency  p50:   48.3 ms   p95:   89.1 ms   p99:  142.7 ms

Read it like an SRE: throughput lands in the same order of magnitude as the \(W/\bar{t}\) estimate (never exactly — serialization, validation, and the network all tax it, and hardware never matches the napkin). And p99 is ~3× p50 — tail latency, not the median, is what your SLO should be written against, because at scale the 99th percentile is somebody’s every request. Re-run with --workers 1 and watch throughput drop and the tail stretch; now the worker math from the previous section is something you’ve measured, not read.

🧪 Your task

The business team wants to experiment with the decision threshold without redeploying. Extend the batch endpoint to accept an optional threshold query parameter (e.g. POST /predict/batch?threshold=0.3) that overrides THRESHOLD for that request only. It must be validated to lie in \([0, 1]\) (reject 422 otherwise), it must not change churn_probability — only churn_label — and out-of-range or non-numeric values must never reach the prediction path. Add two pytest cases: one showing the same instance flips from false to true at a lower threshold, one showing threshold=1.5 is rejected.

Hint: FastAPI turns a handler parameter annotated threshold: float | None = Query(default=None, ge=0.0, le=1.0) into a validated query param — the same ge/le machinery pydantic gave you for the body. You’ll need _predict_df to accept the effective threshold instead of reading the module constant.

Solution
# --- src/serve.py changes ---
from fastapi import FastAPI, HTTPException, Query, Request   # add Query


def _predict_df(model, meta: dict, df: pd.DataFrame,
                threshold: float = THRESHOLD) -> list[Prediction]:
    proba = model.predict_proba(df[FEATURES])[:, 1]
    version = str(meta["version"])
    return [
        Prediction(
            churn_probability=round(float(p), 6),
            churn_label=bool(p >= threshold),
            model_version=version,
        )
        for p in proba
    ]


@app.post("/predict/batch", response_model=BatchResponse)
def predict_batch(
    batch: BatchRequest,
    request: Request,
    threshold: float | None = Query(
        default=None, ge=0.0, le=1.0,
        description="Optional per-request decision threshold; defaults to the service setting.",
    ),
) -> BatchResponse:
    model = _require_model(request)
    df = pd.DataFrame([i.model_dump() for i in batch.instances])
    eff = THRESHOLD if threshold is None else threshold
    preds = _predict_df(model, request.app.state.meta, df, threshold=eff)
    return BatchResponse(predictions=preds, count=len(preds))
# --- tests/test_serve.py additions ---
def test_threshold_override_flips_label(client):
    payload = {"instances": [{**VALID, "num_support_tickets": 3}]}  # p = 0.3

    default = client.post("/predict/batch", json=payload).json()
    assert default["predictions"][0]["churn_label"] is False        # 0.3 < 0.5

    lowered = client.post("/predict/batch?threshold=0.25", json=payload).json()
    assert lowered["predictions"][0]["churn_label"] is True         # 0.3 >= 0.25
    # probability itself must be untouched by the threshold
    assert lowered["predictions"][0]["churn_probability"] == pytest.approx(0.3)


def test_threshold_out_of_range_rejected(client):
    payload = {"instances": [VALID]}
    r = client.post("/predict/batch?threshold=1.5", json=payload)
    assert r.status_code == 422
    assert r.json()["detail"][0]["loc"] == ["query", "threshold"]

Notice how little changed: the validation is declarative (Query(ge=0, le=1) — no if statements), the single /predict endpoint is untouched, and _predict_df stays the one shared prediction path with the default threading through as a default argument.

Key takeaways

  • Pydantic schemas are the trust boundary: field bounds, extra="forbid", and cross-field model_validators stop garbage before the model ever sees it — and generate your OpenAPI docs for free.
  • Load the registry champion once per worker in the lifespan; crash-fast if it fails; expose the exact version via /model-info and readiness via a 200/503 /health.
  • Blocking inference belongs in plain def handlers (threadpool), not async def (event loop) — and single + batch endpoints should share one vectorized prediction path with a hard batch-size cap.
  • Errors have exactly one owner each: 422 from validation, 503 for no model, 500 (logged rich, returned poor) for everything else.
  • --workers W forks W processes with W model copies; budget memory as \(W \times \text{model size}\) and sanity-check throughput against \(W/\bar{t}\) — then measure it, and read p99, not p50.
  • A monkeypatched load_champion plus TestClient gives you a full-contract test suite with zero infrastructure — your Day 8 CI gate.

Tomorrow the per-worker-model-copy trick hits a wall: the model is a multi-gigabyte LLM, and we serve it properly with vLLM — one engine, continuous batching, and OpenAI-compatible endpoints.


🏠 🚢 Course home  |  ← Day 05  |  Day 07 →  |  📚 All mini-courses

 

© Kader Mohideen