ONNX for deployment on non-PyTorch stacks
Why this matters
Production deployments often require running models on C++ servers, mobile devices, browser environments, or frameworks (TensorFlow, TensorRT, CoreML) where PyTorch isn't available or introduces unacceptable overhead. ONNX is the industry standard interchange format that eliminates this bottleneck.
Explanation
What it is: ONNX (Open Neural Network Exchange) is a standardized format for representing trained neural networks, independent of the framework that created them. A PyTorch model exported to ONNX becomes a static computation graph that any ONNX runtime (or framework that reads ONNX) can execute.
How it works mechanically: PyTorch's torch.onnx.export() traces your model by running it on dummy input, recording every operation into an ONNX graph. This creates an .onnx file: essentially a protobuf containing the model's architecture and weights. You then load this file in your target environment (e.g., ONNX Runtime, TensorRT, CoreML) and run inference without PyTorch. The tracing approach has a critical constraint: it captures only the executed path, so dynamic control flow (if statements, loops based on tensor values) must be either avoided or handled with special patterns like torch.jit.trace() or symbolic execution.
When to use it: Export to ONNX at the end of training, validate that inference produces identical outputs, then deploy the .onnx file instead of the PyTorch checkpoint. This is essential for cloud inference (reduce container size), edge deployment (mobile, IoT), and multi-framework pipelines (train in PyTorch, serve via TensorRT for NVIDIA hardware).
Analogy
Think of ONNX like compiled machine code: your PyTorch model is source code that's powerful but requires a runtime to execute. ONNX is the compiled binary: smaller, faster to load, and runnable by any runtime that speaks the same instruction set. You lose the ability to recompile on the fly (no dynamic changes), but you gain portability and efficiency.
Code
import torch
import torch.nn as nn
import torch.onnx
import onnx
import onnxruntime as rt
import numpy as np
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 16)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(16, 3)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleNet()
model.eval()
dummy_input = torch.randn(1, 10)
torch.onnx.export(
model,
dummy_input,
"simple_net.onnx",
input_names=["input"],
output_names=["output"],
opset_version=14,
do_constant_folding=True,
verbose=False
)
print("✓ Model exported to simple_net.onnx")
onnx_model = onnx.load("simple_net.onnx")
onnx.checker.check_model(onnx_model)
print("✓ ONNX model validation passed")
sess = rt.InferenceSession("simple_net.onnx", providers=['CPUExecutionProvider'])
test_input = np.random.randn(1, 10).astype(np.float32)
pytorch_output = model(torch.from_numpy(test_input)).detach().numpy()
onnx_input = {sess.get_inputs()[0].name: test_input}
onnx_output = sess.run(None, onnx_input)[0]
max_diff = np.max(np.abs(pytorch_output - onnx_output))
print(f"✓ Max difference between PyTorch and ONNX outputs: {max_diff:.2e}")
print(f"✓ Outputs match: {np.allclose(pytorch_output, onnx_output, atol=1e-5)}")
print("\nONNX model info:")
for input_tensor in onnx_model.graph.input:
print(f" Input: {input_tensor.name}, shape: {[d.dim_value for d in input_tensor.type.tensor_type.shape.dim]}")
for output_tensor in onnx_model.graph.output:
print(f" Output: {output_tensor.name}, shape: {[d.dim_value for d in output_tensor.type.tensor_type.shape.dim]}") ✓ Model exported to simple_net.onnx ✓ ONNX model validation passed ✓ Max difference between PyTorch and ONNX outputs: 5.96e-07 ✓ Outputs match: True ONNX model info: Input: input, shape: [1, 10] Output: output, shape: [1, 3]
What just happened?
The code defined a simple PyTorch model, exported it to ONNX format using <code>torch.onnx.export()</code> with opset_version=14 (the operator set specification), then loaded the ONNX file and created an inference session using ONNX Runtime. It ran the same test input through both the PyTorch model and the ONNX runtime, verified the outputs matched to within floating-point tolerance (5.96e-07 difference is expected numerical variation from different execution paths), and printed the model's input/output shapes from the ONNX graph metadata.
Common gotcha
The most common mistake is exporting with model.train() still active or forgetting model.eval() beforehand. Dropout and batch normalization behave differently in train mode: your ONNX export will capture the training-time behavior (stochastic dropout, running stats) instead of the inference-time behavior. Always call model.eval() and wrap the export in torch.no_grad() before exporting. Second gotcha: assuming the dummy input shape dictates all batch sizes: if you export with torch.randn(1, 10), ONNX may bake batch size 1 as a constant. Use dynamic_axes={"input": {0: "batch_size"}} in export to allow variable batch sizes.
Error recovery
RuntimeError: ONNX export failed with operator not supportedValueError: input size mismatch at export timeAssertionError from onnx.checker.check_model()ort.InvalidArgument: 'input' does not have shape informationExperienced dev note
One critical insight: ONNX export is not a black box validation. The fact that PyTorch output and ONNX output match on dummy input does NOT guarantee they'll match on production data, especially if: (1) your model has numerical instabilities that manifest only on edge cases, (2) you're using ops that have slightly different numerics between PyTorch and the target ONNX runtime (e.g., layer norm, reductions), or (3) quantization is involved. Always validate on a representative sample of real data in your target runtime before deploying. Second insight: opset_version matters enormously. Lower opsets (11–12) have fewer operations; higher opsets (17+, available in PyTorch 2.0+) have better coverage and often better optimization by ONNX runtimes. Default to opset=17 or check what your target runtime supports. Third: ONNX Runtime has different execution providers (CPU, CUDA, TensorRT, CoreML). Benchmark your exported model on the actual target provider before deployment: CUDA performance can differ dramatically from CPU, and TensorRT needs separate quantization workflows.
Check your understanding
Why would exporting a model with model.train() still active cause inference to produce wrong results in production, and why does this specific mistake often go undetected during export validation?
Show answer hint
A correct answer explains that (1) batch norm and dropout use different forward passes in train vs eval mode (running stats vs fixed stats, stochastic vs identity), (2) the ONNX export captures the train-time graph (with dropout and moving averages), (3) this traces correctly and exports without error, so validation passes, but (4) when the ONNX runtime executes, it runs the train-mode ops which produce different outputs than expected inference-mode behavior, and this only becomes visible when comparing against a model in eval mode or testing on real data where batch statistics matter.