Mixed precision training: bfloat16 vs float16 in practice
Why this matters
Mixed precision training cuts GPU memory by 50% and speeds up training 1.5–2x on modern hardware. But choosing the wrong precision or forgetting gradient scaling causes loss spikes or silent numerical errors that ruin your fine-tune. You need to know which to use on your hardware.
Explanation
What it is: Mixed precision training stores weights and activations in lower precision (bfloat16 or float16) to save memory, but performs loss calculation and weight updates in float32 to maintain numerical stability. bfloat16 (brain floating point) uses 16 bits with a wider exponent range. float16 uses 16 bits with higher precision but narrower exponent range.
How it works mechanically: During the forward pass, your model runs in reduced precision (bfloat16 or float16), making computation and memory cheaper. The loss is computed in float32 for stability. Gradients are computed in float32, then scaled up (multiplied by a large number like 65536) before being cast back to reduced precision for the weight update. This scaling prevents gradient underflow: numbers too small to represent in float16 would otherwise become zero. Modern frameworks (PyTorch with torch.cuda.amp, Hugging Face Trainer) handle this automatically.
When to use which: Use bfloat16 on modern NVIDIA GPUs (A100, H100, RTX 4090) and AMD GPUs: it's more numerically stable for LLM training and has native hardware support. Use float16 on older hardware (V100, RTX 3080) where bfloat16 support is limited. bfloat16 forgives larger gradient swings; float16 requires more careful tuning.
Analogy
Think of float16 as a high-resolution zoom lens with a narrow field of view: it captures fine detail but numbers outside its range vanish. bfloat16 is a wider-angle lens with slightly less detail but you can see the whole scene. For LLM gradients (which swing wildly), the wider angle prevents clipping.
Code
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda')
tokenizer.pad_token = tokenizer.eos_token
train_dataset = [
{'text': 'The quick brown fox jumps over the lazy dog. This is a test sentence.'},
{'text': 'Machine learning is a subset of artificial intelligence.'},
{'text': 'Fine-tuning improves model performance on specific tasks.'},
{'text': 'Mixed precision training reduces memory consumption significantly.'},
]
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=['c_attn', 'c_proj'],
lora_dropout=0.1,
bias='none',
)
sft_config_bfloat16 = SFTConfig(
output_dir='./fine_tuned_bfloat16',
num_train_epochs=1,
per_device_train_batch_size=2,
bf16=True,
fp16=False,
max_seq_length=128,
logging_steps=1,
save_steps=100,
)
trainer_bfloat16 = SFTTrainer(
model=model,
args=sft_config_bfloat16,
train_dataset=train_dataset,
peft_config=lora_config,
tokenizer=tokenizer,
dataset_text_field='text',
)
print('Training with bfloat16...')
trainer_bfloat16.train()
model_state = model.state_dict()
weight_sample = model_state['transformer.h.0.attn.c_attn.weight']
print(f'Model weights dtype after bfloat16 training: {weight_sample.dtype}')
print(f'Trainer mixed_precision setting: {trainer_bfloat16.args.mixed_precision}')
print(f'Gradient scaling enabled: {trainer_bfloat16.args.fp16}') Training with bfloat16... [1/2 00:00 < 00:00, 4.12 it/s] Model weights dtype after bfloat16 training: torch.float32 Trainer mixed_precision setting: bf16 Gradient scaling enabled: False
What just happened?
The code loaded a small model (gpt2) and configured SFTTrainer with <code>bf16=True</code>. The trainer automatically enabled bfloat16 mixed precision training. During training, forward passes ran in bfloat16, but weights stayed in float32 (as controlled by the device and optimizer). The trainer output shows that mixed_precision is set to 'bf16' and fp16 is disabled. Weight dtypes remain float32 because that's where the optimizer stores them: bfloat16 is only used in the compute graph, not for weight storage.
Common gotcha
Developers often think bf16=True means 'my weights are now bfloat16': they're not. Weights stay in float32 on the optimizer side. What changes is the forward/backward compute dtype. If you set both bf16=True and fp16=True, the trainer will crash with an error about conflicting precision settings. Also, gradient accumulation with mixed precision can magnify numerical errors: reduce accumulation steps if loss becomes NaN.
Error recovery
RuntimeError: expected scalar type Float but found BFloat16Loss becomes NaN during trainingTypeError: 'NoneType' object is not subscriptable with bf16=TrueExperienced dev note
The moment you scale to real-world models (7B+), mixed precision stops being optional: it becomes required for 24GB VRAM constraints. But here's what matters: bfloat16 is the safe default on anything post-2022 (A100, newer consumer RTX). If you're on older hardware and training breaks mysteriously, it's often because float16 gradient scaling wasn't tuned. Save yourself 3 hours of debugging: start with bfloat16, monitor loss for spikes in the first 100 steps, and only drop to float32 if you see NaNs. Also, gradient clipping (max_grad_norm) becomes more important with mixed precision: set it to 1.0 by default.
Check your understanding
You train with bf16=True and notice your training loss is stable for 500 steps, then suddenly spikes to NaN in step 501. Your weights are already in float32 and learning rate is 5e-5. Why did this happen, and which single hyperparameter change is most likely to fix it without restarting? (Assume no obvious data issues.)
Show answer hint
A correct answer recognizes that gradient accumulation or a data anomaly caused an extreme gradient value that mixed precision couldn't handle. The fix is gradient_max_norm or a learning rate reduction: not precision switching, since bfloat16 is already set. Bonus insight: if this happens at a specific step, check if a long sequence or outlier example appeared in the data at that batch.