How to beginner · 3 min read

How to save only model weights in PyTorch

Quick answer
Use torch.save(model.state_dict(), filepath) to save only the model weights in PyTorch. This saves the state_dict, which contains all learnable parameters, allowing efficient storage and later loading with model.load_state_dict().

PREREQUISITES

  • Python 3.8+
  • pip install torch>=2.0

Setup

Install PyTorch if you haven't already. Use the official command from PyTorch installation guide. For example, to install the latest stable version with CPU support:

bash
pip install torch torchvision torchaudio

Step by step

This example demonstrates how to save only the model weights (parameters) using state_dict() and then load them back into the model.

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

# Instantiate the model
model = SimpleModel()

# Save only the model weights (state_dict)
save_path = 'model_weights.pth'
torch.save(model.state_dict(), save_path)
print(f'Model weights saved to {save_path}')

# To load weights back
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load(save_path))
loaded_model.eval()
print('Model weights loaded successfully')
output
Model weights saved to model_weights.pth
Model weights loaded successfully

Common variations

  • Save weights on GPU and load on CPU by adding map_location=torch.device('cpu') in torch.load().
  • Save weights asynchronously or with compression using third-party libraries.
  • Use torch.jit for saving scripted models but that saves the full model, not just weights.
python
import torch

# Save weights on GPU
model = SimpleModel().to('cuda')
torch.save(model.state_dict(), 'gpu_weights.pth')

# Load weights on CPU
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('gpu_weights.pth', map_location=torch.device('cpu')))
loaded_model.eval()
output
No output (runs without error)

Troubleshooting

  • If you get a RuntimeError: size mismatch when loading weights, ensure the model architecture matches exactly.
  • If loading on a different device, use map_location to avoid device mismatch errors.
  • Always call model.eval() after loading weights for inference to set dropout/batchnorm layers correctly.

Key Takeaways

  • Use model.state_dict() with torch.save() to save only model weights efficiently.
  • Load weights with model.load_state_dict(torch.load(filepath)) ensuring model architecture matches.
  • Use map_location in torch.load() to handle device differences between saving and loading.
  • Call model.eval() after loading weights for correct inference behavior.
  • Saving only weights reduces file size and improves flexibility compared to saving the entire model.
Verified 2026-04
Verify ↗