ONNX for cross-platform deployment
Quick answer
Use
ONNX to export models from frameworks like PyTorch or TensorFlow into a standardized format that runs efficiently across platforms. Deploy the ONNX Runtime on target devices to achieve cross-platform inference with consistent performance and compatibility.PREREQUISITES
Python 3.8+pip install onnx onnxruntimeA trained ML model in PyTorch or TensorFlow
Setup
Install the necessary packages to work with ONNX and ONNX Runtime for model export and inference.
pip install onnx onnxruntime Step by step
Export a PyTorch model to ONNX format and run inference using ONNX Runtime for cross-platform deployment.
import torch
import onnx
import onnxruntime as ort
import numpy as np
# Define a simple PyTorch model
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(3, 2)
def forward(self, x):
return self.linear(x)
# Instantiate and set model to eval mode
model = SimpleModel()
model.eval()
# Create dummy input tensor
dummy_input = torch.randn(1, 3)
# Export the model to ONNX format
onnx_model_path = "simple_model.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_model_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
opset_version=14
)
# Verify the ONNX model
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
print("ONNX model is valid.")
# Run inference with ONNX Runtime
ort_session = ort.InferenceSession(onnx_model_path)
# Prepare input as numpy array
input_data = dummy_input.numpy()
# Run the model
outputs = ort_session.run(None, {"input": input_data})
print("ONNX Runtime output:", outputs[0]) output
ONNX model is valid. ONNX Runtime output: [[-0.123456 0.789012]]
Common variations
You can export models from TensorFlow using tf2onnx or tensorflow-onnx. ONNX Runtime supports GPU acceleration and mobile platforms like Android and iOS. Use onnxruntime-gpu for GPU inference.
import onnxruntime as ort
# For GPU inference, create session with CUDA provider
ort_session = ort.InferenceSession("simple_model.onnx", providers=["CUDAExecutionProvider"])
# For CPU fallback
ort_session = ort.InferenceSession("simple_model.onnx", providers=["CPUExecutionProvider"])
print("Providers available:", ort.get_available_providers()) output
Providers available: ['CUDAExecutionProvider', 'CPUExecutionProvider']
Troubleshooting
- If
onnx.checker.check_model()fails, verify your model export parameters and opset version. - For missing operators errors, update
onnxruntimeto the latest version. - On mobile, ensure you use the appropriate
ONNX Runtimemobile package and follow platform-specific integration guides.
Key Takeaways
- Export models to
ONNXformat for standardized cross-platform compatibility. - Use
ONNX Runtimeto run inference efficiently on various devices including CPU, GPU, and mobile. - Validate exported models with
onnx.checkerto avoid runtime errors. - Leverage
ONNX Runtimeproviders to optimize performance for your target platform. - Keep
onnxruntimeupdated to support the latest operators and features.