How to use MSELoss in PyTorch
Quick answer
Use
torch.nn.MSELoss in PyTorch to compute the mean squared error between predicted and target tensors. Instantiate it as a loss function, then call it with model outputs and ground truth labels to get the loss value.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if you haven't already. Use the official command from PyTorch installation guide. For CPU-only:
pip install torch torchvision torchaudiopip install torch torchvision torchaudio Step by step
This example shows how to use torch.nn.MSELoss to compute the mean squared error loss between predictions and targets in a regression setting.
import torch
import torch.nn as nn
# Create dummy prediction and target tensors
predictions = torch.tensor([[0.5], [0.8], [0.3]], dtype=torch.float32)
targets = torch.tensor([[0.0], [1.0], [0.0]], dtype=torch.float32)
# Instantiate MSELoss
criterion = nn.MSELoss()
# Compute loss
loss = criterion(predictions, targets)
print(f'MSE Loss: {loss.item():.4f}') output
MSE Loss: 0.0867
Common variations
You can use reduction parameter in MSELoss to control how the loss is aggregated: 'mean' (default), 'sum', or 'none' to get element-wise loss. Also, use it in training loops with optimizers.
import torch
import torch.nn as nn
# Example with reduction='sum'
criterion_sum = nn.MSELoss(reduction='sum')
loss_sum = criterion_sum(predictions, targets)
print(f'Sum reduction MSE Loss: {loss_sum.item():.4f}')
# Example with reduction='none'
criterion_none = nn.MSELoss(reduction='none')
loss_none = criterion_none(predictions, targets)
print(f'Element-wise MSE Loss:\n{loss_none}') output
Sum reduction MSE Loss: 0.2600
Element-wise MSE Loss:
tensor([[0.2500],
[0.0400],
[0.0900]]) Troubleshooting
- If you get a shape mismatch error, ensure predictions and targets have the same shape and dtype.
- Use
predictions.float()andtargets.float()if needed. - For GPU usage, move tensors and loss function to the same device.
Key Takeaways
- Use
torch.nn.MSELoss()to compute mean squared error between predictions and targets. - The
reductionparameter controls how losses are aggregated: mean, sum, or none. - Ensure predictions and targets have matching shapes and data types to avoid errors.
- MSELoss is commonly used for regression problems in PyTorch models.