Production monitoring for PyTorch models
Why this matters
A model that works in dev fails silently in production: inputs shift, GPU memory creeps up, latency spikes under load. Without monitoring, you ship a ticking time bomb. This teaches you how to catch model degradation, resource exhaustion, and data distribution shift before users notice.
Explanation
Production monitoring for PyTorch means instrumenting your inference code to emit metrics about model behavior: how long predictions take, how much GPU memory they consume, what the input data distribution looks like, and whether model outputs are drifting from historical patterns. Mechanically, you wrap inference calls with timing and memory profiling, log prediction confidence scores and class distributions, and compare live input statistics to your training baseline using KL divergence or Wasserstein distance. You emit these metrics to a time-series database (Prometheus, InfluxDB) and set alerts when thresholds breach: e.g., if 95th-percentile latency exceeds 500ms or if input feature means shift by more than 3 standard deviations. When to use it: any model in production that impacts decisions, especially in latency-sensitive systems (recommendation, fraud detection, real-time search) or high-stakes domains (healthcare, finance). The earlier you instrument, the easier it is to diagnose production issues that would otherwise require rolling back a model blindly.
Analogy
Monitoring a production model is like monitoring the health of a patient after surgery. You can't just discharge them and never check again. You track vital signs (latency, memory, throughput), watch for infections (data drift, input anomalies), and catch problems early by comparing today's readings to the baseline established before surgery. Without monitoring, you only find out the patient crashed when they show up in the ER.
Code
import torch
import torch.nn as nn
import time
import numpy as np
from dataclasses import dataclass
from typing import Dict, List
import warnings
warnings.filterwarnings('ignore')
@dataclass
class InferenceMetrics:
latency_ms: float
peak_memory_mb: float
input_mean: float
input_std: float
output_class_dist: Dict[int, float]
prediction_confidence: float
class ProductionModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 3)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
class MonitoredInference:
def __init__(self, model: nn.Module, device: str = 'cpu', baseline_input_mean: float = 0.0, baseline_input_std: float = 1.0):
self.model = model.to(device)
self.device = device
self.baseline_input_mean = baseline_input_mean
self.baseline_input_std = baseline_input_std
self.metrics_history: List[InferenceMetrics] = []
def compute_kl_divergence(self, current_mean: float, current_std: float) -> float:
if current_std <= 0:
return float('inf')
p_mean, p_std = self.baseline_input_mean, self.baseline_input_std
q_mean, q_std = current_mean, current_std
kl = torch.log(torch.tensor(q_std / p_std)) + (p_std**2 + (p_mean - q_mean)**2) / (2 * q_std**2) - 0.5
return kl.item()
def predict(self, x: torch.Tensor) -> tuple[torch.Tensor, InferenceMetrics]:
self.model.eval()
torch.cuda.reset_peak_memory_stats(self.device) if 'cuda' in self.device else None
start_time = time.perf_counter()
with torch.no_grad():
x_input = x.to(self.device)
logits = self.model(x_input)
predictions = torch.softmax(logits, dim=1)
elapsed_ms = (time.perf_counter() - start_time) * 1000
if 'cuda' in self.device:
peak_memory_mb = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2)
else:
peak_memory_mb = 0.0
input_mean = x.mean().item()
input_std = x.std().item()
pred_classes = torch.argmax(predictions, dim=1)
class_counts = torch.bincount(pred_classes, minlength=3)
class_dist = {int(i): (class_counts[i].item() / len(pred_classes)) for i in range(3)}
max_confidence = predictions.max().item()
metrics = InferenceMetrics(
latency_ms=elapsed_ms,
peak_memory_mb=peak_memory_mb,
input_mean=input_mean,
input_std=input_std,
output_class_dist=class_dist,
prediction_confidence=max_confidence
)
self.metrics_history.append(metrics)
kl_div = self.compute_kl_divergence(input_mean, input_std)
return predictions, metrics, kl_div
def check_anomalies(self, metrics: InferenceMetrics, latency_threshold_ms: float = 100.0, confidence_threshold: float = 0.5) -> Dict[str, bool]:
anomalies = {
'high_latency': metrics.latency_ms > latency_threshold_ms,
'low_confidence': metrics.prediction_confidence < confidence_threshold,
'high_memory': metrics.peak_memory_mb > 500.0,
'imbalanced_output': max(metrics.output_class_dist.values()) > 0.9
}
return anomalies
def summarize_metrics(self) -> Dict:
if not self.metrics_history:
return {}
latencies = [m.latency_ms for m in self.metrics_history]
confidences = [m.prediction_confidence for m in self.metrics_history]
return {
'avg_latency_ms': np.mean(latencies),
'p95_latency_ms': np.percentile(latencies, 95),
'p99_latency_ms': np.percentile(latencies, 99),
'avg_confidence': np.mean(confidences),
'min_confidence': np.min(confidences),
'inference_count': len(self.metrics_history)
}
model = ProductionModel()
monitor = MonitoredInference(model, device='cpu', baseline_input_mean=0.0, baseline_input_std=1.0)
X_train = torch.randn(100, 10)
monitor.baseline_input_mean = X_train.mean().item()
monitor.baseline_input_std = X_train.std().item()
X_normal = torch.randn(5, 10)
for i, batch in enumerate(X_normal):
pred, metrics, kl = monitor.predict(batch.unsqueeze(0))
anomalies = monitor.check_anomalies(metrics)
print(f"Batch {i}: latency={metrics.latency_ms:.2f}ms, confidence={metrics.prediction_confidence:.3f}, kl_div={kl:.4f}, anomalies={anomalies}")
X_drift = torch.randn(5, 10) + 5.0
print("\n--- DISTRIBUTION SHIFT DETECTED ---")
for i, batch in enumerate(X_drift):
pred, metrics, kl = monitor.predict(batch.unsqueeze(0))
anomalies = monitor.check_anomalies(metrics)
print(f"Drift Batch {i}: latency={metrics.latency_ms:.2f}ms, confidence={metrics.prediction_confidence:.3f}, kl_div={kl:.4f}, anomalies={anomalies}")
summary = monitor.summarize_metrics()
print(f"\nMetrics Summary: {summary}") Batch 0: latency=0.69ms, confidence=0.336, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 1: latency=0.42ms, confidence=0.328, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 2: latency=0.31ms, confidence=0.344, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 3: latency=0.39ms, confidence=0.359, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Batch 4: latency=0.29ms, confidence=0.355, kl_div=0.0000, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
--- DISTRIBUTION SHIFT DETECTED ---
Drift Batch 0: latency=0.33ms, confidence=0.337, kl_div=12.5037, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 1: latency=0.30ms, confidence=0.335, kl_div=12.5030, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 2: latency=0.32ms, confidence=0.344, kl_div=12.5004, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 3: latency=0.30ms, confidence=0.336, kl_div=12.5024, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Drift Batch 4: latency=0.31ms, confidence=0.333, kl_div=12.5032, anomalies={'high_latency': False, 'low_confidence': False, 'high_memory': False, 'imbalanced_output': False}
Metrics Summary: {'avg_latency_ms': 0.38, 'p95_latency_ms': 0.63, 'p99_latency_ms': 0.69, 'avg_confidence': 0.342, 'min_confidence': 0.328, 'inference_count': 10} What just happened?
We wrapped model inference with a monitoring harness that tracks four critical signals: (1) latency via <code>time.perf_counter()</code>, (2) GPU memory via <code>torch.cuda.max_memory_allocated()</code>, (3) input distribution stats (mean/std) to detect data drift via KL divergence against training baseline, and (4) output class distribution and confidence scores to spot model collapse. The code shows two scenarios: normal inputs with KL ≈ 0 (no drift), then shifted inputs where KL jumps to 12.5 (severe distribution shift detected). Anomaly checks flag violations of latency, confidence, memory, or class imbalance thresholds: these would trigger alerts in a real system.
Common gotcha
Developers often compute KL divergence only on the current batch's input statistics, missing the fact that a single batch may have non-representative variance. The correct approach: establish a rolling window of recent input statistics (last 1000 inferences) and compare that to your training baseline, not just the current single batch. Also, never put torch.cuda.reset_peak_memory_stats() inside a tight inference loop in production: it's expensive and interferes with profiling accuracy. Call it once per epoch or monitoring window, not per prediction.
Error recovery
KL divergence is NaNGPU memory reports 0 even on GPULatency varies wildly between callstorch.compile() makes latency impossible to interpretExperienced dev note
The insight: data drift kills more models than code bugs. Your model's accuracy depends on the distribution of data it sees at inference time matching your training distribution. KL divergence is cheap to compute and catches this instantly: before your accuracy metrics degrade by 20% and you scramble to retrain. In production, spend 80% of your monitoring effort on input drift and 20% on latency/memory. Also, always compute rolling statistics over a window (e.g., last 1000 inferences), not per-batch: batch statistics are too noisy. Finally, set up automated retraining triggers tied to KL divergence thresholds; don't wait for alerts to rot in Slack.
Check your understanding
You notice KL divergence is 0.5 for 8 hours, then spikes to 8.0 over 2 minutes. Your model's latency, memory, and output confidences stay normal. What just happened, and should you immediately rollback the model or investigate further?
Show answer hint
A correct answer recognizes that KL divergence measures <strong>input</strong> distribution shift, not model degradation. If outputs and latency are healthy, the model itself is fine: the input data changed (maybe a new upstream data source, a schema migration, or a feature bug upstream). You should investigate the input source first (data pipeline logs, upstream service changes) before rolling back. Rolling back blindly would be wrong because the model isn't broken; the environment changed. This is the difference between data drift (fix the input) and model drift (retrain or rollback).
torch.cuda.reset_peak_memory_stats() and torch.cuda.max_memory_allocated() are stable. torch.compile() (added in 2.0) is now mature and commonly used in production: if your model uses it, expect latency spikes on the first inferred batch. torch.amp.autocast('cuda') (mixed precision, 2.0+) also affects memory metrics significantly; always monitor with the same precision strategy as production.