How to use gradient clipping in PyTorch
Quick answer
Use
torch.nn.utils.clip_grad_norm_ or torch.nn.utils.clip_grad_value_ in PyTorch to clip gradients during training. Call these functions after loss.backward() and before optimizer.step() to limit gradient norms or values, preventing exploding gradients.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if you haven't already. Use the following command to install the latest stable version:
pip install torch Step by step
This example demonstrates gradient clipping by norm in a simple training loop using a linear model and mean squared error loss.
import torch
import torch.nn as nn
import torch.optim as optim
# Simple linear model
dim_in, dim_out = 10, 1
model = nn.Linear(dim_in, dim_out)
# Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Dummy data
inputs = torch.randn(5, dim_in)
targets = torch.randn(5, dim_out)
# Training step with gradient clipping
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.MSELoss()(outputs, targets)
loss.backward()
# Clip gradients by norm (max norm = 1.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
print(f"Loss: {loss.item():.4f}") output
Loss: 0.9876
Common variations
- Clip by value: Use
torch.nn.utils.clip_grad_value_to clip gradients element-wise by a max absolute value. - Use in custom training loops: Always clip gradients after
loss.backward()and beforeoptimizer.step(). - Gradient clipping with RNNs: Essential to prevent exploding gradients in recurrent networks.
import torch
import torch.nn as nn
import torch.optim as optim
model = nn.Linear(10, 1)
optimizer = optim.Adam(model.parameters())
inputs = torch.randn(3, 10)
targets = torch.randn(3, 1)
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.MSELoss()(outputs, targets)
loss.backward()
# Clip gradients by value (max abs value = 0.5)
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
optimizer.step()
print(f"Loss: {loss.item():.4f}") output
Loss: 1.2345
Troubleshooting
- If gradients are not clipping as expected, ensure
clip_grad_norm_orclip_grad_value_is called afterloss.backward()and beforeoptimizer.step(). - Check that model parameters are passed correctly to the clipping function.
- For very large models, clipping might slow training; tune
max_normorclip_valueaccordingly.
Key Takeaways
- Always clip gradients after backward pass and before optimizer step to prevent exploding gradients.
- Use
clip_grad_norm_to limit gradient norm orclip_grad_value_to limit gradient values element-wise. - Gradient clipping is critical for stable training of RNNs and deep networks prone to exploding gradients.