How to beginner · 3 min read

How to load a PyTorch model

Quick answer
Use torch.load to load a saved PyTorch model file and model.load_state_dict to restore the model weights. This allows you to reload the model architecture and parameters for inference or further training.

PREREQUISITES

  • Python 3.8+
  • pip install torch>=2.0

Setup

Install PyTorch if not already installed. Use the official PyTorch installation command for your environment.

bash
pip install torch torchvision

Step by step

This example shows how to define a simple model, save it, and then load it back using torch.load and load_state_dict.

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

# Instantiate and save the model
model = SimpleModel()

# Save the model state dict
PATH = "simple_model.pth"
torch.save(model.state_dict(), PATH)

# Load the model state dict into a new model instance
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load(PATH))
loaded_model.eval()

# Test loaded model with dummy input
input_tensor = torch.randn(1, 10)
output = loaded_model(input_tensor)
print("Output shape:", output.shape)
output
Output shape: torch.Size([1, 2])

Common variations

  • Load the entire model directly with torch.load(PATH) if saved with torch.save(model), but this is less flexible and not recommended.
  • Use map_location argument in torch.load to load models saved on GPU to CPU.
  • For models from torchvision or Hugging Face, use their specific loading utilities.
python
device = torch.device('cpu')
# Load model saved on GPU to CPU
loaded_model.load_state_dict(torch.load(PATH, map_location=device))

Troubleshooting

  • If you get a RuntimeError: size mismatch, ensure the model architecture matches the saved weights.
  • If torch.load fails with a CUDA error, use map_location='cpu' to load on CPU.
  • Always call model.eval() before inference to set dropout and batchnorm layers to evaluation mode.

Key Takeaways

  • Use torch.save(model.state_dict()) and model.load_state_dict(torch.load()) for flexible model saving/loading.
  • Always match the model class definition when loading weights to avoid size mismatch errors.
  • Use map_location in torch.load to handle device differences between saving and loading.
  • Call model.eval() before inference to ensure correct behavior of layers like dropout.
  • Avoid saving the entire model object directly for better portability and version compatibility.
Verified 2026-04 · torch.load, load_state_dict
Verify ↗