Dataset class: __len__ and __getitem__
Why this matters
Most real-world datasets don't come in a neat PyTorch tensor. You need to implement Dataset to load images, audio, or database records on-the-fly during training without loading everything into memory at once.
Explanation
What it is: A Dataset subclass defines how to access individual samples from your data. The __len__ method tells the DataLoader how many samples exist. The __getitem__ method returns a single sample (and label) at index idx.
How it works: When you create a DataLoader(dataset, batch_size=32), the loader calls len(dataset) to know the total size, then repeatedly calls dataset[i] for random or sequential indices i. This lazy-loading pattern keeps memory usage constant regardless of dataset size: only the current batch lives in RAM.
When to use it: Use this pattern whenever your data lives on disk (image files, audio files) or requires preprocessing (resizing, augmentation) that shouldn't happen all at once. This is the standard approach for computer vision and NLP pipelines.
Analogy
Think of Dataset like a library catalog system. <code>__len__</code> is the card catalog that says "we have 1,000 books." <code>__getitem__</code> is the request form: you hand it a book number (index), and it retrieves that specific book from the shelves. The librarian (DataLoader) uses the catalog to know how many requests to make, and fetches books one at a time.
Code
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class CustomImageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
label = self.labels[idx]
image = np.random.randn(28, 28, 3).astype(np.float32)
if self.transform:
image = self.transform(image)
return torch.from_numpy(image), label
image_paths = [f"image_{i}.jpg" for i in range(100)]
labels = np.random.randint(0, 10, 100)
dataset = CustomImageDataset(image_paths, labels)
print(f"Dataset length: {len(dataset)}")
sample_image, sample_label = dataset[0]
print(f"Sample shape: {sample_image.shape}, label: {sample_label}")
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
batch_images, batch_labels = next(iter(dataloader))
print(f"Batch shape: {batch_images.shape}")
print(f"Batch labels: {batch_labels}") Dataset length: 100 Sample shape: torch.Size([28, 28, 3]), label: 7 Batch shape: torch.Size([16, 28, 28, 3]) Batch labels: tensor([5, 2, 9, 1, 8, 4, 7, 3, 6, 0, 2, 5, 8, 1, 9, 3])
What just happened?
We defined a custom Dataset class with 100 samples. Calling <code>len(dataset)</code> invoked <code>__len__</code> and returned 100. Accessing <code>dataset[0]</code> invoked <code>__getitem__(0)</code>, which loaded the sample at index 0 and returned a tuple of (image tensor, label integer). The DataLoader called <code>__getitem__</code> multiple times internally, collecting samples into batches and returning them as stacked tensors with shape <code>[batch_size, ...]</code>.
Common gotcha
Forgetting that __getitem__ must return a tuple (data, target) if your downstream code expects both. Many developers write return image and wonder why the DataLoader collapses the batch structure. Also: if you do expensive preprocessing inside __getitem__ (like loading a full image file every call), you'll trash performance: profile with torch.utils.bottleneck before blaming DataLoader.
Error recovery
TypeError: object of type 'int' has no len()RuntimeError: expected scalar type Long but found FloatIndexError: list index out of rangeExperienced dev note
Always use num_workers > 0 in DataLoader (e.g., num_workers=4) if your __getitem__ does I/O (file reads, database queries). This spawns worker processes that fetch samples in parallel while the main process trains. Without it, the GPU sits idle waiting for CPU-bound I/O. For testing, use num_workers=0 so stack traces point to your code, not worker processes. Also: Dataset should be stateless: never store mutable objects (lists, dicts) that workers will share, as multiprocessing behavior is unpredictable. If you need shared state, pass it immutably in __init__.
Check your understanding
If you have a dataset of 10,000 images split across 100 shards (files), how would you modify __getitem__ to load the correct shard without loading all 100 shards at once? What would break if you loaded all shards in __init__ instead?
Show answer hint
A correct answer explains that __getitem__ calculates which shard to load based on idx (e.g., shard = idx // 100), opens only that shard file, and retrieves the sample: keeping memory constant. Loading all shards in __init__ defeats lazy-loading and wastes memory proportional to data size, defeating the purpose of Dataset.
torch.compile() support for DataLoader iteration: if your __getitem__ is pure functions, consider wrapping the batch processing loop with @torch.compile for speedup.