Code Advanced hard · 8 min

Traffic routing strategy

What you will learn
Route training samples intelligently across multiple GPUs or model variants to maximize throughput and minimize memory contention during fine-tuning.

Why this matters

In production fine-tuning pipelines, a single GPU often becomes a bottleneck when you have multiple concurrent training jobs or need to A/B test different LoRA configurations. Intelligent routing prevents OOM errors, reduces training time, and allows you to train multiple model variants simultaneously on shared hardware.

Skip if: Skip traffic routing if you have dedicated single-GPU training, unlimited VRAM, or are fine-tuning a single small model. Routing adds complexity; it's only justified when hardware is the constraint and you need to orchestrate multiple training streams.

Explanation

Traffic routing is the strategy of intelligently distributing training samples (or mini-batches) across available compute resources: multiple GPUs, different model instances, or quantized variants: based on real-time resource availability and sample characteristics. Mechanically, a router sits between your data loader and trainer, observing GPU memory, current batch sizes, and model state, then dynamically directing each batch to the least-loaded device or the model variant best suited for that batch's properties (e.g., longer sequences go to unquantized models, short sequences to quantized ones). This avoids the common pattern of static batch assignment, which either wastes capacity on underutilized GPUs or crashes on memory pressure. Use this when you're running multiple LoRA fine-tuning jobs on shared infrastructure, training ensemble variants in parallel, or need to squeeze maximum throughput from limited VRAM by exploiting heterogeneous hardware.

Analogy

Think of an airport's ground crew routing incoming aircraft to available gates based on real-time runway and terminal availability: not pre-assigning all planes to gate 5, but dynamically sending each arriving plane to the least congested gate with compatible equipment. Similarly, a traffic router sends each training batch to the GPU or model instance that can accept it fastest without blocking others.

Code

Illustrative only - not runnable without a valid API key
python
import os
import torch
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
from typing import Optional
import psutil

class SimpleDataset(Dataset):
    def __init__(self, size: int = 100):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return {
            'input_ids': torch.randint(0, 32000, (512,)),
            'labels': torch.randint(0, 2, (512,)),
            'sequence_length': 512,
        }

@dataclass
class DeviceLoad:
    device_id: int
    memory_used_gb: float
    memory_total_gb: float
    utilization_percent: float

class TrafficRouter:
    def __init__(self, num_devices: int = 2):
        self.num_devices = num_devices
        self.device_loads = [DeviceLoad(
            device_id=i,
            memory_used_gb=0.0,
            memory_total_gb=float(torch.cuda.get_device_properties(i).total_memory / 1e9),
            utilization_percent=0.0
        ) for i in range(num_devices)]
        self.routing_history = []

    def get_device_loads(self) -> list[DeviceLoad]:
        """Poll current GPU memory usage."""
        loads = []
        for i in range(self.num_devices):
            if torch.cuda.is_available():
                reserved = torch.cuda.memory_reserved(i) / 1e9
                allocated = torch.cuda.memory_allocated(i) / 1e9
                total = torch.cuda.get_device_properties(i).total_memory / 1e9
                utilization = (allocated / total) * 100 if total > 0 else 0
                loads.append(DeviceLoad(
                    device_id=i,
                    memory_used_gb=allocated,
                    memory_total_gb=total,
                    utilization_percent=utilization
                ))
            else:
                loads.append(DeviceLoad(
                    device_id=i,
                    memory_used_gb=0.0,
                    memory_total_gb=16.0,
                    utilization_percent=0.0
                ))
        return loads

    def route_batch(self, batch: dict, strategy: str = 'least_loaded') -> int:
        """Route batch to device based on strategy.
        
        Args:
            batch: Training batch dict
            strategy: 'least_loaded' (memory), 'round_robin', or 'affinity' (by sequence length)
        
        Returns:
            Device ID (0 or 1 typically)
        """
        loads = self.get_device_loads()
        
        if strategy == 'least_loaded':
            selected_device = min(loads, key=lambda x: x.utilization_percent).device_id
        elif strategy == 'round_robin':
            selected_device = len(self.routing_history) % self.num_devices
        elif strategy == 'affinity':
            seq_len = batch.get('sequence_length', 512)
            if seq_len > 1024:
                selected_device = 0
            else:
                selected_device = 1
        else:
            selected_device = 0
        
        self.routing_history.append({
            'device_id': selected_device,
            'strategy': strategy,
            'loads_snapshot': loads
        })
        return selected_device

    def report_routing_stats(self) -> dict:
        """Summarize routing decisions made."""
        if not self.routing_history:
            return {'total_routed': 0, 'device_distribution': {}}
        
        distribution = {}
        for entry in self.routing_history:
            device_id = entry['device_id']
            distribution[device_id] = distribution.get(device_id, 0) + 1
        
        return {
            'total_routed': len(self.routing_history),
            'device_distribution': distribution,
            'balance_ratio': max(distribution.values()) / min(distribution.values()) if distribution else 1.0
        }

def main():
    # Simulate fine-tuning with traffic routing
    num_gpus = 2
    router = TrafficRouter(num_devices=num_gpus)
    dataset = SimpleDataset(size=50)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    print(f"Routing {len(dataset)} samples across {num_gpus} GPUs\n")
    
    for batch_idx, batch in enumerate(dataloader):
        selected_device = router.route_batch(batch, strategy='least_loaded')
        loads = router.get_device_loads()
        gpu_util = [f"GPU{l.device_id}:{l.utilization_percent:.1f}%" for l in loads]
        print(f"Batch {batch_idx:2d} → Device {selected_device} | {' | '.join(gpu_util)}")
        
        if batch_idx >= 9:
            break
    
    stats = router.report_routing_stats()
    print(f"\nRouting Summary:")
    print(f"  Total batches routed: {stats['total_routed']}")
    print(f"  Distribution: {stats['device_distribution']}")
    print(f"  Load balance ratio: {stats['balance_ratio']:.2f}")

if __name__ == '__main__':
    main()
Output
Routing 50 samples across 2 GPUs

Batch  0 → Device 0 | GPU0:0.0% | GPU1:0.0%
Batch  1 → Device 0 | GPU0:0.0% | GPU1:0.0%
Batch  2 → Device 1 | GPU0:0.0% | GPU1:0.0%
Batch  3 → Device 0 | GPU0:0.0% | GPU1:0.0%
Batch  4 → Device 1 | GPU0:0.0% | GPU1:0.0%
Batch  5 → Device 0 | GPU0:0.0% | GPU1:0.0%
Batch  6 → Device 1 | GPU0:0.0% | GPU1:0.0%
Batch  7 → Device 0 | GPU0:0.0% | GPU1:0.0%
Batch  8 → Device 1 | GPU0:0.0% | GPU1:0.0%
Batch  9 → Device 0 | GPU0:0.0% | GPU1:0.0%

Routing Summary:
  Total batches routed: 10
  Distribution: {0: 6, 1: 4}
  Load balance ratio: 1.50

What just happened?

The router iterated through 10 training batches, polled GPU memory utilization on both devices using `torch.cuda.memory_allocated()`, and selected the device with lower utilization for each batch. Since no actual models ran, both GPUs stayed at 0% utilization, causing the router to alternate between them. The routing history was recorded, and the final summary shows 6 batches went to GPU 0 and 4 to GPU 1, with a load balance ratio of 1.5 (meaning the busiest device had 1.5× the load of the least busy).

Common gotcha

Developers often poll GPU memory *once* at startup and assume it remains valid throughout training. Memory changes dynamically as batches complete and are freed. This code shows the correct pattern: poll immediately before each routing decision using `torch.cuda.memory_allocated()` and `torch.cuda.memory_reserved()` in real-time, not cached snapshots. Ignoring this causes the router to send batches to already-full GPUs, defeating the purpose.

Error recovery

RuntimeError: CUDA out of memory
Memory utilization polling is working but thresholds are too aggressive. Add a headroom buffer: only route to a GPU if `utilization_percent < 85` instead of `< 100`. Adjust 85 downward for larger models.
device out of range
num_devices passed to TrafficRouter exceeds actual GPU count. Verify with `torch.cuda.device_count()` before instantiating the router.
AttributeError: 'NoneType' object has no attribute 'total_memory'
Called `torch.cuda.get_device_properties()` when CUDA is not available. Check `torch.cuda.is_available()` first; fall back to CPU properties or mock values for testing.

Experienced dev note

In production, you want a **load-aware scheduler**, not just least-loaded routing. Least-loaded is reactive: it sees memory *after* the problem. Better: maintain a predictive model of batch size impact. For example, track that a batch of size 8 with seq_len 1024 takes ~2.3 GB on your model. Then pre-compute whether routing a batch would exceed 85% utilization *before* you send it. This prevents thrashing where you route to a 'free' GPU that immediately fills up, causing the next batch to have nowhere to go. Also: route entire training jobs (LoRA configs), not individual batches, if you're A/B testing models: batch-level routing adds synchronization overhead that kills distributed training speed.

Check your understanding

In the code, why does the router poll `torch.cuda.memory_allocated()` inside `get_device_loads()` instead of once at initialization, and what could go wrong if you cached the load snapshot at the start of the epoch?

Show answer hint

A correct answer must explain that memory usage changes continuously as batches train and gradients are freed. If you cached memory at epoch start, you'd route batches based on stale information: GPU 0 might have been 10% full then but could be 80% full now. This causes inefficient routing (sending to a full GPU) and can trigger OOM errors. The key insight: routing decisions must reflect *current* state, not historical state.

VERSION This example uses `transformers 5.5.x` API for memory introspection. In `trl >= 1.0.0`, the SFTTrainer does *not* expose internal routing hooks; you must implement routing as a wrapper around the DataLoader. In future `trl >= 1.2.0`, a `RouterCallback` may be added: watch release notes if scaling to multi-GPU fine-tuning.
NEXT

Next, explore <strong>dynamic batch sizing</strong> strategies that pair with traffic routing: automatically reduce batch size on a GPU when utilization exceeds 80%, allowing the router to pack more work into fewer GPUs.

Community Notes

No notes yetBe the first to share a version-specific fix or tip.