How to resume training from checkpoint PyTorch
Quick answer
To resume training in
PyTorch, load the saved checkpoint dictionary using torch.load(), then restore the model and optimizer states with model.load_state_dict() and optimizer.load_state_dict(). Also, restore the training epoch or iteration count to continue training from where it left off.PREREQUISITES
Python 3.8+PyTorch 1.10+Basic knowledge of PyTorch training loop
Setup
Ensure you have PyTorch installed. You can install it via pip if needed:
pip install torch torchvisionAlso, prepare your model and optimizer as usual before loading the checkpoint.
pip install torch torchvision Step by step
This example demonstrates saving a checkpoint during training and resuming from it later by restoring the model, optimizer, and epoch.
import torch
import torch.nn as nn
import torch.optim as optim
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# Initialize model and optimizer
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Suppose this is the checkpoint path
checkpoint_path = 'checkpoint.pth'
# Simulate saving checkpoint after some training
epoch = 5
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, checkpoint_path)
# Later, to resume training:
# 1. Load checkpoint
checkpoint = torch.load(checkpoint_path)
# 2. Restore model and optimizer states
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 3. Restore epoch
start_epoch = checkpoint['epoch']
print(f"Resuming training from epoch {start_epoch}") output
Resuming training from epoch 5
Common variations
- Use
map_locationintorch.load()to load checkpoints saved on GPU to CPU or vice versa. - Save and load learning rate schedulers similarly by adding their state dict to the checkpoint.
- For distributed training, ensure to load checkpoint only on the main process or handle accordingly.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) Troubleshooting
- If you get a
RuntimeErrorabout missing keys, verify the model architecture matches the checkpoint. - If optimizer states fail to load, ensure optimizer was initialized with the same parameters.
- For CUDA errors, use
map_location='cpu'intorch.load()when loading GPU checkpoints on CPU machines.
Key Takeaways
- Always save and load both model and optimizer state dicts to resume training correctly.
- Restore the training epoch or iteration count from the checkpoint to continue seamlessly.
- Use map_location in torch.load to handle device mismatches between saving and loading.
- Ensure model architecture matches checkpoint to avoid state dict loading errors.