What is overfitting in machine learning
Overfitting in machine learning occurs when a model learns the training data too well, capturing noise and details that hurt its ability to generalize to new data. It results in low training error but high validation or test error, indicating poor performance on unseen data.Overfitting is a machine learning phenomenon where a model fits the training data too closely, reducing its ability to generalize to new inputs.How it works
Overfitting happens when a model becomes excessively complex, learning not only the underlying patterns but also the random noise in the training dataset. Imagine memorizing answers to a practice test instead of understanding the concepts; the model performs well on the practice test but poorly on the actual exam. This reduces its ability to generalize to new, unseen data.
Concrete example
The following PyTorch example trains a simple neural network on synthetic data and shows how overfitting manifests as a divergence between training and validation loss.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Generate synthetic data
torch.manual_seed(0)
X = torch.linspace(-1, 1, 100).unsqueeze(1)
y = X.pow(3) + 0.1 * torch.randn(X.size()) # Cubic function with noise
# Split into train and validation
train_X, val_X = X[:80], X[80:]
train_y, val_y = y[:80], y[80:]
train_ds = TensorDataset(train_X, train_y)
val_ds = TensorDataset(val_X, val_y)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16)
# Define a simple over-parameterized model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, 100),
nn.ReLU(),
nn.Linear(100, 100),
nn.ReLU(),
nn.Linear(100, 1)
)
def forward(self, x):
return self.net(x)
model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
train_losses = []
val_losses = []
for epoch in range(100):
model.train()
train_loss = 0
for xb, yb in train_loader:
optimizer.zero_grad()
preds = model(xb)
loss = criterion(preds, yb)
loss.backward()
optimizer.step()
train_loss += loss.item() * xb.size(0)
train_loss /= len(train_loader.dataset)
train_losses.append(train_loss)
model.eval()
val_loss = 0
with torch.no_grad():
for xb, yb in val_loader:
preds = model(xb)
loss = criterion(preds, yb)
val_loss += loss.item() * xb.size(0)
val_loss /= len(val_loader.dataset)
val_losses.append(val_loss)
import matplotlib.pyplot as plt
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Overfitting example: Training vs Validation Loss')
plt.show() When to use it
Overfitting is a problem to avoid rather than a technique to use. You want to prevent overfitting when training models on limited or noisy data to ensure good generalization. Use regularization, early stopping, simpler models, or more data to combat overfitting. Avoid overfitting when your goal is robust performance on unseen data, not perfect training accuracy.
Key terms
| Term | Definition |
|---|---|
| Overfitting | Model fits training data too closely, capturing noise and losing generalization. |
| Generalization | Model's ability to perform well on unseen data. |
| Regularization | Techniques to reduce overfitting by constraining model complexity. |
| Early stopping | Halting training when validation loss stops improving to prevent overfitting. |
Key Takeaways
- Overfitting causes low training error but high validation error, indicating poor generalization.
- Monitor training and validation loss curves to detect overfitting during model training.
- Use simpler models, regularization, or early stopping to prevent overfitting in PyTorch.
- Overfitting is a critical issue when working with small or noisy datasets.
- Understanding overfitting helps build models that perform well on real-world data.