RuntimeError
torch.nn.modules.module.RuntimeError
Stack trace
RuntimeError: You cannot use gradient checkpointing with LoRA layers unless you disable LoRA's gradient checkpointing or adjust the model accordingly.
File "train.py", line 123, in train_loop
loss.backward()
File "/usr/local/lib/python3.10/site-packages/torch/tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward(*args, **kwargs) # Calls into C++ engine Why it happens
LoRA gradient checkpointing modifies the backward pass to save memory, but it requires specific support in the model's forward and backward methods. If the base model or LoRA implementation does not support gradient checkpointing properly, or if conflicting flags are set, this runtime error occurs during backpropagation.
Detection
Monitor training logs for RuntimeError mentioning gradient checkpointing conflicts when using LoRA; add assertions to check if both model and LoRA gradient checkpointing flags are enabled simultaneously before training starts.
Causes & fixes
LoRA gradient checkpointing is enabled but the base model does not support it properly.
Disable LoRA gradient checkpointing by setting `lora_enable_gradient_checkpointing=False` or update the base model to a version that supports gradient checkpointing with LoRA.
Both the base model and LoRA layers have gradient checkpointing enabled, causing conflicts.
Enable gradient checkpointing only on either the base model or the LoRA layers, not both; typically disable it on LoRA layers if the base model uses it.
Using an outdated or incompatible version of PEFT or bitsandbytes that mishandles gradient checkpointing with LoRA.
Upgrade PEFT and bitsandbytes packages to the latest versions that have fixed gradient checkpointing support for LoRA.
Code: broken vs fixed
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
gradient_checkpointing=True # This causes the conflict
)
model = get_peft_model(base_model, lora_config)
# Training loop that triggers RuntimeError
loss = model(input_ids, labels=labels).loss
loss.backward() # RuntimeError here due to gradient checkpointing conflict from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
gradient_checkpointing=False # Disabled to fix conflict
)
model = get_peft_model(base_model, lora_config)
loss = model(input_ids, labels=labels).loss
loss.backward() # Works without RuntimeError
print("Training step completed successfully.") Workaround
Wrap the backward call in try/except RuntimeError, and if the error matches gradient checkpointing conflict, disable LoRA gradient checkpointing dynamically or fallback to full precision training without checkpointing.
Prevention
Standardize on a single gradient checkpointing strategy by enabling it only on the base model or LoRA layers, and always test compatibility of your PEFT and bitsandbytes versions before training.