How to split dataset into train and validation PyTorch
Quick answer
Use
torch.utils.data.random_split to split a PyTorch Dataset into training and validation subsets by specifying lengths. Then create DataLoader instances for each subset to iterate batches during training and validation.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if not already installed. Use the following command to install the latest stable PyTorch version with CPU support:
pip install torch torchvision Step by step
This example demonstrates splitting a dataset into 80% training and 20% validation sets using torch.utils.data.random_split. It then creates DataLoader objects for batch iteration.
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
# Create a sample dataset (e.g., FakeData for demonstration)
dataset = FakeData(size=1000, transform=ToTensor())
# Define split lengths
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
# Split dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Example: iterate one batch from train_loader
for images, labels in train_loader:
print(f"Batch images shape: {images.shape}")
print(f"Batch labels shape: {labels.shape}")
break output
Batch images shape: torch.Size([32, 3, 224, 224]) Batch labels shape: torch.Size([32])
Common variations
- Use
SubsetRandomSamplerwithDataLoaderfor more control over indices. - For custom datasets, ensure your
Datasetimplements__len__and__getitem__. - Shuffle training data but not validation data for reproducibility.
Troubleshooting
- If you get an error about dataset length mismatch, verify your split lengths sum to the total dataset size.
- If batches are not shuffled during training, ensure
shuffle=Truein the trainingDataLoader. - For reproducible splits, set a manual seed before calling
random_splitusingtorch.manual_seed(seed).
Key Takeaways
- Use
torch.utils.data.random_splitto easily split datasets into train and validation subsets. - Create separate
DataLoaderinstances for training and validation with appropriate batch sizes and shuffling. - Set a manual random seed for reproducible dataset splits.
- Shuffle training data but keep validation data order fixed for consistent evaluation.
- Verify split sizes sum exactly to the dataset length to avoid errors.