How to beginner · 3 min read

How to resume training from checkpoint PyTorch

Quick answer
To resume training in PyTorch, load the saved checkpoint dictionary using torch.load(), then restore the model and optimizer states with model.load_state_dict() and optimizer.load_state_dict(). Also, restore the training epoch or iteration count to continue training from where it left off.

PREREQUISITES

  • Python 3.8+
  • PyTorch 1.10+
  • Basic knowledge of PyTorch training loop

Setup

Ensure you have PyTorch installed. You can install it via pip if needed:

pip install torch torchvision

Also, prepare your model and optimizer as usual before loading the checkpoint.

bash
pip install torch torchvision

Step by step

This example demonstrates saving a checkpoint during training and resuming from it later by restoring the model, optimizer, and epoch.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)
    def forward(self, x):
        return self.linear(x)

# Initialize model and optimizer
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Suppose this is the checkpoint path
checkpoint_path = 'checkpoint.pth'

# Simulate saving checkpoint after some training
epoch = 5
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, checkpoint_path)

# Later, to resume training:
# 1. Load checkpoint
checkpoint = torch.load(checkpoint_path)

# 2. Restore model and optimizer states
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 3. Restore epoch
start_epoch = checkpoint['epoch']

print(f"Resuming training from epoch {start_epoch}")
output
Resuming training from epoch 5

Common variations

  • Use map_location in torch.load() to load checkpoints saved on GPU to CPU or vice versa.
  • Save and load learning rate schedulers similarly by adding their state dict to the checkpoint.
  • For distributed training, ensure to load checkpoint only on the main process or handle accordingly.
python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

Troubleshooting

  • If you get a RuntimeError about missing keys, verify the model architecture matches the checkpoint.
  • If optimizer states fail to load, ensure optimizer was initialized with the same parameters.
  • For CUDA errors, use map_location='cpu' in torch.load() when loading GPU checkpoints on CPU machines.

Key Takeaways

  • Always save and load both model and optimizer state dicts to resume training correctly.
  • Restore the training epoch or iteration count from the checkpoint to continue seamlessly.
  • Use map_location in torch.load to handle device mismatches between saving and loading.
  • Ensure model architecture matches checkpoint to avoid state dict loading errors.
Verified 2026-04 · torch.load, model.load_state_dict, optimizer.load_state_dict
Verify ↗