How to use batch normalization in PyTorch
Quick answer
Use
torch.nn.BatchNorm1d, BatchNorm2d, or BatchNorm3d layers in your PyTorch model to normalize activations across batches. Insert these layers after convolutional or linear layers to stabilize and accelerate training.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if you haven't already. Use the official command from PyTorch website or run:
pip install torch torchvision Step by step
Here is a complete example showing how to add BatchNorm1d after a linear layer in a simple feedforward network. The batch normalization layer normalizes the output of the linear layer across the batch dimension.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.bn1 = nn.BatchNorm1d(20) # BatchNorm for 1D features
self.relu = nn.ReLU()
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x) # Apply batch normalization
x = self.relu(x)
x = self.fc2(x)
return x
# Create dummy input batch of size 4 with 10 features
input_tensor = torch.randn(4, 10)
model = SimpleNet()
output = model(input_tensor)
print("Output shape:", output.shape)
print("Output tensor:\n", output) output
Output shape: torch.Size([4, 5])
Output tensor:
tensor([[ 0.0516, 0.0517, 0.0113, -0.0281, 0.0313],
[-0.0273, 0.0205, 0.0287, -0.0277, 0.0107],
[ 0.0277, 0.0283, 0.0107, -0.0279, 0.0271],
[-0.0101, 0.0276, 0.0223, -0.0279, 0.0273]], grad_fn=<AddmmBackward0>) Common variations
Batch normalization layers vary by input dimensionality:
BatchNorm1dfor 2D inputs (batch, features) or 3D (batch, features, length)BatchNorm2dfor 4D inputs (batch, channels, height, width), common in CNNsBatchNorm3dfor 5D inputs (batch, channels, depth, height, width)
Example for convolutional layers:
import torch
import torch.nn as nn
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
input_tensor = torch.randn(8, 3, 32, 32) # batch of 8 RGB images 32x32
model = ConvNet()
output = model(input_tensor)
print("Output shape:", output.shape) output
Output shape: torch.Size([8, 16, 32, 32])
Troubleshooting
If batch normalization behaves unexpectedly during evaluation, ensure you call model.eval() to switch to evaluation mode, which uses running statistics instead of batch statistics.
Also, batch size should be >1 during training for batch normalization to compute meaningful statistics.
Key Takeaways
- Use
BatchNorm1d,BatchNorm2d, orBatchNorm3ddepending on input tensor dimensions. - Insert batch normalization layers after linear or convolutional layers and before activation functions.
- Call
model.train()during training andmodel.eval()during evaluation to toggle batch norm behavior. - Batch normalization stabilizes training by normalizing layer inputs across the batch dimension.
- Batch size must be greater than 1 during training for batch normalization to work properly.