How to track hyperparameters with MLflow
Quick answer
Use
mlflow.log_params() to record hyperparameters during your training runs in MLflow. This allows you to track, compare, and reproduce experiments efficiently within the MLflow UI or API.PREREQUISITES
Python 3.8+pip install mlflow>=2.0Basic familiarity with Python machine learning scripts
Setup
Install MLflow and set up your environment to enable experiment tracking.
pip install mlflow Step by step
This example demonstrates how to track hyperparameters and metrics in MLflow during a simple model training run.
import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# Define hyperparameters
params = {
"n_estimators": 100,
"max_depth": 3,
"random_state": 42
}
# Start MLflow run
with mlflow.start_run():
# Log hyperparameters
mlflow.log_params(params)
# Train model
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# Predict and evaluate
preds = model.predict(X_test)
acc = accuracy_score(y_test, preds)
# Log metric
mlflow.log_metric("accuracy", acc)
print(f"Logged params: {params}")
print(f"Logged accuracy: {acc}") output
Logged params: {'n_estimators': 100, 'max_depth': 3, 'random_state': 42}
Logged accuracy: 1.0 Common variations
You can track hyperparameters asynchronously or with different MLflow tracking servers. Also, use mlflow.log_param() for single parameters or mlflow.log_params() for multiple. For advanced use, integrate MLflow with frameworks like PyTorch or TensorFlow.
import mlflow
# Log a single hyperparameter
mlflow.log_param("learning_rate", 0.01)
# Log multiple hyperparameters
mlflow.log_params({"batch_size": 32, "epochs": 10}) Troubleshooting
If you don't see parameters in the MLflow UI, ensure your mlflow.start_run() context is active when logging. Also, verify your tracking URI is correctly set with mlflow.set_tracking_uri() if using a remote server.
import mlflow
# Set tracking URI if using remote server
mlflow.set_tracking_uri("http://your-mlflow-server:5000") Key Takeaways
- Use
mlflow.log_params()to track multiple hyperparameters in one call. - Always wrap logging calls inside
mlflow.start_run()to ensure data is recorded. - Set the MLflow tracking URI properly when using remote or cloud tracking servers.