nn.MaxPool2d: downsampling
Why this matters
Downsampling is essential in CNNs to reduce feature map size, lower memory footprint, and increase receptive field: but you need to understand pool size, stride, and padding trade-offs to avoid losing critical spatial information or breaking architecture assumptions.
Explanation
What it is: nn.MaxPool2d is a layer that slides a window (kernel) over the input feature map and outputs the maximum value in each window. It's a form of downsampling that reduces height and width while keeping depth (channels) unchanged.
How it works: For each position in the kernel window, the layer compares all values under that window and keeps only the largest. If you use a 2×2 kernel with stride 2, a 4×4 input becomes 2×2 output: each 2×2 block is reduced to its single maximum value. This is why it's called max pooling. Padding works the same way as in convolutions: it adds zeros around the edges to control output size.
When to use it: MaxPool2d typically appears after convolutional blocks to reduce spatial dimensions. It's the standard choice in image classification networks because it's translation-invariant (small shifts don't change which value is maximum) and computationally cheap compared to another conv layer.
Analogy
Imagine a grid of elevation readings from a terrain map. MaxPool2d is like saying: 'divide the map into 2×2 squares and keep only the highest peak in each square.' You lose detail but keep the most important (highest) information and now have a much coarser map.
Code
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Create a simple feature map (batch_size=1, channels=1, height=4, width=4)
x = torch.tensor([[
[[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0]]
]]).to(device)
print(f"Input shape: {x.shape}")
print(f"Input:\n{x}")
print()
# 2×2 max pool with stride 2 (no padding)
pool = nn.MaxPool2d(kernel_size=2, stride=2).to(device)
output = pool(x)
print(f"Output shape: {output.shape}")
print(f"Output after MaxPool2d(2, stride=2):\n{output}")
print()
# 2×2 max pool with stride 1 (overlapping windows)
pool_overlap = nn.MaxPool2d(kernel_size=2, stride=1).to(device)
output_overlap = pool_overlap(x)
print(f"Output shape with stride=1: {output_overlap.shape}")
print(f"Output after MaxPool2d(2, stride=1):\n{output_overlap}")
print()
# 2×2 max pool with padding=1
pool_padded = nn.MaxPool2d(kernel_size=2, stride=2, padding=1).to(device)
output_padded = pool_padded(x)
print(f"Output shape with padding=1: {output_padded.shape}")
print(f"Output after MaxPool2d(2, stride=2, padding=1):\n{output_padded}")
# Multi-channel example (realistic)
x_multi = torch.randn(2, 3, 8, 8).to(device)
pool_multi = nn.MaxPool2d(kernel_size=2, stride=2).to(device)
output_multi = pool_multi(x_multi)
print(f"\nMulti-channel input shape: {x_multi.shape}")
print(f"Multi-channel output shape: {output_multi.shape}") Input shape: torch.Size([1, 1, 4, 4])
Input:
tensor([[[[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.],
[13., 14., 15., 16.]]]])
Output shape: torch.Size([1, 1, 2, 2])
Output after MaxPool2d(2, stride=2):
tensor([[[[ 6., 8.],
[14., 16.]]]])
Output shape with stride=1: torch.Size([1, 1, 3, 3])
Output after MaxPool2d(2, stride=1):
tensor([[[[ 6., 7., 8.],
[10., 11., 12.],
[14., 15., 16.]]]])
Output shape with padding=1: torch.Size([1, 1, 3, 3])
Output after MaxPool2d(2, stride=2, padding=1):
tensor([[[[ 1., 3., 4.],
[ 9., 11., 12.],
[13., 15., 16.]]]])
Multi-channel input shape: torch.Size([2, 3, 8, 8])
Multi-channel output shape: torch.Size([2, 3, 4, 4]) What just happened?
The code created a 4×4 feature map and applied MaxPool2d with different configurations. With kernel_size=2, stride=2 (non-overlapping windows), the output became 2×2: each 2×2 block was reduced to its maximum value (6, 8, 14, 16). With stride=1, windows overlapped, producing a 3×3 output instead. With padding=1, zeros were added around the input edges before pooling, changing which values appeared at the borders. Finally, the multi-channel example showed that pooling applies independently to each of the 3 channels, preserving the channel dimension while only reducing spatial dimensions.
Common gotcha
Developers often forget that stride defaults to kernel_size, not 1. If you write nn.MaxPool2d(2), it's equivalent to nn.MaxPool2d(kernel_size=2, stride=2), giving non-overlapping windows. If you want overlapping pools (which are less common but useful for fine-grained feature preservation), you must explicitly set stride=1. The second gotcha: padding shifts which values survive pooling at the borders: padding adds zeros, so corner and edge pixels rarely appear in the output of later layers unless you understand the padding carefully.
Error recovery
RuntimeError: Expected 4D input (got 3D input instead)RuntimeError: input size must be >= kernel sizeShape mismatch after poolingExperienced dev note
In production CNNs, most developers use stride=kernel_size (non-overlapping) because it gives the best memory savings and is standard in ResNet/VGG. However, when you're debugging a network that loses too much spatial information early on, overlapping pooling (stride=1 or stride=kernel_size//2) can recover some detail at the cost of computation. Also: MaxPool2d doesn't have learnable parameters: if you find pooling is being too aggressive for your task, the fix is usually to use a smaller kernel or add a dilated convolution to increase receptive field without downsampling. Avoid the temptation to use nn.AvgPool2d in classification; max pooling is almost always better because it preserves sharp features (edges, corners) better than averaging.
Check your understanding
You have a 16×16 input with 3 channels and apply MaxPool2d(kernel_size=4, stride=2, padding=1). What will be the output shape, and how many spatial positions in the output are affected by zero-padding (from the padding=1)?
Show answer hint
First: calculate output size using the formula <code>(16 + 2*1 - 4) // 2 + 1 = 8</code>, so shape is (batch, 3, 8, 8). Second: padding adds a ring of zeros around the 16×16 input. The top-left 4×4 window during the first stride will touch these padded zeros. Count how many output positions have at least one zero in their kernel window: this requires understanding how the padding overlay maps to output positions.