RuntimeError
torch.cuda.runtime.RuntimeError
Stack trace
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate values of the first backward are already freed. Specify retain_graph=True when calling backward the first time.
Why it happens
This error occurs because gradient checkpointing saves memory by freeing intermediate activations during the forward pass. If the backward pass is called multiple times without retain_graph=True, or if the model's forward method is incompatible with checkpointing, PyTorch raises this RuntimeError.
Detection
Monitor training logs for RuntimeError mentioning backward graph retention or checkpointing issues; wrap training steps in try/except to log detailed error messages.
Causes & fixes
Gradient checkpointing enabled but model's forward method returns multiple outputs or uses unsupported operations.
Modify the model's forward method to return a single tensor or tuple compatible with checkpointing, or disable gradient checkpointing.
Calling backward multiple times without retain_graph=True when using gradient checkpointing.
Add retain_graph=True to the first backward() call or restructure training loop to avoid multiple backward passes per forward.
Using an older PyTorch or PEFT version with known checkpointing bugs.
Upgrade PyTorch and PEFT to the latest stable versions where checkpointing issues are fixed.
Code: broken vs fixed
from peft import get_peft_model
import torch
model = get_peft_model(base_model, peft_config)
model.gradient_checkpointing_enable() # enables checkpointing
optimizer = torch.optim.Adam(model.parameters())
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward() # RuntimeError here
optimizer.step() import os
from peft import get_peft_model
import torch
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
model = get_peft_model(base_model, peft_config)
# Disabled gradient checkpointing to fix RuntimeError
# model.gradient_checkpointing_enable()
optimizer = torch.optim.Adam(model.parameters())
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward(retain_graph=True) # Added retain_graph=True to fix error
optimizer.step()
print('Fine-tuning step completed without checkpointing error') Workaround
Wrap the backward call in try/except RuntimeError, and if caught, retry backward with retain_graph=True or disable gradient checkpointing temporarily.
Prevention
Use compatible model architectures with gradient checkpointing, test backward passes carefully, and keep PEFT and PyTorch updated to latest versions with checkpointing fixes.