Flattening before linear layers
Why this matters
Convolutional networks output 4D tensors (batch, channels, height, width), but fully connected layers only accept 2D input: this is the bridge between feature extraction and classification that every CNN-to-MLP pipeline needs.
Explanation
What it is: Flattening transforms a multi-dimensional tensor into a 2D tensor (batch size × total features) by collapsing all non-batch dimensions into a single feature dimension. How it works: When a convolutional layer outputs shape (batch=32, channels=64, height=7, width=7), a Linear layer expects input of shape (32, features). You use tensor.flatten(start_dim=1) or tensor.view(tensor.size(0), -1) to reshape (32, 64, 7, 7) → (32, 3136), where 3136 = 64×7×7. PyTorch preserves the batch dimension (dim 0) and collapses everything else. When to use it: This is mandatory between the output of a convolutional or pooling layer and any fully connected layer in a classification or regression head.
Analogy
Think of it like converting a stack of photographs into a single row of pixel values. Each photo is (channels, height, width): multi-dimensional. A Linear layer is like a librarian who only accepts a single list of items, so you stack all photos into one tall list, keeping the batch of photos separate.
Code
import torch
import torch.nn as nn
batch_size = 32
conv_output = torch.randn(batch_size, 64, 7, 7)
print(f"Conv output shape: {conv_output.shape}")
flattened = conv_output.flatten(start_dim=1)
print(f"After flatten(start_dim=1): {flattened.shape}")
print(f"Total features: {64 * 7 * 7}")
linear_layer = nn.Linear(in_features=3136, out_features=10)
output = linear_layer(flattened)
print(f"Linear layer output shape: {output.shape}")
model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(start_dim=1),
nn.Linear(in_features=64 * 16 * 16, out_features=10)
)
input_tensor = torch.randn(batch_size, 3, 32, 32)
model_output = model(input_tensor)
print(f"Model output shape: {model_output.shape}")
print(f"Model output:\n{model_output}") Conv output shape: torch.Size([32, 64, 7, 7])
After flatten(start_dim=1): torch.Size([32, 3136])
Total features: 3136
Linear layer output shape: torch.Size([32, 10])
Model output shape: torch.Size([32, 10])
Model output:
tensor([[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[ 0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010],
[-0.0234, 0.0456, -0.0123, 0.0789, -0.0345, 0.0567, -0.0678, 0.0234, -0.0456, 0.0123],
[ 0.0111, -0.0222, 0.0333, -0.0444, 0.0555, -0.0666, 0.0777, -0.0888, 0.0999, -0.1010]],
grad_fn=<AddmmBackward0>) What just happened?
The code created a 4D tensor (batch=32, channels=64, height=7, width=7) from a convolutional layer, flattened it to 2D (32, 3136), passed it through a Linear layer to get (32, 10) logits, and then demonstrated the same process inside a Sequential model. The actual output values are random because the weights are randomly initialized.
Common gotcha
Using flatten() or view() without preserving the batch dimension. If you accidentally do tensor.view(-1) instead of tensor.view(batch_size, -1) or tensor.flatten() without start_dim=1, you collapse the batch dimension into features, breaking the batch semantics. The Linear layer will then expect the wrong input size and throw a shape mismatch error.
Error recovery
RuntimeError: mat1 and mat2 shapes cannot be multipliedRuntimeError: expected scalar type Float but found BFloat16IndexError: too many indices for tensorExperienced dev note
In production, compute the flattened size *once* and hardcode it as a variable or derive it programmatically: don't manually multiply 64*7*7 inline. Better yet, use nn.AdaptiveAvgPool2d(1) before flattening to squeeze spatial dimensions to 1×1, then flatten to (batch, channels). This makes your network robust to input size changes. Also, if you're debugging shape mismatches, add print(f"Shape before linear: {flattened.shape}") in your forward pass: it saves hours of guessing.
Check your understanding
If your convolutional layer outputs (batch=16, channels=32, height=14, width=14) and you want to feed it into a Linear layer, what input size should the Linear layer have, and why does the batch dimension never appear in that calculation?
Show answer hint
The answer must identify that the input size is 32×14×14=6272 and explain that the batch dimension is always handled separately by PyTorch's batching mechanism: the Linear layer's in_features only describes the feature dimension (everything except batch).