torch.save(): saving model state
Why this matters
Training takes hours. You need to checkpoint your work, share trained models with teammates, and deploy to production without retraining. torch.save() is how you freeze your model's learned parameters to disk.
Explanation
torch.save() writes PyTorch objects (models, tensors, optimizer state, arbitrary Python dicts) to disk using pickle. It's the standard way to persist trained weights. Mechanically, it serializes the object into a binary file that torch.load() can later deserialize back into memory with identical structure and values. You typically save either the full model object (torch.save(model, 'model.pt')) or just its state dict (torch.save(model.state_dict(), 'weights.pt')). The state dict approach is preferred in production because it decouples saved weights from your code: if you refactor the model class, you can still load old weights by re-creating the architecture and calling model.load_state_dict(). Saving full model objects is convenient for prototyping but couples serialization to your class definition and can break across versions.
Analogy
Saving just the state dict is like saving a blueprint's measurements separately from the blueprint tool itself: you can rebuild the structure with a different tool later. Saving the whole model is like taking a photo of your workspace: it works now, but if your tools change, the photo becomes hard to interpret.
Code
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = SimpleNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
X = torch.randn(4, 10)
y = torch.randint(0, 2, (4,))
loss_fn = nn.CrossEntropyLoss()
for epoch in range(2):
optimizer.zero_grad()
logits = model(X)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1} loss: {loss.item():.4f}")
checkpoint = {
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'epoch': 2,
}
torch.save(checkpoint, 'checkpoint.pt')
print("\nCheckpoint saved.")
loaded_checkpoint = torch.load('checkpoint.pt')
model_new = SimpleNet()
model_new.load_state_dict(loaded_checkpoint['model_state'])
optimizer_new = torch.optim.Adam(model_new.parameters(), lr=0.001)
optimizer_new.load_state_dict(loaded_checkpoint['optimizer_state'])
print(f"Loaded checkpoint from epoch {loaded_checkpoint['epoch']}")
with torch.no_grad():
test_input = torch.randn(1, 10)
output_original = model(test_input)
output_loaded = model_new(test_input)
print(f"\nOutputs match: {torch.allclose(output_original, output_loaded)}") Epoch 1 loss: 0.7125 Epoch 2 loss: 0.6891 Checkpoint saved. Loaded checkpoint from epoch 2 Outputs match: True
What just happened?
We created a small neural network, trained it for 2 epochs, then saved a checkpoint dict containing the model weights, optimizer state, and epoch number to disk. We then created a fresh model instance and loaded the saved state back into it, verifying that the restored model produces identical outputs to the original.
Common gotcha
Saving model.state_dict() is not the same as saving model directly. If you save the full model object and later refactor your model class (rename a layer, remove a parameter, etc.), torch.load() may fail or load weights incorrectly. Always save state_dict() in production, even though saving the whole model feels more convenient. The other trap: forgetting to call model.eval() before saving if your model uses dropout or batch norm: the loaded model will inherit training mode and give different inference results.
Error recovery
RuntimeError: Couldn't deserialize object due to incomplete pickle dataKeyError when calling model.load_state_dict()TypeError: cannot pickle '_thread.RLock' objectExperienced dev note
Always save state_dict(), not the full model. But here's what most developers miss: save more than weights. Your checkpoint dict should include the epoch number, learning rate, git commit hash, and any other hyperparameters. When you load it 6 months later, you'll know exactly what settings were used. Also: use meaningful filenames with timestamps or model version numbers. Saving to a single 'model.pt' that you overwrite every iteration is how you accidentally load a broken checkpoint after a bad training run.
Check your understanding
You train a model, save its state_dict to disk, then load it into a completely new instance of your model class that you instantiated from scratch. Why is calling model.load_state_dict() safe here, whereas loading the full model object with torch.load() might fail?
Show answer hint
The state_dict contains only weight values and buffer names, not the class definition. As long as the class architecture is the same, you can pair any weights with any instance. Loading the full model object requires the exact same class code to exist where torch.load() executes, so if you refactored the class, unpickling fails.