How to use CrossEntropyLoss in PyTorch
Quick answer
Use
torch.nn.CrossEntropyLoss in PyTorch to compute the loss for multi-class classification tasks. It combines nn.LogSoftmax and nn.NLLLoss in one single class and expects raw logits as input and class indices as targets.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if you haven't already. Use the official command from PyTorch installation guide. For CPU-only, run:
pip install torch torchvision Step by step
This example shows how to define a simple neural network, use CrossEntropyLoss for the loss function, and run a single training step with dummy data.
import torch
import torch.nn as nn
import torch.optim as optim
# Define a simple model
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(10, 3) # 10 features, 3 classes
def forward(self, x):
return self.fc(x)
# Instantiate model, loss, optimizer
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Dummy input (batch_size=5, features=10)
inputs = torch.randn(5, 10)
# Dummy target labels (batch_size=5), class indices 0,1,2
labels = torch.tensor([0, 2, 1, 0, 2])
# Forward pass
outputs = model(inputs) # raw logits
# Compute loss
loss = criterion(outputs, labels)
print(f"Loss: {loss.item():.4f}")
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step() output
Loss: 1.2043
Common variations
- Input format:
CrossEntropyLossexpects raw logits, not probabilities. Do not apply softmax before the loss. - Target format: Targets must be class indices (LongTensor), not one-hot encoded.
- Weighting classes: You can pass a
weighttensor to handle class imbalance. - Reduction: Control output with
reductionparameter: 'mean' (default), 'sum', or 'none'.
Troubleshooting
- If you get a shape mismatch error, verify that your model outputs raw logits with shape
(batch_size, num_classes)and targets are 1D with class indices. - If loss is NaN or very large, check that you are not applying softmax before the loss.
- For multi-label classification, use
nn.BCEWithLogitsLossinstead.
Key Takeaways
- Use raw logits as model output input to
CrossEntropyLoss, not probabilities. - Targets must be class indices, not one-hot vectors.
- You can customize loss behavior with class weights and reduction modes.
- Avoid applying softmax before
CrossEntropyLossto prevent numerical instability.