Distributed training for full model fine-tuning
Why this matters
Full model fine-tuning on enterprise datasets requires distributing training across multiple GPUs: a single GPU runs out of memory or takes weeks. Knowing when and how to set up distributed training correctly prevents costly failed training runs and accelerates iteration on large models.
Explanation
Full model fine-tuning at scale means updating every parameter in an LLM across multiple machines or GPUs. Unlike LoRA (which trains only adapter weights), full fine-tuning stores optimizer states and gradients for the entire model: this memory explosion is why distribution becomes necessary around 13B+ models.
Mechanically, PyTorch's DistributedDataParallel (DDP) replicates the model on each GPU, computes gradients locally, then synchronizes them via all-reduce operations (typically NCCL on NVIDIA). The Hugging Face Trainer wraps this: when you set num_train_epochs and per_device_train_batch_size with world_size > 1, it automatically handles distributed setup, gradient accumulation across processes, and loss scaling.
When to use this: You have 2+ GPUs (or TPUs), a model 13B+, and dataset > 10K examples. For smaller models or single-GPU setups, the communication overhead wastes time. For even larger models (70B+), consider DeepSpeed or FSDP instead of vanilla DDP.
Analogy
Imagine a bakery doubling a recipe. Instead of one baker making twice as much (slow, messy), you split: Baker A handles ingredients for 50 loaves, Baker B for 50 loaves. They work in parallel, then sync results at checkpoint. If they go out of sync, one batch of loaves fails. Distributed training is that sync problem solved by NCCL.
Code
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
import os
def setup_distributed():
"""Initialize distributed training environment."""
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", "29500")
dist.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method=f"tcp://{master_addr}:{master_port}",
rank=rank,
world_size=world_size
)
torch.cuda.set_device(rank)
return rank, world_size
return 0, 1
def main():
rank, world_size = setup_distributed()
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
texts = [
"The future of AI is distributed training across multiple nodes.",
"Scaling large language models requires careful synchronization.",
"Each GPU processes a batch and syncs gradients via all-reduce.",
] * 100
dataset = Dataset.from_dict({"text": texts})
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=128,
padding="max_length"
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"]
)
training_args = TrainingArguments(
output_dir="./distributed_ft_output",
num_train_epochs=2,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
learning_rate=5e-5,
logging_steps=10,
save_steps=50,
remove_unused_columns=False,
dataloader_pin_memory=True,
dataloader_num_workers=4,
ddp_find_unused_parameters=False,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
if rank == 0:
print(f"Starting distributed training on {world_size} processes")
trainer.train()
if rank == 0:
print("Training complete. Model checkpoint saved.")
dist.destroy_process_group() if world_size > 1 else None
if __name__ == "__main__":
main() Starting distributed training on 1 processes Training: 0%| | 0/6 [00:00<?, ?it/s] Training: 100%|██████████| 6/6 [00:15<00:00, 2.50s/it] Training complete. Model checkpoint saved.
What just happened?
The code initialized a distributed process group (even in single-process mode, which is the fallback for local testing), loaded a small GPT-2 model, tokenized a dataset of 300 examples, and ran the Trainer with DDP-enabled arguments. The Trainer automatically wrapped the model in DistributedDataParallel, synchronized gradients across processes after each backward pass, and saved checkpoints to <code>./distributed_ft_output</code>. On a multi-GPU machine, RANK and WORLD_SIZE environment variables would be set by a launcher (torchrun or accelerate launch), and gradients would sync via NCCL between GPUs.
Common gotcha
Setting ddp_find_unused_parameters=True is tempting when your model has conditional forward paths, but it adds significant communication overhead per step. Instead, refactor your model to ensure all parameters are used in the forward pass, or move unused modules to a separate non-DDP wrapper. Also, if batch size is too small relative to world_size, gradient synchronization cost dominates computation time: use gradient accumulation to accumulate gradients over multiple local batches before sync.
Error recovery
CUDA out of memory during distributed trainingRuntimeError: Expected to have finished reduction in the backward passConnection refused during dist.init_process_groupLoss spikes or diverges after first stepExperienced dev note
Most engineers try to manually set up torch.distributed.init_process_group and DDP wrapping, which breaks when they move to multi-node or change hardware. Use Trainer and TrainingArguments: it handles all the boilerplate and is battle-tested at scale. Second: monitoring is critical. Set logging_steps to a small value and watch for loss plateaus or communication bottlenecks. If loss doesn't decrease, the issue is usually dataset quality or learning rate, not the distributed setup itself. Third: always run a single-epoch smoke test on a single GPU first with the exact same code path: if it fails, distributed training will fail too, just slower.
Check your understanding
You're fine-tuning a 34B model on 4 GPUs with per_device_train_batch_size=4 and gradient_accumulation_steps=2. Describe what happens to the actual batch size seen by each GPU's model during the backward pass, and explain why gradient synchronization occurs at that granularity, not at the accumulation level.
Show answer hint
A correct answer explains that each GPU processes a local batch of 4 examples (the per_device batch size), computes gradients for those 4 examples, then synchronizes those gradients via all-reduce before accumulation happens. Accumulation is a post-sync step: the synchronized gradients are scaled by 1/(accumulation_steps) and summed, then the optimizer steps. This means all-reduce happens 2x per effective batch (once per local batch), not once per effective batch, which is the real cost of distributed training.