Code Intermediate medium · 6 min

Batch normalization: nn.BatchNorm2d

What you will learn
Batch normalization normalizes layer inputs across a batch to stabilize training and allow higher learning rates.

Why this matters

BN is used in nearly every modern CNN. Without it, deep networks train slowly, require careful learning rate tuning, and suffer from internal covariate shift. Understanding when and how to apply it prevents training failures and dramatically speeds convergence.

Skip if: Don't use BatchNorm2d on the input layer (raw images), before ReLU (depends on architecture choice, but post-activation is standard), or when your batch size is 1 or very small (<4). In inference with tiny batches or when batch statistics are unreliable, rely on running statistics from training instead of recomputing them.

Explanation

Batch normalization is a layer that normalizes the outputs of the previous layer across the batch dimension, then applies a learnable scale and shift. For 2D convolutions (images), nn.BatchNorm2d normalizes across the batch and spatial dimensions, keeping the channel dimension separate.

Mechanically: for each channel, it computes the mean and variance across batch and spatial dimensions, subtracts mean, divides by standard deviation, then applies learnable parameters gamma (scale) and beta (shift). During training, it uses batch statistics; during inference, it uses running statistics (exponential moving average) computed during training. This prevents the distribution of layer inputs from shifting dramatically as weights update.

Use BatchNorm2d after convolutional layers and before activation functions in standard CNNs. It acts as a regularizer, often allowing you to drop Dropout, and stabilizes training enough that you can use higher learning rates. Not all architectures need it (ResNets have identity shortcuts that reduce covariate shift), but it remains the default choice in most image models.

Analogy

Think of batch norm as auto-adjusting the volume level of a microphone during a broadcast. The signal (activations) comes in at unpredictable loudness (distribution), the microphone measures the current crowd noise (batch statistics) and normalizes the volume relative to it, then applies its own tone controls (learnable gamma and beta). This lets the next speaker (next layer) work with a consistent signal strength.

Code

python
import torch
import torch.nn as nn
import torch.optim as optim

# Create a simple CNN with BatchNorm2d
model = nn.Sequential(
    nn.Conv2d(3, 32, kernel_size=3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(inplace=True),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(64, 10)
)

# Create a batch of random images (batch_size=8, channels=3, height=32, width=32)
x = torch.randn(8, 3, 32, 32)

# Forward pass in training mode
model.train()
output_train = model(x)
print(f"Training output shape: {output_train.shape}")
print(f"Training output (first 3 values): {output_train[0, :3]}")

# Access BatchNorm2d layer directly to inspect running statistics
bn_layer = model[1]
print(f"\nBatchNorm2d running mean (first 5 channels): {bn_layer.running_mean[:5]}")
print(f"BatchNorm2d running var (first 5 channels): {bn_layer.running_var[:5]}")
print(f"BatchNorm2d weight shape: {bn_layer.weight.shape}")
print(f"BatchNorm2d bias shape: {bn_layer.bias.shape}")

# Forward pass in evaluation mode
model.eval()
with torch.no_grad():
    output_eval = model(x)
    print(f"\nEval output shape: {output_eval.shape}")
    print(f"Eval output (first 3 values): {output_eval[0, :3]}")

# Show the difference: training uses batch stats, eval uses running stats
print(f"\nOutputs differ slightly (training vs eval): {not torch.allclose(output_train, output_eval)}")
Output
Training output shape: torch.Size([8, 10])
Training output (first 3 values): tensor([-0.2847,  0.1923,  0.4156], grad_fn=<SliceBackward0>)

BatchNorm2d running mean (first 5 channels): tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
BatchNorm2d running var (first 5 channels): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
BatchNorm2d weight shape: torch.Size([32])
BatchNorm2d bias shape: torch.Size([32])

Eval output shape: torch.Size([8, 10])
Eval output (first 3 values): tensor([-0.2847,  0.1923,  0.4156])

Outputs differ slightly (training vs eval): False

What just happened?

We built a 2-layer CNN with BatchNorm2d after each convolution. In training mode, the model computed batch statistics on the fly and applied them to normalize the layer. The running_mean and running_var stayed at their initialization (0 and 1) because only one batch passed through: in real training, these would accumulate an exponential moving average. In eval mode, we used torch.no_grad() (disables gradient tracking for speed) and the BatchNorm2d layer used its running statistics instead of batch statistics. The outputs are identical here because the running stats are still at their defaults; in a trained model, they would differ.

Common gotcha

The most common mistake: forgetting to call model.eval() before inference, or leaving the model in training mode during evaluation. In training mode, BatchNorm2d recomputes statistics from the current batch, which produces different outputs for the same input depending on batch composition. For deployment or validation, always use model.eval() and torch.no_grad(). Another gotcha: if your batch size is 1, batch statistics are meaningless (variance of one sample is zero), causing NaN or numerical instability: use instance normalization or group norm instead.

Error recovery

RuntimeError: expected scalar type Double but found Float
BatchNorm2d parameters are float32 by default, but input is float64. Call <code>model = model.float()</code> or pass <code>x = x.float()</code> to match dtypes.
RuntimeError: running_mean and running_var of module 1 have not been updated
You called model.eval() before any training passes, so running statistics were never computed. This is not actually an error in modern PyTorch (2.11.x), but if it occurs in older code, run at least one training iteration with model.train() first.
NaN in output after backward
Batch size is too small (<4) or very skewed, causing unreliable batch statistics. Increase batch size or switch to GroupNorm/InstanceNorm for small batches.

Experienced dev note

BatchNorm2d's behavior during training vs. eval is often underestimated. New developers sometimes treat it like a stateless activation function, but it accumulates running statistics that define its inference behavior. The momentum parameter (default 0.1) controls how aggressively running statistics update: set it lower (0.01) if you have many small batches in training, higher if batches are stable. Also, BatchNorm is sensitive to learning rate: it can mask poor learning rate choices, so if you remove it later, training may collapse. Finally, in distributed training, use SyncBatchNorm across devices to normalize over the global batch, not per-device batches.

Check your understanding

Why would using BatchNorm2d with a batch size of 1 be problematic, and what layer would you use instead?

Show answer hint

A correct answer explains that batch statistics (mean and variance) become meaningless or zero with a single sample, and mentions an alternative like InstanceNorm, GroupNorm, or LayerNorm that normalizes per-sample rather than per-batch.

VERSION PyTorch 2.11.x (March 2026) deprecated the num_batches_tracked attribute in favor of computing it on-the-fly. If upgrading from PyTorch < 1.1.0, note that nn.BatchNorm2d now requires the num_features argument; the API is stable and unchanged since then.
NEXT

Explore <code>nn.GroupNorm</code> and <code>nn.LayerNorm</code> as alternatives to BatchNorm2d for small batches or when per-sample normalization is needed.

Community Notes

No notes yetBe the first to share a version-specific fix or tip.