What is transfer learning in deep learning
model on a large dataset is reused and fine-tuned on a smaller, related dataset to improve performance and reduce training time. It leverages learned features from the original task to accelerate learning on a new task.How it works
Transfer learning works by taking a neural network trained on a large dataset (source task) and adapting it to a new, often smaller dataset (target task). The early layers of the network, which learn general features like edges and textures, are kept mostly intact, while the later layers are fine-tuned to the specifics of the new task. This is similar to how humans apply prior knowledge to learn new skills faster.
Concrete example
Here is a PyTorch example using transfer learning with a pre-trained ResNet model for image classification on a new dataset:
import torch
import torch.nn as nn
import torchvision.models as models
# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace the final fully connected layer for 10 classes
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)
# Only parameters of the final layer will be updated
params_to_update = [p for p in model.parameters() if p.requires_grad]
print("Parameters to update:")
for name, param in model.named_parameters():
if param.requires_grad:
print(name) Parameters to update: fc.weight fc.bias
When to use it
Use transfer learning when you have limited labeled data for a new task but a related large dataset exists for pre-training. It speeds up training and improves accuracy. Avoid it when the new task is very different from the original task or when you have abundant data to train from scratch.
Key Takeaways
- Transfer learning leverages pre-trained models to reduce training time and improve performance on new tasks.
- Freezing early layers and fine-tuning later layers is a common transfer learning strategy in PyTorch.
- Use transfer learning when new task data is limited but related to the pre-trained model's domain.