ONNX supported frameworks
Quick answer
The main frameworks supporting
ONNX include PyTorch, TensorFlow, scikit-learn, and XGBoost. These frameworks provide tools to export models to ONNX format for interoperability and deployment across platforms.PREREQUISITES
Python 3.8+pip install onnx onnxruntimepip install torch torchvision (for PyTorch)pip install tensorflow (for TensorFlow)pip install scikit-learnpip install xgboost
Setup
Install the necessary packages to work with ONNX and supported frameworks. Use onnxruntime for running ONNX models.
pip install onnx onnxruntime torch torchvision tensorflow scikit-learn xgboost Step by step
Export a simple model from each supported framework to ONNX format and run inference with onnxruntime.
import torch
import torchvision.models as models
import onnxruntime as ort
import numpy as np
# PyTorch example: export ResNet18 to ONNX
model = models.resnet18(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "resnet18.onnx"
torch.onnx.export(model, dummy_input, onnx_path, opset_version=13)
# Load and run inference with ONNX Runtime
ort_session = ort.InferenceSession(onnx_path)
inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
outputs = ort_session.run(None, inputs)
print(f"PyTorch ONNX output shape: {outputs[0].shape}")
# TensorFlow example: export a simple model to ONNX
import tensorflow as tf
import tf2onnx
# Create a simple TF model
class SimpleModel(tf.Module):
@tf.function(input_signature=[tf.TensorSpec([None, 3], tf.float32)])
def __call__(self, x):
return tf.math.square(x)
model = SimpleModel()
# Convert to ONNX
spec = (tf.TensorSpec((None, 3), tf.float32, name="input"),)
output_path = "simple_tf_model.onnx"
model_proto, _ = tf2onnx.convert.from_function(
model.__call__, input_signature=spec, opset=13, output_path=output_path)
# Run inference
ort_session_tf = ort.InferenceSession(output_path)
input_data = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)
outputs_tf = ort_session_tf.run(None, {"input": input_data})
print(f"TensorFlow ONNX output: {outputs_tf[0]}")
# scikit-learn example: export a logistic regression model to ONNX
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
iris = load_iris()
X, y = iris.data, iris.target
model_sk = LogisticRegression(max_iter=1000)
model_sk.fit(X, y)
initial_type = [("input", FloatTensorType([None, X.shape[1]]))]
onnx_model = convert_sklearn(model_sk, initial_types=initial_type)
with open("logreg.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
# Run inference
ort_session_sk = ort.InferenceSession("logreg.onnx")
inputs_sk = {ort_session_sk.get_inputs()[0].name: X.astype(np.float32)}
outputs_sk = ort_session_sk.run(None, inputs_sk)
print(f"scikit-learn ONNX output shape: {outputs_sk[0].shape}") output
PyTorch ONNX output shape: (1, 1000) TensorFlow ONNX output: [[ 1. 4. 9.]] scikit-learn ONNX output shape: (150, 3)
Common variations
You can export models asynchronously or use different opset versions. Some frameworks like XGBoost also support ONNX export via onnxmltools. For inference, onnxruntime supports CPU and GPU execution.
import xgboost as xgb
import onnxmltools
from onnxmltools.convert import convert_xgboost
from onnxmltools.convert.common.data_types import FloatTensorType
import numpy as np
import onnxruntime as ort
# Train XGBoost model
X = np.random.rand(100, 4).astype(np.float32)
y = np.random.randint(0, 2, 100)
dtrain = xgb.DMatrix(X, label=y)
params = {"objective": "binary:logistic", "eval_metric": "logloss"}
model_xgb = xgb.train(params, dtrain, num_boost_round=10)
# Convert to ONNX
initial_type = [("input", FloatTensorType([None, 4]))]
onnx_model_xgb = convert_xgboost(model_xgb, initial_types=initial_type)
with open("xgboost.onnx", "wb") as f:
f.write(onnx_model_xgb.SerializeToString())
# Inference
ort_session_xgb = ort.InferenceSession("xgboost.onnx")
inputs_xgb = {ort_session_xgb.get_inputs()[0].name: X}
outputs_xgb = ort_session_xgb.run(None, inputs_xgb)
print(f"XGBoost ONNX output shape: {outputs_xgb[0].shape}") output
XGBoost ONNX output shape: (100, 1)
Troubleshooting
- If you see
opset_versionerrors, try upgradingonnxandonnxruntimeto the latest versions. - Model export failures often relate to unsupported operators; check framework documentation for supported opsets.
- For GPU inference, ensure
onnxruntime-gpuis installed and CUDA drivers are compatible.
Key Takeaways
-
ONNXenables model interoperability across major ML frameworks likePyTorch,TensorFlow, andscikit-learn. - Use
onnxruntimefor efficient cross-platform inference on CPU and GPU. - Export models with compatible
opset_versionto avoid conversion errors. - Framework-specific exporters and tools like
tf2onnxandonnxmltoolssimplify ONNX conversion. - Troubleshoot by updating packages and verifying operator support in your target
ONNXopset.