Concept beginner · 3 min read

What is transfer learning in deep learning

Quick answer
Transfer learning in deep learning is a technique where a pre-trained model on a large dataset is reused and fine-tuned on a smaller, related dataset to improve performance and reduce training time. It leverages learned features from the original task to accelerate learning on a new task.
Transfer learning is a deep learning technique that reuses a pre-trained model's knowledge to improve learning efficiency on a new but related task.

How it works

Transfer learning works by taking a neural network trained on a large dataset (source task) and adapting it to a new, often smaller dataset (target task). The early layers of the network, which learn general features like edges and textures, are kept mostly intact, while the later layers are fine-tuned to the specifics of the new task. This is similar to how humans apply prior knowledge to learn new skills faster.

Concrete example

Here is a PyTorch example using transfer learning with a pre-trained ResNet model for image classification on a new dataset:

python
import torch
import torch.nn as nn
import torchvision.models as models

# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace the final fully connected layer for 10 classes
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

# Only parameters of the final layer will be updated
params_to_update = [p for p in model.parameters() if p.requires_grad]

print("Parameters to update:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
output
Parameters to update:
fc.weight
fc.bias

When to use it

Use transfer learning when you have limited labeled data for a new task but a related large dataset exists for pre-training. It speeds up training and improves accuracy. Avoid it when the new task is very different from the original task or when you have abundant data to train from scratch.

Key Takeaways

  • Transfer learning leverages pre-trained models to reduce training time and improve performance on new tasks.
  • Freezing early layers and fine-tuning later layers is a common transfer learning strategy in PyTorch.
  • Use transfer learning when new task data is limited but related to the pre-trained model's domain.
Verified 2026-04 · resnet18
Verify ↗