How to beginner · 3 min read

How to save model checkpoint in PyTorch

Quick answer
Use torch.save() to save the model's state_dict as a checkpoint file and torch.load() to load it back. This allows you to persist and resume training by saving model weights and optimizer states.

PREREQUISITES

  • Python 3.8+
  • PyTorch 1.10+ (pip install torch)

Setup

Install PyTorch if you haven't already. Use the official command from PyTorch installation guide. For CPU-only:

pip install torch torchvision
bash
pip install torch torchvision

Step by step

This example shows how to save and load a model checkpoint including the model's weights and optimizer state.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

# Initialize model and optimizer
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Dummy training step
inputs = torch.randn(5, 10)
outputs = model(inputs)
loss = outputs.sum()
loss.backward()
optimizer.step()

# Save checkpoint
checkpoint_path = 'model_checkpoint.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': 1
}, checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}")

# Load checkpoint
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print(f"Checkpoint loaded, resuming from epoch {epoch}")
output
Checkpoint saved to model_checkpoint.pth
Checkpoint loaded, resuming from epoch 1

Common variations

  • Save only the model weights with torch.save(model.state_dict(), 'model_weights.pth').
  • Load weights with model.load_state_dict(torch.load('model_weights.pth')).
  • Save checkpoints asynchronously or during training loops.
  • Use map_location in torch.load() to load GPU-trained models on CPU.
python
model_weights_path = 'model_weights.pth'
torch.save(model.state_dict(), model_weights_path)
print(f"Model weights saved to {model_weights_path}")

# Loading weights only
model2 = SimpleModel()
model2.load_state_dict(torch.load(model_weights_path))
print("Model weights loaded into new model instance")
output
Model weights saved to model_weights.pth
Model weights loaded into new model instance

Troubleshooting

  • If you get a RuntimeError about missing keys when loading state dict, ensure the model architecture matches exactly.
  • Use map_location=torch.device('cpu') in torch.load() if loading GPU checkpoints on CPU-only machines.
  • Check file paths and permissions if FileNotFoundError occurs.

Key Takeaways

  • Use torch.save() with a dictionary to save model and optimizer states together.
  • Load checkpoints with torch.load() and restore states to resume training seamlessly.
  • Saving only model.state_dict() is sufficient for inference or fine-tuning.
  • Use map_location in torch.load() to handle device mismatches between saving and loading.
  • Always verify model architecture matches when loading checkpoints to avoid errors.
Verified 2026-04
Verify ↗