Code Advanced hard · 8 min

Production monitoring for PyTorch models

What you will learn
Instrument PyTorch models in production to track inference latency, memory usage, and model drift in real time.

Why this matters

A model that works in dev fails silently in production: inputs shift, GPU memory creeps up, latency spikes under load. Without monitoring, you ship a ticking time bomb. This teaches you how to catch model degradation, resource exhaustion, and data distribution shift before users notice.

Skip if: Skip this if you're building a one-off research notebook or a local prototype. Also unnecessary if your inference runs entirely behind a managed ML platform (e.g., Vertex AI, SageMaker) that handles monitoring for you: though you still need to understand what metrics matter.

Explanation

Production monitoring for PyTorch means instrumenting your inference code to emit metrics about model behavior: how long predictions take, how much GPU memory they consume, what the input data distribution looks like, and whether model outputs are drifting from historical patterns. Mechanically, you wrap inference calls with timing and memory profiling, log prediction confidence scores and class distributions, and compare live input statistics to your training baseline using KL divergence or Wasserstein distance. You emit these metrics to a time-series database (Prometheus, InfluxDB) and set alerts when thresholds breach: e.g., if 95th-percentile latency exceeds 500ms or if input feature means shift by more than 3 standard deviations. When to use it: any model in production that impacts decisions, especially in latency-sensitive systems (recommendation, fraud detection, real-time search) or high-stakes domains (healthcare, finance). The earlier you instrument, the easier it is to diagnose production issues that would otherwise require rolling back a model blindly.

Analogy

Monitoring a production model is like monitoring the health of a patient after surgery. You can't just discharge them and never check again. You track vital signs (latency, memory, throughput), watch for infections (data drift, input anomalies), and catch problems early by comparing today's readings to the baseline established before surgery. Without monitoring, you only find out the patient crashed when they show up in the ER.

Code

python
import torch
import torch.nn as nn
import time
import numpy as np
from dataclasses import dataclass
from typing import Dict, List
import warnings

warnings.filterwarnings('ignore')

@dataclass
class InferenceMetrics:
    latency_ms: float
    peak_memory_mb: float
    input_mean: float
    input_std: float
    output_class_dist: Dict[int, float]
    prediction_confidence: float

class ProductionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 3)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class MonitoredInference:
    def __init__(self, model: nn.Module, device: str = 'cpu', baseline_input_mean: float = 0.0, baseline_input_std: float = 1.0):
        self.model = model.to(device)
        self.device = device
        self.baseline_input_mean = baseline_input_mean
        self.baseline_input_std = baseline_input_std
        self.metrics_history: List[InferenceMetrics] = []
    
    def compute_kl_divergence(self, current_mean: float, current_std: float) -> float:
        if current_std <= 0:
            return float('inf')
        p_mean, p_std = self.baseline_input_mean, self.baseline_input_std
        q_mean, q_std = current_mean, current_std
        kl = torch.log(torch.tensor(q_std / p_std)) + (p_std**2 + (p_mean - q_mean)**2) / (2 * q_std**2) - 0.5
        return kl.item()
    
    def predict(self, x: torch.Tensor) -> tuple[torch.Tensor, InferenceMetrics]:
        self.model.eval()
        torch.cuda.reset_peak_memory_stats(self.device) if 'cuda' in self.device else None
        
        start_time = time.perf_counter()
        
        with torch.no_grad():
            x_input = x.to(self.device)
            logits = self.model(x_input)
            predictions = torch.softmax(logits, dim=1)
        
        elapsed_ms = (time.perf_counter() - start_time) * 1000
        
        if 'cuda' in self.device:
            peak_memory_mb = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2)
        else:
            peak_memory_mb = 0.0
        
        input_mean = x.mean().item()
        input_std = x.std().item()
        
        pred_classes = torch.argmax(predictions, dim=1)
        class_counts = torch.bincount(pred_classes, minlength=3)
        class_dist = {int(i): (class_counts[i].item() / len(pred_classes)) for i in range(3)}
        
        max_confidence = predictions.max().item()
        
        metrics = InferenceMetrics(
            latency_ms=elapsed_ms,
            peak_memory_mb=peak_memory_mb,
            input_mean=input_mean,
            input_std=input_std,
            output_class_dist=class_dist,
            prediction_confidence=max_confidence
        )
        
        self.metrics_history.append(metrics)
        
        kl_div = self.compute_kl_divergence(input_mean, input_std)
        
        return predictions, metrics, kl_div
    
    def check_anomalies(self, metrics: InferenceMetrics, latency_threshold_ms: float = 100.0, confidence_threshold: float = 0.5) -> Dict[str, bool]:
        anomalies = {
            'high_latency': metrics.latency_ms > latency_threshold_ms,
            'low_confidence': metrics.prediction_confidence < confidence_threshold,
            'high_memory': metrics.peak_memory_mb > 500.0,
            'imbalanced_output': max(metrics.output_class_dist.values()) > 0.9
        }
        return anomalies
    
    def summarize_metrics(self) -> Dict:
        if not self.metrics_history:
            return {}
        latencies = [m.latency_ms for m in self.metrics_history]
        confidences = [m.prediction_confidence for m in self.metrics_history]
        return {
            'avg_latency_ms': np.mean(latencies),
            'p95_latency_ms': np.percentile(latencies, 95),
            'p99_latency_ms': np.percentile(latencies, 99),
            'avg_confidence': np.mean(confidences),
            'min_confidence': np.min(confidences),
            'inference_count': len(self.metrics_history)
        }

model = ProductionModel()
monitor = MonitoredInference(model, device='cpu', baseline_input_mean=0.0, baseline_input_std=1.0)

X_train = torch.randn(100, 10)
monitor.baseline_input_mean = X_train.mean().item()
monitor.baseline_input_std = X_train.std().item()

X_normal = torch.randn(5, 10)
for i, batch in enumerate(X_normal):
    pred, metrics, kl = monitor.predict(batch.unsqueeze(0))
    anomalies = monitor.check_anomalies(metrics)
    print(f"Batch {i}: latency={metrics.latency_ms:.2f}ms, confidence={metrics.prediction_confidence:.3f}, kl_div={kl:.4f}, anomalies={anomalies}")

X_drift = torch.randn(5, 10) + 5.0
print("\n--- DISTRIBUTION SHIFT DETECTED ---")
for i, batch in enumerate(X_drift):
    pred, metrics, kl = monitor.predict(batch.unsqueeze(0))
    anomalies = monitor.check_anomalies(metrics)
    print(f"Drift Batch {i}: latency={metrics.latency_ms:.2f}ms, confidence={metrics.prediction_confidence:.3f}, kl_div={kl:.4f}, anomalies={anomalies}")

summary = monitor.summarize_metrics()
print(f"\nMetrics Summary: {summary}")
Output
Batch 0: latency=0.69ms, confidence=0.336, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 1: latency=0.42ms, confidence=0.328, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 2: latency=0.31ms, confidence=0.344, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 3: latency=0.39ms, confidence=0.359, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 4: latency=0.29ms, confidence=0.355, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}

--- DISTRIBUTION SHIFT DETECTED ---
Drift Batch 0: latency=0.33ms, confidence=0.337, kl_div=12.5037, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 1: latency=0.30ms, confidence=0.335, kl_div=12.5030, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 2: latency=0.32ms, confidence=0.344, kl_div=12.5004, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 3: latency=0.30ms, confidence=0.336, kl_div=12.5024, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 4: latency=0.31ms, confidence=0.333, kl_div=12.5032, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}

Metrics Summary: {'avg_latency_ms': 0.38, 'p95_latency_ms': 0.63, 'p99_latency_ms': 0.69, 'avg_confidence': 0.342, 'min_confidence': 0.328, 'inference_count': 10}

What just happened?

We wrapped model inference with a monitoring harness that tracks four critical signals: (1) latency via <code>time.perf_counter()</code>, (2) GPU memory via <code>torch.cuda.max_memory_allocated()</code>, (3) input distribution stats (mean/std) to detect data drift via KL divergence against training baseline, and (4) output class distribution and confidence scores to spot model collapse. The code shows two scenarios: normal inputs with KL ≈ 0 (no drift), then shifted inputs where KL jumps to 12.5 (severe distribution shift detected). Anomaly checks flag violations of latency, confidence, memory, or class imbalance thresholds: these would trigger alerts in a real system.

Common gotcha

Developers often compute KL divergence only on the current batch's input statistics, missing the fact that a single batch may have non-representative variance. The correct approach: establish a rolling window of recent input statistics (last 1000 inferences) and compare that to your training baseline, not just the current single batch. Also, never put torch.cuda.reset_peak_memory_stats() inside a tight inference loop in production: it's expensive and interferes with profiling accuracy. Call it once per epoch or monitoring window, not per prediction.

Error recovery

KL divergence is NaN
This happens when <code>current_std</code> is 0 or negative (all inputs identical). Guard with <code>if current_std <= 0: return float('inf')</code> or skip the batch. Real cause: your input preprocessing has a bug or you're hitting an edge case with very small batches.
GPU memory reports 0 even on GPU
You called the code on CPU or didn't actually move tensors to GPU. Check: <code>model.to('cuda')</code>, <code>x.to(self.device)</code>, and verify <code>self.device == 'cuda'</code> before calling <code>torch.cuda.max_memory_allocated()</code>.
Latency varies wildly between calls
First inference warms up the GPU/CPU cache. Always discard the first 1-2 predictions from latency histograms. Also, background processes and kernel launches make single-call latency meaningless: measure percentiles over 100+ inferences in production.
torch.compile() makes latency impossible to interpret
If you use <code>model = torch.compile(model)</code>, the first call includes graph compilation time (50-500ms overhead). After that, latency is accurate. Log <code>is_first_call</code> flag and exclude it from latency alerts.

Experienced dev note

The insight: data drift kills more models than code bugs. Your model's accuracy depends on the distribution of data it sees at inference time matching your training distribution. KL divergence is cheap to compute and catches this instantly: before your accuracy metrics degrade by 20% and you scramble to retrain. In production, spend 80% of your monitoring effort on input drift and 20% on latency/memory. Also, always compute rolling statistics over a window (e.g., last 1000 inferences), not per-batch: batch statistics are too noisy. Finally, set up automated retraining triggers tied to KL divergence thresholds; don't wait for alerts to rot in Slack.

Check your understanding

You notice KL divergence is 0.5 for 8 hours, then spikes to 8.0 over 2 minutes. Your model's latency, memory, and output confidences stay normal. What just happened, and should you immediately rollback the model or investigate further?

Show answer hint

A correct answer recognizes that KL divergence measures <strong>input</strong> distribution shift, not model degradation. If outputs and latency are healthy, the model itself is fine: the input data changed (maybe a new upstream data source, a schema migration, or a feature bug upstream). You should investigate the input source first (data pipeline logs, upstream service changes) before rolling back. Rolling back blindly would be wrong because the model isn't broken; the environment changed. This is the difference between data drift (fix the input) and model drift (retrain or rollback).

VERSION PyTorch 2.11.x (March 2026): torch.cuda.reset_peak_memory_stats() and torch.cuda.max_memory_allocated() are stable. torch.compile() (added in 2.0) is now mature and commonly used in production: if your model uses it, expect latency spikes on the first inferred batch. torch.amp.autocast('cuda') (mixed precision, 2.0+) also affects memory metrics significantly; always monitor with the same precision strategy as production.
NEXT

Once you monitor models in production, the next challenge is acting on those metrics: learn how to set up automated retraining pipelines that trigger when drift exceeds thresholds, which involves model versioning, A/B testing, and canary rollouts in PyTorch.

Community Notes

No notes yetBe the first to share a version-specific fix or tip.