How to use pretrained model in PyTorch
Quick answer
Use
torchvision.models or other PyTorch model libraries to load pretrained models with pretrained=True. Then, run inference by passing input tensors through the model or fine-tune by modifying the final layers and training on your dataset.PREREQUISITES
Python 3.8+pip install torch torchvision
Setup
Install PyTorch and torchvision to access pretrained models. Use the following command to install the latest stable versions:
pip install torch torchvision Step by step
This example loads a pretrained ResNet-18 model, prepares a dummy input tensor, runs inference, and prints the output tensor shape.
import torch
import torchvision.models as models
# Load pretrained ResNet-18 model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.eval() # Set to evaluation mode
# Create a dummy input tensor (batch_size=1, channels=3, height=224, width=224)
input_tensor = torch.randn(1, 3, 224, 224)
# Run inference
with torch.no_grad():
output = model(input_tensor)
print('Output tensor shape:', output.shape) output
Output tensor shape: torch.Size([1, 1000])
Common variations
To fine-tune a pretrained model, replace the final classification layer and train on your dataset. You can also load pretrained weights for other architectures like VGG, DenseNet, or custom models.
import torch.nn as nn
# Replace final layer for 10 classes
model.fc = nn.Linear(model.fc.in_features, 10)
# Example: switch to VGG16 pretrained model
vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
vgg16.eval() Troubleshooting
- If you get a warning about deprecated
pretrained=True, useweights=models.ResNet18_Weights.DEFAULTinstead. - Ensure input tensor shape matches model requirements (usually 3x224x224 for ImageNet models).
- Set model to
eval()mode for inference to disable dropout and batchnorm updates.
Key Takeaways
- Use torchvision's pretrained models by setting pretrained=True or weights= parameter.
- Always set the model to eval() mode before inference to get consistent results.
- Modify the final layer to fine-tune pretrained models on your own classes.
- Input tensors must have the correct shape and normalization for the pretrained model.
- Use torch.no_grad() context to save memory and computation during inference.