How to beginner · 3 min read

How to calculate accuracy in PyTorch

Quick answer
In PyTorch, calculate accuracy by comparing predicted labels with true labels using tensor operations like torch.argmax and torch.eq. Then compute the mean of correct predictions as accuracy.

PREREQUISITES

  • Python 3.8+
  • pip install torch>=2.0

Setup

Install PyTorch if not already installed. Use the following command to install the latest stable version:

bash
pip install torch torchvision

Step by step

This example shows how to calculate accuracy for a classification model's output logits and true labels tensors.

python
import torch

# Example logits from model (batch_size=4, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1],
                       [0.5, 2.5, 0.3],
                       [1.2, 0.7, 1.8],
                       [0.1, 0.2, 0.3]])

# True labels
labels = torch.tensor([0, 1, 2, 2])

# Get predicted class indices
preds = torch.argmax(logits, dim=1)

# Compare predictions with true labels
correct = torch.eq(preds, labels)

# Calculate accuracy as mean of correct predictions
accuracy = correct.float().mean().item()

print(f"Accuracy: {accuracy * 100:.2f}%")
output
Accuracy: 75.00%

Common variations

You can calculate accuracy during training or evaluation loops. For multi-label classification, use thresholding instead of argmax. To compute accuracy on GPU, ensure tensors are on the same device.

python
def accuracy(logits, labels):
    preds = torch.argmax(logits, dim=1)
    return (preds == labels).float().mean().item()

# Example usage in training loop
for batch in dataloader:
    inputs, labels = batch
    outputs = model(inputs)
    acc = accuracy(outputs, labels)
    print(f"Batch accuracy: {acc * 100:.2f}%")

Troubleshooting

  • If accuracy is always zero, verify that predictions and labels have matching shapes and are on the same device.
  • For multi-class problems, ensure you use torch.argmax on logits, not probabilities.
  • Check that labels are integer class indices, not one-hot encoded vectors.

Key Takeaways

  • Use torch.argmax on model outputs to get predicted classes for accuracy calculation.
  • Compare predictions with true labels using torch.eq and average correct predictions.
  • Ensure tensors are on the same device and have compatible shapes to avoid errors.
  • For multi-label tasks, adapt accuracy calculation with thresholding instead of argmax.
  • Wrap accuracy calculation in a reusable function for clean training and evaluation code.
Verified 2026-04
Verify ↗