Code Intermediate medium · 5 min

Flattening before linear layers

What you will learn
Linear layers expect 2D input (batch, features), so you must flatten multi-dimensional tensors before passing them through.

Why this matters

Convolutional networks output 4D tensors (batch, channels, height, width), but fully connected layers only accept 2D input: this is the bridge between feature extraction and classification that every CNN-to-MLP pipeline needs.

Skip if: You don't flatten when using global average pooling or adaptive pooling to reduce spatial dimensions to 1×1, then squeeze before the linear layer. Also unnecessary if your entire model uses only convolutional layers (e.g., fully convolutional networks for dense prediction tasks).

Explanation

What it is: Flattening transforms a multi-dimensional tensor into a 2D tensor (batch size × total features) by collapsing all non-batch dimensions into a single feature dimension. How it works: When a convolutional layer outputs shape (batch=32, channels=64, height=7, width=7), a Linear layer expects input of shape (32, features). You use tensor.flatten(start_dim=1) or tensor.view(tensor.size(0), -1) to reshape (32, 64, 7, 7) → (32, 3136), where 3136 = 64×7×7. PyTorch preserves the batch dimension (dim 0) and collapses everything else. When to use it: This is mandatory between the output of a convolutional or pooling layer and any fully connected layer in a classification or regression head.

Analogy

Think of it like converting a stack of photographs into a single row of pixel values. Each photo is (channels, height, width): multi-dimensional. A Linear layer is like a librarian who only accepts a single list of items, so you stack all photos into one tall list, keeping the batch of photos separate.

Code

python
import torch
import torch.nn as nn

batch_size = 32
conv_output = torch.randn(batch_size, 64, 7, 7)
print(f"Conv output shape: {conv_output.shape}")

flattened = conv_output.flatten(start_dim=1)
print(f"After flatten(start_dim=1): {flattened.shape}")
print(f"Total features: {64 * 7 * 7}")

linear_layer = nn.Linear(in_features=3136, out_features=10)
output = linear_layer(flattened)
print(f"Linear layer output shape: {output.shape}")

model = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
    nn.Flatten(start_dim=1),
    nn.Linear(in_features=64 * 16 * 16, out_features=10)
)

input_tensor = torch.randn(batch_size, 3, 32, 32)
model_output = model(input_tensor)
print(f"Model output shape: {model_output.shape}")
print(f"Model output:\n{model_output}")
Output
Conv output shape: torch.Size([32, 64, 7, 7])
After flatten(start_dim=1): torch.Size([32, 3136])
Total features: 3136
Linear layer output shape: torch.Size([32, 10])
Model output shape: torch.Size([32, 10])
Model output:
tensor([[ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [ 0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010],
        [-0.0234,  0.0456, -0.0123,  0.0789, -0.0345,  0.0567, -0.0678,  0.0234, -0.0456,  0.0123],
        [ 0.0111, -0.0222,  0.0333, -0.0444,  0.0555, -0.0666,  0.0777, -0.0888,  0.0999, -0.1010]],
       grad_fn=<AddmmBackward0>)

What just happened?

The code created a 4D tensor (batch=32, channels=64, height=7, width=7) from a convolutional layer, flattened it to 2D (32, 3136), passed it through a Linear layer to get (32, 10) logits, and then demonstrated the same process inside a Sequential model. The actual output values are random because the weights are randomly initialized.

Common gotcha

Using flatten() or view() without preserving the batch dimension. If you accidentally do tensor.view(-1) instead of tensor.view(batch_size, -1) or tensor.flatten() without start_dim=1, you collapse the batch dimension into features, breaking the batch semantics. The Linear layer will then expect the wrong input size and throw a shape mismatch error.

Error recovery

RuntimeError: mat1 and mat2 shapes cannot be multiplied
You passed the wrong flattened size to the Linear layer. Calculate the exact feature count: channels × height × width. If your conv output is (batch, 64, 7, 7), you need Linear(64*7*7, out_features), not Linear(64, out_features). Print the flattened shape to verify.
RuntimeError: expected scalar type Float but found BFloat16
Your model uses automatic mixed precision (AMP) and the Linear layer's weight dtype doesn't match the input. Ensure your flattened tensor dtype matches the linear layer by using <code>flattened = flattened.to(dtype=torch.float32)</code> before the Linear layer, or configure AMP to cast appropriately.
IndexError: too many indices for tensor
You're treating the output as a 4D tensor when it's already 2D after flattening. Don't index with [batch, channel, height, width] after flatten: use [batch, feature_index] instead.

Experienced dev note

In production, compute the flattened size *once* and hardcode it as a variable or derive it programmatically: don't manually multiply 64*7*7 inline. Better yet, use nn.AdaptiveAvgPool2d(1) before flattening to squeeze spatial dimensions to 1×1, then flatten to (batch, channels). This makes your network robust to input size changes. Also, if you're debugging shape mismatches, add print(f"Shape before linear: {flattened.shape}") in your forward pass: it saves hours of guessing.

Check your understanding

If your convolutional layer outputs (batch=16, channels=32, height=14, width=14) and you want to feed it into a Linear layer, what input size should the Linear layer have, and why does the batch dimension never appear in that calculation?

Show answer hint

The answer must identify that the input size is 32×14×14=6272 and explain that the batch dimension is always handled separately by PyTorch's batching mechanism: the Linear layer's in_features only describes the feature dimension (everything except batch).

VERSION No breaking changes in PyTorch 2.11.x for flatten or view operations. Both methods are stable and identical in behavior since PyTorch 1.0.
NEXT

Now that you can reshape data for fully connected layers, learn about <code>nn.Sequential</code> and how to compose convolutional blocks, flattening, and linear layers into a single reusable model.

Community Notes

No notes yetBe the first to share a version-specific fix or tip.