How to fine-tune a pretrained model in PyTorch
Quick answer
To fine-tune a pretrained model in
PyTorch, load the model with pretrained weights, replace the final layer(s) to match your task, freeze or unfreeze layers as needed, then train on your dataset using a suitable optimizer and loss function. Use torchvision.models or Hugging Face transformers for pretrained models and customize the training loop accordingly.PREREQUISITES
Python 3.8+pip install torch torchvisionBasic knowledge of PyTorch tensors and training loops
Setup
Install torch and torchvision if not already installed. Import necessary modules and set device for GPU acceleration if available.
pip install torch torchvision Step by step
This example fine-tunes a pretrained ResNet18 model on a custom dataset with 2 classes. It replaces the final fully connected layer, freezes the pretrained layers initially, then trains the model for a few epochs.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data transforms and loaders
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Replace with your dataset path
train_dataset = datasets.FakeData(transform=transform) # Using FakeData for demo
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
# Load pretrained model
model = models.resnet18(pretrained=True)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace the final layer for 2 classes
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
model = model.to(device)
# Only parameters of final layer are being optimized
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# Training loop
model.train()
for epoch in range(3): # 3 epochs for demo
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(train_loader.dataset)
print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}') output
Epoch 1, Loss: 0.7102 Epoch 2, Loss: 0.6905 Epoch 3, Loss: 0.6801
Common variations
- Unfreeze some pretrained layers for fine-tuning after initial training.
- Use different pretrained models like
resnet50,efficientnet, or Hugging Facetransformersmodels. - Use learning rate schedulers to improve training.
- Implement validation loop and save best model checkpoints.
import torch
from torchvision import models
# Load pretrained ResNet50
model = models.resnet50(pretrained=True)
# Unfreeze last two layers
for name, param in model.named_parameters():
if 'layer4' in name or 'fc' in name:
param.requires_grad = True
else:
param.requires_grad = False
# Replace final layer
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 10) # For 10 classes
# Continue with optimizer, loss, and training loop as before Troubleshooting
- If you get shape mismatch errors, verify the final layer output size matches your dataset classes.
- If training loss does not decrease, try lowering learning rate or unfreezing more layers.
- Use
model.train()andmodel.eval()modes correctly during training and validation. - Ensure input images are resized and normalized as expected by the pretrained model.
Key Takeaways
- Load pretrained models from torchvision or Hugging Face and replace final layers for your task.
- Freeze pretrained layers initially to train only new layers, then optionally unfreeze for fine-tuning.
- Use proper data transforms matching the pretrained model's expected input normalization.
- Implement a training loop with optimizer and loss, monitoring loss to ensure learning.
- Adjust learning rate and layer freezing if training stagnates or overfits.