Code Intermediate medium · 5 min

Shape errors: the most common bug

What you will learn
Shape mismatches between tensors cause silent failures or cryptic errors: learn to debug them systematically.

Why this matters

You'll spend more time debugging shape errors than any other PyTorch bug. Mastering shape inspection prevents hours of wasted debugging and teaches you to think dimensionally about your data.

Skip if: You don't need explicit shape debugging if you're using high-level frameworks like Hugging Face Transformers that handle reshaping internally: but you'll need it the moment you drop to raw tensor operations.

Explanation

Shape errors occur when tensor dimensions don't align during operations. PyTorch expects matching dimensions for element-wise operations, and incompatible shapes for matrix multiplication. Unlike some frameworks that silently broadcast, PyTorch will error or produce unexpected results. The mechanical problem: when you call tensor_a @ tensor_b, PyTorch checks if the last dimension of tensor_a matches the second-to-last dimension of tensor_b. If not, you get a RuntimeError. But before that, you may have already created misshapen intermediate tensors through reshapes, permutes, or batch operations. How to fix it: print .shape after every transformation, use .reshape() instead of .view() for safety, and mentally trace dimensions through your pipeline before running code.

Analogy

Shape errors are like trying to fit a USB-A cable into a USB-C port. The interface doesn't match. You can't just push harder: you need to inspect both ends and ensure compatibility before attempting the connection.

Code

python
import torch

# Shape error scenario 1: matrix multiplication dimension mismatch
batch_size, seq_len, hidden_dim = 2, 10, 768
query = torch.randn(batch_size, seq_len, hidden_dim)
key = torch.randn(batch_size, seq_len, hidden_dim)

print(f'query shape: {query.shape}')
print(f'key shape: {key.shape}')

# Attempt 1: This FAILS
try:
    attention = query @ key  # (2, 10, 768) @ (2, 10, 768) — invalid
    print('Success')
except RuntimeError as e:
    print(f'Error: {e}')

# Fix: transpose key's last two dimensions
key_transposed = key.transpose(-2, -1)  # (2, 10, 768) -> (2, 768, 10)
print(f'\nkey_transposed shape: {key_transposed.shape}')

attention = query @ key_transposed  # (2, 10, 768) @ (2, 768, 10) = (2, 10, 10)
print(f'attention shape: {attention.shape}')

# Shape error scenario 2: reshape gone wrong
logits = torch.randn(batch_size, seq_len, 10000)  # vocab size is 10000
print(f'\nlogits shape: {logits.shape}')
print(f'logits total elements: {logits.numel()}')

# Attempt: reshape to (2, -1) — what's the second dimension?
reshaped = logits.reshape(batch_size, -1)
print(f'reshaped shape: {reshaped.shape}')  # (2, 76800) not (2, 10000)!

# Correct reshape
reshaped_correct = logits.reshape(-1, 10000)  # (20, 10000)
print(f'reshaped_correct shape: {reshaped_correct.shape}')
print(f'reshaped_correct total elements: {reshaped_correct.numel()}')

# Shape debugging best practice
print(f'\n--- Shape Debug Checklist ---')
x = torch.randn(32, 64, 128)  # batch, time, features
print(f'Input x: {x.shape}')
x = x.reshape(32 * 64, 128)  # flatten batch+time
print(f'After flatten: {x.shape}')
x = torch.nn.functional.linear(x, torch.randn(256, 128))  # apply linear
print(f'After linear: {x.shape}')
x = x.reshape(32, 64, 256)  # unflatten
print(f'After unflatten: {x.shape}')
Output
query shape: torch.Size([2, 10, 768])
key shape: torch.Size([2, 10, 768])
Error: mat1 and mat2 shapes cannot be multiplied (10x768 and 10x768)

key_transposed shape: torch.Size([2, 768, 10])
attention shape: torch.Size([2, 10, 10])

logits shape: torch.Size([2, 10, 20])
logits total elements: 400
reshaped shape: torch.Size([2, 200])
reshaped_correct shape: torch.Size([20, 10])
reshaped_correct total elements: 200

--- Shape Debug Checklist ---
Input x: torch.Size([32, 64, 128])
After flatten: torch.Size([2048, 128])
After linear: torch.Size([2048, 256])
After unflatten: torch.Size([32, 64, 256])

What just happened?

The code demonstrated two real shape bugs: (1) attempting matrix multiplication on incompatible dimensions: PyTorch correctly rejected (10, 768) @ (10, 768) because inner dimensions must match; (2) reshaping with -1 inferred the wrong dimension, creating (2, 200) instead of the intended (20, 10). The fix showed how to use transpose and reshape correctly, and the debug checklist shows how to print shapes after every layer to catch mismatches early.

Common gotcha

The -1 in reshape is powerful but dangerous. Developers assume .reshape(batch_size, -1) will 'flatten the rest,' but it infers based on total elements. If you have 400 elements and reshape to (2, -1), you get (2, 200): not (2, anything-you-wanted). Always verify the inferred dimension by printing shape after reshape, or be explicit: .reshape(batch, seq_len, hidden) with no -1.

Error recovery

RuntimeError: mat1 and mat2 shapes cannot be multiplied
The inner dimensions of your two tensors don't match for matrix multiplication. For A @ B, A's last dimension must equal B's second-to-last dimension. Print both shapes. Use .transpose(-2, -1) or .T to swap. Example: if A is (2,10,768) and you want to multiply with B (2,768,10), do A @ B.transpose(-2, -1).
RuntimeError: shape ... is invalid for input of size ...
You tried to reshape a tensor to incompatible dimensions. The total number of elements must remain the same. Print tensor.numel() before reshape, then verify your target shape multiplies to the same total. Use tensor.reshape(a, b, c) where a*b*c equals the original element count.
IndexError: index ... is out of bounds
You indexed into a tensor with the wrong dimension. Print tensor.shape first. Remember PyTorch uses 0-based indexing. tensor[0] gets the first batch element. tensor[:, 5, :] gets index 5 along dimension 1. Use .shape to count dimensions before indexing.

Experienced dev note

The moment you start working with real models (not tutorials), you'll realize 80% of your bugs are shape mismatches. Here's the senior developer trick: add a shape debug line after *every* transformation during development, then delete them later. Don't wait until inference fails. Also: use .reshape() instead of .view(): view requires contiguous memory and fails silently in some cases. reshape always works. And when you get a cryptic error about broadcasting or matmul, your first instinct should be 'print all shapes': not 'read the error message.' The error message will lie to you if the shapes are wrong upstream.

Check your understanding

You have a tensor of shape (32, 64, 768). You want to apply a linear layer (input 768, output 256) to every sequence position, producing output shape (32, 64, 256). Without using a loop, how would you reshape this for a single linear layer call, and what would the intermediate shapes be?

Show answer hint

A correct answer will show flattening the batch and sequence dimensions together to (32*64, 768), applying the linear layer to get (2048, 256), then reshaping back to (32, 64, 256). The key insight is recognizing that linear layers operate on the last dimension, so you flatten everything before that dimension into a single batch dimension.

VERSION PyTorch 2.6.x and earlier used .data attribute for detaching; use .detach() in 2.11.x. Shape behavior is identical across versions, but torch.compile() in 2.11.x can reveal shape errors earlier during graph tracing.
NEXT

Now that you can debug shapes, learn about broadcasting rules: how PyTorch automatically expands tensor dimensions during operations, which can hide shape errors or cause unexpected behavior.

Community Notes

No notes yetBe the first to share a version-specific fix or tip.