ONNX model pruning guide
Quick answer
Pruning an
ONNX model involves removing unnecessary weights or nodes to reduce model size and improve inference speed. Use tools like onnxruntime with pruning libraries or convert to PyTorch for structured pruning before exporting back to ONNX.PREREQUISITES
Python 3.8+pip install onnx onnxruntime onnx-simplifier torch torchvision
Setup
Install the necessary Python packages to work with ONNX models and pruning tools.
pip install onnx onnxruntime onnx-simplifier torch torchvision Step by step pruning
This example shows how to load a PyTorch model, prune it using structured pruning, export it to ONNX, and simplify the resulting ONNX model.
import torch
import torch.nn.utils.prune as prune
import torchvision.models as models
import onnx
from onnxsim import simplify
# Load pretrained PyTorch model
model = models.resnet18(pretrained=True)
model.eval()
# Apply structured pruning to convolutional layers
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(module, name='weight', amount=0.3, n=2, dim=0) # prune 30% filters
# Remove pruning re-parametrization to make pruning permanent
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.remove(module, 'weight')
# Export pruned model to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = 'pruned_resnet18.onnx'
torch.onnx.export(model, dummy_input, onnx_path, opset_version=13, input_names=['input'], output_names=['output'])
# Simplify ONNX model
onnx_model = onnx.load(onnx_path)
simplified_model, check = simplify(onnx_model)
assert check, 'Simplified ONNX model could not be validated'
onnx.save(simplified_model, 'pruned_resnet18_simplified.onnx')
print('Pruned and simplified ONNX model saved as pruned_resnet18_simplified.onnx') output
Pruned and simplified ONNX model saved as pruned_resnet18_simplified.onnx
Common variations
- Use
torch.nn.utils.prunefor unstructured pruning (individual weights) or structured pruning (filters/channels). - Prune models directly in PyTorch before exporting to
ONNXfor best control. - Use
onnxruntimefor inference speed benchmarking before and after pruning. - For quantization combined with pruning, use
onnxruntime.quantizationtools.
Troubleshooting
- If
onnxsim.simplifyfails, check model opset compatibility and try upgradingonnxandonnx-simplifier. - If pruning causes accuracy drop, reduce pruning amount or fine-tune the pruned model.
- Ensure pruning re-parametrizations are removed before exporting to
ONNXto avoid export errors.
Key Takeaways
- Prune models in PyTorch before exporting to ONNX for best pruning control.
- Use structured pruning to remove entire filters or channels for efficient speedup.
- Simplify ONNX models after pruning to optimize graph and reduce size.
- Test inference speed and accuracy before and after pruning to validate improvements.