How to use data augmentation in PyTorch
Quick answer
Use
torchvision.transforms to apply data augmentation in PyTorch by composing transformations like random flips, rotations, and color jitter. Integrate these transforms into your Dataset pipeline to augment images on the fly during training.PREREQUISITES
Python 3.8+pip install torch torchvision
Setup
Install PyTorch and torchvision if you haven't already. These libraries provide the core functionality and data augmentation utilities.
pip install torch torchvision Step by step
This example shows how to apply common data augmentations like random horizontal flip, random rotation, and color jitter using torchvision.transforms. The transforms are composed and passed to a torchvision.datasets.CIFAR10 dataset for training.
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define data augmentation transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
# Load CIFAR10 training dataset with augmentation
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Iterate over one batch and print shape
images, labels = next(iter(train_loader))
print(f'Batch image tensor shape: {images.shape}')
print(f'Batch labels tensor shape: {labels.shape}') output
Files already downloaded and verified Batch image tensor shape: torch.Size([64, 3, 32, 32]) Batch labels tensor shape: torch.Size([64])
Common variations
You can customize augmentations by adding or removing transforms such as RandomCrop, GaussianBlur, or RandomGrayscale. For validation or test sets, use only normalization without augmentation. You can also create custom transforms by subclassing torchvision.transforms.Transform.
from torchvision.transforms import RandomCrop, GaussianBlur, RandomGrayscale
custom_transforms = transforms.Compose([
RandomCrop(28),
RandomHorizontalFlip(),
GaussianBlur(kernel_size=3),
RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]) Troubleshooting
- If images appear distorted or training accuracy is low, check the order of transforms; normalization should come after all augmentations.
- If your dataset is small, aggressive augmentation can help but may also cause underfitting if too strong.
- Ensure
ToTensor()is applied before normalization to convert images to tensors.
Key Takeaways
- Use torchvision.transforms.Compose to chain multiple augmentations easily.
- Apply data augmentation only on training data, not on validation or test sets.
- Normalize images after converting them to tensors with ToTensor().