How to use ImageFolder dataset in PyTorch
Quick answer
Use
torchvision.datasets.ImageFolder to load images arranged in folders by class labels. Combine it with torch.utils.data.DataLoader to iterate batches for training or evaluation in PyTorch.PREREQUISITES
Python 3.8+pip install torch torchvision
Setup
Install PyTorch and torchvision if not already installed. Ensure your image dataset is organized in a root directory with subfolders named by class labels, each containing images.
pip install torch torchvision Step by step
This example loads an image dataset using ImageFolder, applies basic transforms, and creates a DataLoader for batch iteration.
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define dataset root directory
root_dir = 'path/to/your/dataset'
# Define transforms for images (resize + tensor conversion)
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
# Load dataset with ImageFolder
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
# Create DataLoader for batching
loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Iterate one batch and print batch shape and labels
images, labels = next(iter(loader))
print(f'Batch image tensor shape: {images.shape}')
print(f'Batch labels tensor: {labels}') output
Batch image tensor shape: torch.Size([16, 3, 128, 128]) Batch labels tensor: tensor([0, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2])
Common variations
- Use different transforms like normalization or data augmentation with
transforms.Normalize,transforms.RandomHorizontalFlip. - Set
shuffle=Falsein DataLoader for validation/testing. - Use
torchvision.datasets.ImageFolderwith custom loader functions if images are not standard formats.
Troubleshooting
- If you get a
FileNotFoundError, verify theroot_dirpath and folder structure. - If images fail to load, check that image files are valid and supported formats (JPEG, PNG).
- Ensure subfolders are named exactly as class labels; otherwise,
ImageFolderwill not assign labels correctly.
Key Takeaways
- Use
torchvision.datasets.ImageFolderto load image datasets organized by class folders. - Combine with
torch.utils.data.DataLoaderfor efficient batch processing. - Apply transforms like resizing and tensor conversion to prepare images for models.
- Verify folder structure and image formats to avoid loading errors.