How to calculate accuracy in PyTorch
Quick answer
In
PyTorch, calculate accuracy by comparing predicted labels with true labels using tensor operations like torch.argmax and torch.eq. Then compute the mean of correct predictions as accuracy.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if not already installed. Use the following command to install the latest stable version:
pip install torch torchvision Step by step
This example shows how to calculate accuracy for a classification model's output logits and true labels tensors.
import torch
# Example logits from model (batch_size=4, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1],
[0.5, 2.5, 0.3],
[1.2, 0.7, 1.8],
[0.1, 0.2, 0.3]])
# True labels
labels = torch.tensor([0, 1, 2, 2])
# Get predicted class indices
preds = torch.argmax(logits, dim=1)
# Compare predictions with true labels
correct = torch.eq(preds, labels)
# Calculate accuracy as mean of correct predictions
accuracy = correct.float().mean().item()
print(f"Accuracy: {accuracy * 100:.2f}%") output
Accuracy: 75.00%
Common variations
You can calculate accuracy during training or evaluation loops. For multi-label classification, use thresholding instead of argmax. To compute accuracy on GPU, ensure tensors are on the same device.
def accuracy(logits, labels):
preds = torch.argmax(logits, dim=1)
return (preds == labels).float().mean().item()
# Example usage in training loop
for batch in dataloader:
inputs, labels = batch
outputs = model(inputs)
acc = accuracy(outputs, labels)
print(f"Batch accuracy: {acc * 100:.2f}%") Troubleshooting
- If accuracy is always zero, verify that predictions and labels have matching shapes and are on the same device.
- For multi-class problems, ensure you use
torch.argmaxon logits, not probabilities. - Check that labels are integer class indices, not one-hot encoded vectors.
Key Takeaways
- Use
torch.argmaxon model outputs to get predicted classes for accuracy calculation. - Compare predictions with true labels using
torch.eqand average correct predictions. - Ensure tensors are on the same device and have compatible shapes to avoid errors.
- For multi-label tasks, adapt accuracy calculation with thresholding instead of argmax.
- Wrap accuracy calculation in a reusable function for clean training and evaluation code.