Checkpoint Management
Why this matters
Fine-tuning large models generates gigabytes of checkpoints. Without proper management, you'll run out of disk space mid-training, lose recovery points, or resume from a stale checkpoint. This directly impacts production training pipelines.
Explanation
Checkpoint management in modern fine-tuning involves controlling what SFTTrainer saves (via save_strategy, save_steps, save_total_limit), what it loads (via resume_from_checkpoint), and how it selects the best checkpoint (via load_best_model_at_end and metric_for_best_model).
Mechanically, SFTTrainer inherits from HuggingFace's Trainer. After each evaluation step (if eval_strategy != 'no'), it compares the current metric against saved checkpoints. When save_total_limit=3, it keeps only the 3 best checkpoints by metric score, automatically deleting older ones. When resuming, it restores model weights, optimizer state, and training step counter from the checkpoint directory, allowing training to continue as if never interrupted.
Use checkpoint management when training for >12 hours, on shared infrastructure where interruption is likely, or when iterating hyperparameters and need to recover mid-run. Always set save_total_limit on large models to prevent storage bloat.
Analogy
Think of checkpoints as git commits: <code>save_strategy</code> decides when to commit (every N steps or every evaluation), <code>save_total_limit</code> is a branch that keeps only the last 3 meaningful commits and deletes the rest, and <code>resume_from_checkpoint</code> is checking out a specific commit to continue work from there.
Code
import os
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
import json
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_texts = [
'The quick brown fox jumps over the lazy dog',
'Machine learning requires large datasets',
'Checkpoints enable training recovery',
'Fine-tuning adapts pretrained models',
]
eval_texts = [
'Fine-tuning is useful for domain adaptation',
'Checkpoints save training state',
]
def build_dataset(texts):
data = {'text': texts}
return Dataset.from_dict(data)
train_dataset = build_dataset(train_texts)
eval_dataset = build_dataset(eval_texts)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
bias='none',
task_type='CAUSAL_LM',
)
output_dir = './checkpoint_demo'
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
eval_strategy='steps',
eval_steps=2,
save_strategy='steps',
save_steps=2,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model='eval_loss',
greater_is_better=False,
logging_steps=1,
logging_dir='./logs',
report_to=[],
seed=42,
bf16=False,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
tokenizer=tokenizer,
)
print('Starting initial training...')
trainer.train()
checkpoint_dirs = [d for d in os.listdir(output_dir) if d.startswith('checkpoint-')]
checkpoint_dirs.sort(key=lambda x: int(x.split('-')[1]))
print(f'\nCheckpoints saved: {checkpoint_dirs}')
print(f'Total checkpoints (limited by save_total_limit=2): {len(checkpoint_dirs)}')
best_checkpoint = trainer.state.best_model_checkpoint
print(f'Best checkpoint by eval_loss: {best_checkpoint}')
if best_checkpoint:
print(f'\nResuming from best checkpoint: {best_checkpoint}')
resume_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
eval_strategy='steps',
eval_steps=2,
save_strategy='steps',
save_steps=2,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model='eval_loss',
greater_is_better=False,
logging_steps=1,
report_to=[],
seed=42,
bf16=False,
)
resume_trainer = SFTTrainer(
model=model,
args=resume_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
tokenizer=tokenizer,
)
print(f'Training resumed from step {resume_trainer.state.global_step}')
print('Note: In this demo, global_step resets because we create a new trainer instance.')
print('In production, pass resume_from_checkpoint to trainer.train(resume_from_checkpoint=best_checkpoint)')
print(f'\nFinal best metric: {trainer.state.best_metric}')
print(f'Training completed. Model saved to: {output_dir}') Starting initial training... Loading best model from /root/checkpoint_demo/checkpoint-2 (score: 3.197845458984375). Checkpoints saved: ['checkpoint-2', 'checkpoint-4'] Total checkpoints (limited by save_total_limit=2): 2 Best checkpoint by eval_loss: /root/checkpoint_demo/checkpoint-2 Resuming from best checkpoint: /root/checkpoint_demo/checkpoint-2 Training resumed from step 0 Note: In this demo, global_step resets because we create a new trainer instance. In production, pass resume_from_checkpoint to trainer.train(resume_from_checkpoint=best_checkpoint) Final best metric: 3.197845458984375 Training completed. Model saved to: ./checkpoint_demo
What just happened?
The code trained a model with <code>save_steps=2</code> (save every 2 steps) and <code>save_total_limit=2</code> (keep only 2 best checkpoints). After 4 training steps, only checkpoints 2 and 4 remain on disk; earlier checkpoints were auto-deleted. <code>load_best_model_at_end=True</code> identified checkpoint-2 as best by eval_loss and stored its path in <code>trainer.state.best_model_checkpoint</code>. When creating a second trainer, we could pass <code>resume_from_checkpoint=best_checkpoint</code> to <code>trainer.train()</code> to continue from that exact state (weights, optimizer, step counter). The demo shows the checkpoint directory structure and best model metadata but does not actually perform true resumption (that requires passing the checkpoint to <code>train()</code>).
Common gotcha
Many developers set save_total_limit but forget to also set eval_strategy to something other than 'no'. Without evaluation, trainer never computes metrics, so it cannot rank checkpoints and save_total_limit silently keeps checkpoints in chronological order instead of by quality. You end up with your oldest checkpoints, not your best ones.
Error recovery
FileNotFoundError when resumingRuntimeError: 'step' not in checkpoint during resumeOutOfMemoryError from too many checkpointsModels differ after resuming from checkpointExperienced dev note
In production fine-tuning on cloud infrastructure, always set save_total_limit to prevent runaway storage costs: a single 70B parameter model checkpoint is ~140GB. Also, load_best_model_at_end=True with metric_for_best_model='eval_loss' is mandatory for production because final checkpoints are rarely the best ones; you'll deploy a worse model without this. One hidden trap: if you set save_strategy='epoch' but num_train_epochs=1, you get only one checkpoint at the end with no recovery points for mid-epoch failures. Use save_strategy='steps' instead.
Check your understanding
You train a 13B model for 10,000 steps with save_steps=500, save_total_limit=3, and eval_steps=500. After step 3000, your training job crashes. How many checkpoints exist on disk, and which one should you resume from? Explain why.
Show answer hint
A correct answer explains: (1) only the 3 best checkpoints by eval_loss remain (not the 6 chronological ones), (2) you resume from <code>trainer.state.best_model_checkpoint</code>, which is selected by the metric you set in <code>metric_for_best_model</code>, not the latest checkpoint, and (3) why resuming from the latest (checkpoint-3000) instead of the best can hurt your model quality.
resume_from_checkpoint parameter to trainer.train(). Always pass it as a parameter to train(), not as part of SFTConfig.