How to load a PyTorch model
Quick answer
Use
torch.load to load a saved PyTorch model file and model.load_state_dict to restore the model weights. This allows you to reload the model architecture and parameters for inference or further training.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if not already installed. Use the official PyTorch installation command for your environment.
pip install torch torchvision Step by step
This example shows how to define a simple model, save it, and then load it back using torch.load and load_state_dict.
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 2)
def forward(self, x):
return self.linear(x)
# Instantiate and save the model
model = SimpleModel()
# Save the model state dict
PATH = "simple_model.pth"
torch.save(model.state_dict(), PATH)
# Load the model state dict into a new model instance
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load(PATH))
loaded_model.eval()
# Test loaded model with dummy input
input_tensor = torch.randn(1, 10)
output = loaded_model(input_tensor)
print("Output shape:", output.shape) output
Output shape: torch.Size([1, 2])
Common variations
- Load the entire model directly with
torch.load(PATH)if saved withtorch.save(model), but this is less flexible and not recommended. - Use
map_locationargument intorch.loadto load models saved on GPU to CPU. - For models from
torchvisionor Hugging Face, use their specific loading utilities.
device = torch.device('cpu')
# Load model saved on GPU to CPU
loaded_model.load_state_dict(torch.load(PATH, map_location=device)) Troubleshooting
- If you get a
RuntimeError: size mismatch, ensure the model architecture matches the saved weights. - If
torch.loadfails with a CUDA error, usemap_location='cpu'to load on CPU. - Always call
model.eval()before inference to set dropout and batchnorm layers to evaluation mode.
Key Takeaways
- Use
torch.save(model.state_dict())andmodel.load_state_dict(torch.load())for flexible model saving/loading. - Always match the model class definition when loading weights to avoid size mismatch errors.
- Use
map_locationintorch.loadto handle device differences between saving and loading. - Call
model.eval()before inference to ensure correct behavior of layers like dropout. - Avoid saving the entire model object directly for better portability and version compatibility.