How to beginner · 3 min read

How to create custom dataset in PyTorch

Quick answer
Create a custom dataset in PyTorch by subclassing torch.utils.data.Dataset and implementing the __len__ and __getitem__ methods. This allows you to load and preprocess data on-the-fly for training models efficiently.

PREREQUISITES

  • Python 3.8+
  • pip install torch>=2.0

Setup

Install PyTorch if you haven't already. Use the official command from PyTorch installation guide. For CPU-only:

bash
pip install torch torchvision

Step by step

Subclass torch.utils.data.Dataset and implement __init__, __len__, and __getitem__. This example creates a dataset from a list of (feature, label) tuples.

python
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data):
        # data is a list of (feature, label) tuples
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        feature, label = self.data[idx]
        # Convert to tensors
        feature_tensor = torch.tensor(feature, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return feature_tensor, label_tensor

# Example data: list of (features, label)
data = [([1.0, 2.0], 0), ([3.0, 4.0], 1), ([5.0, 6.0], 0)]
dataset = CustomDataset(data)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

for batch_features, batch_labels in dataloader:
    print("Batch features:", batch_features)
    print("Batch labels:", batch_labels)
output
Batch features: tensor([[3., 4.],
        [1., 2.]])
Batch labels: tensor([1, 0])
Batch features: tensor([[5., 6.]])
Batch labels: tensor([0])

Common variations

  • Use transforms or preprocessing inside __getitem__ for data augmentation.
  • Load data from files (images, CSVs) by storing file paths in __init__ and reading files in __getitem__.
  • Use torchvision.datasets for many standard datasets to avoid reinventing the wheel.
python
from PIL import Image
import os

class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_files = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = 0  # example label
        return image, label

Troubleshooting

  • If you get a TypeError about missing methods, ensure __len__ and __getitem__ are implemented.
  • If data loading is slow, use DataLoader with num_workers>0 for parallel loading.
  • Check tensor shapes and types in __getitem__ to avoid model input errors.

Key Takeaways

  • Subclass torch.utils.data.Dataset and implement __len__ and __getitem__ to create custom datasets.
  • Use DataLoader to batch and shuffle your custom dataset efficiently.
  • Load and preprocess data on-the-fly inside __getitem__ for memory efficiency.
  • Use transforms for data augmentation and preprocessing within your dataset class.
  • Parallelize data loading with DataLoader's num_workers parameter to speed up training.
Verified 2026-04
Verify ↗