How to export PyTorch model to ONNX
Quick answer
Use
torch.onnx.export to convert a PyTorch model to ONNX format by providing the model, a sample input tensor, and the output file path. This exports the model graph and weights for interoperability with other frameworks.PREREQUISITES
Python 3.8+pip install torch>=2.0.0pip install onnx
Setup
Install the required packages torch and onnx using pip if you haven't already. Ensure you have Python 3.8 or newer.
pip install torch onnx Step by step
This example shows how to export a simple PyTorch model to ONNX format. It includes defining a model, creating a dummy input, and calling torch.onnx.export. The exported ONNX file can be loaded by other frameworks.
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
model.eval() # Set to evaluation mode
# Create a dummy input tensor with the correct shape
dummy_input = torch.randn(1, 10)
# Export the model to ONNX format
onnx_path = "simple_model.onnx"
torch.onnx.export(
model, # model to export
dummy_input, # model input (or a tuple for multiple inputs)
onnx_path, # where to save the ONNX file
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}} # variable batch size
)
print(f"Model exported to {onnx_path}") output
Model exported to simple_model.onnx
Common variations
- Change
opset_versionto target different ONNX versions (default is 17 as of 2026). - Export models with multiple inputs by passing a tuple to
dummy_inputand naming inputs accordingly. - Use
dynamic_axesto allow variable input sizes, e.g., batch size. - For models with control flow, ensure
do_constant_foldingis set appropriately.
Troubleshooting
- If you get errors about unsupported operators, upgrade PyTorch and ONNX to the latest versions.
- Check that the dummy input shape matches the model's expected input shape.
- Use
torch.onnx.exportwithverbose=Trueto debug the export graph. - Verify the exported ONNX model with
onnx.checker.check_modelto ensure validity.
import onnx
# Load and check the exported ONNX model
model_onnx = onnx.load("simple_model.onnx")
onnx.checker.check_model(model_onnx)
print("ONNX model is valid") output
ONNX model is valid
Key Takeaways
- Use
torch.onnx.exportwith a dummy input to export PyTorch models to ONNX. - Set
dynamic_axesfor flexible input sizes like batch dimension. - Verify the ONNX model with
onnx.checker.check_modelto catch export issues.