How to save a PyTorch model
Quick answer
Use
torch.save() to save a PyTorch model's state dictionary or the entire model object. Typically, save the model.state_dict() for flexibility and reload it later with model.load_state_dict().PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if you haven't already. Use the official command from PyTorch installation guide. For example, to install CPU-only version:
pip install torch torchvision torchaudiopip install torch torchvision torchaudio Step by step
This example shows how to save and load a PyTorch model's state dictionary, which is the recommended approach for saving models.
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 2)
def forward(self, x):
return self.linear(x)
# Instantiate the model
model = SimpleModel()
# Save the model's state_dict
save_path = 'model_state.pth'
torch.save(model.state_dict(), save_path)
print(f'Model state_dict saved to {save_path}')
# Load the model state_dict into a new model instance
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load(save_path))
loaded_model.eval() # Set to evaluation mode
# Verify loaded model works
input_tensor = torch.randn(1, 10)
output = loaded_model(input_tensor)
print('Output from loaded model:', output) output
Model state_dict saved to model_state.pth Output from loaded model: tensor([[ 0.1234, -0.5678]], grad_fn=<AddmmBackward0>)
Common variations
You can also save the entire model object directly with torch.save(model, PATH), but this is less flexible and not recommended for production. For distributed training, save checkpoints including optimizer state. For async saving, use Python's threading or multiprocessing.
import torch
# Save entire model (not recommended for production)
torch.save(model, 'full_model.pth')
# Load entire model
loaded_full_model = torch.load('full_model.pth')
loaded_full_model.eval() Troubleshooting
- If you get a
RuntimeError: size mismatchwhen loading state_dict, ensure the model architecture matches exactly. - If loading on CPU a model saved on GPU, use
torch.load(PATH, map_location=torch.device('cpu')). - Always call
model.eval()before inference to set dropout and batchnorm layers correctly.
Key Takeaways
- Use torch.save(model.state_dict(), PATH) to save model weights efficiently.
- Load weights with model.load_state_dict(torch.load(PATH)) and call model.eval() before inference.
- Saving the entire model object is possible but less portable and not recommended.
- Match model architecture exactly when loading saved weights to avoid errors.
- Use map_location argument to load GPU-trained models on CPU environments.