How to do hyperparameter tuning with GridSearchCV
Quick answer
Use
GridSearchCV from scikit-learn to tune hyperparameters of PyTorch models by wrapping the model in a sklearn.base.BaseEstimator compatible class. Define the parameter grid and fit GridSearchCV to find the best hyperparameters automatically.PREREQUISITES
Python 3.8+pip install torch scikit-learn numpy
Setup
Install required packages with pip install torch scikit-learn numpy. Import necessary modules including torch, sklearn.model_selection.GridSearchCV, and sklearn.base.BaseEstimator to create a PyTorch model wrapper compatible with Scikit-learn.
pip install torch scikit-learn numpy Step by step
Create a PyTorch model class and wrap it in a BaseEstimator and ClassifierMixin subclass to implement fit, predict, and score methods. Define a parameter grid for hyperparameters like learning rate and number of epochs. Use GridSearchCV to search over the grid and fit on training data.
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Define a simple PyTorch model
class SimpleNet(nn.Module):
def __init__(self, input_dim, hidden_dim=10):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, 2) # binary classification
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Wrapper to use PyTorch model with sklearn
class PyTorchClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, input_dim=20, hidden_dim=10, lr=0.01, epochs=10, batch_size=32):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.lr = lr
self.epochs = epochs
self.batch_size = batch_size
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self._build_model()
def _build_model(self):
self.model = SimpleNet(self.input_dim, self.hidden_dim).to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
def fit(self, X, y):
self._build_model() # rebuild model for each fit
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
y_tensor = torch.tensor(y, dtype=torch.long).to(self.device)
dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
self.model.train()
for epoch in range(self.epochs):
for xb, yb in loader:
self.optimizer.zero_grad()
outputs = self.model(xb)
loss = self.criterion(outputs, yb)
loss.backward()
self.optimizer.step()
return self
def predict(self, X):
self.model.eval()
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
with torch.no_grad():
outputs = self.model(X_tensor)
_, preds = torch.max(outputs, 1)
return preds.cpu().numpy()
def score(self, X, y):
preds = self.predict(X)
return accuracy_score(y, preds)
# Generate synthetic data
X, y = make_classification(n_samples=500, n_features=20, n_informative=15, n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define parameter grid
param_grid = {
'hidden_dim': [10, 20],
'lr': [0.01, 0.001],
'epochs': [10, 20]
}
# Initialize GridSearchCV
clf = GridSearchCV(PyTorchClassifier(input_dim=20), param_grid, cv=3, scoring='accuracy', verbose=1)
# Fit GridSearchCV
clf.fit(X_train, y_train)
# Best parameters and score
print('Best params:', clf.best_params_)
print('Best CV accuracy:', clf.best_score_)
# Evaluate on test set
test_acc = clf.score(X_test, y_test)
print('Test accuracy:', test_acc) output
Fitting 3 folds for each of 8 candidates, totalling 24 fits
Best params: {'epochs': 20, 'hidden_dim': 20, 'lr': 0.001}
Best CV accuracy: 0.91
Test accuracy: 0.92 Common variations
- Use
RandomizedSearchCVfor faster hyperparameter tuning with random sampling. - Wrap more complex PyTorch models or include dropout, batch norm layers.
- Use GPU acceleration by setting
device=torch.device('cuda')in the wrapper. - Integrate early stopping by modifying the
fitmethod.
Troubleshooting
- If you get shape mismatch errors, verify input feature dimensions match
input_dim. - If training is slow, reduce
epochsor batch size. - Ensure
fitrebuilds the model to reset weights for each hyperparameter combination. - Check CUDA availability and device assignment if GPU is intended.
Key Takeaways
- Wrap PyTorch models in sklearn-compatible classes to use GridSearchCV.
- Define a clear parameter grid for hyperparameters like learning rate and epochs.
- Rebuild the model inside fit to reset weights for each parameter set.
- Use GridSearchCV's cv and scoring parameters to control validation and metrics.
- Test on a holdout set after tuning to confirm generalization performance.