How to define forward pass in PyTorch
Quick answer
In PyTorch, define the forward pass by subclassing
torch.nn.Module and overriding its forward method, which specifies how input tensors are transformed through the network layers. The forward method takes input tensors and returns output tensors, enabling automatic differentiation during training.PREREQUISITES
Python 3.8+pip install torch>=2.0
Setup
Install PyTorch if not already installed. Use the official command from PyTorch installation guide or run:
pip install torch torchvision Step by step
Define a custom neural network by subclassing torch.nn.Module and implement the forward method to specify the forward pass logic. The example below creates a simple feedforward network with one hidden layer.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5) # input layer to hidden layer
self.relu = nn.ReLU()
self.fc2 = nn.Linear(5, 2) # hidden layer to output layer
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Instantiate the model
model = SimpleNet()
# Create a dummy input tensor with batch size 3 and 10 features
input_tensor = torch.randn(3, 10)
# Run forward pass
output = model(input_tensor)
print(output)
print(output.shape) output
tensor([[ 0.1234, -0.5678],
[ 0.2345, -0.6789],
[-0.3456, 0.7890]], grad_fn=<AddmmBackward0>)
torch.Size([3, 2]) Common variations
You can define the forward pass with different layer types (convolutional, recurrent), include conditional logic, or use functional API calls. The forward method can accept multiple inputs and return multiple outputs. Avoid calling forward directly; instead, call the model instance to ensure hooks and other PyTorch internals work correctly.
import torch.nn.functional as F
class CustomNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 16, 3)
self.fc = nn.Linear(16*26*26, 10)
def forward(self, x, apply_relu=True):
x = self.conv(x)
if apply_relu:
x = F.relu(x)
x = x.view(x.size(0), -1) # flatten
x = self.fc(x)
return x
model = CustomNet()
input_tensor = torch.randn(4, 1, 28, 28)
output = model(input_tensor, apply_relu=False)
print(output.shape) output
torch.Size([4, 10])
Troubleshooting
- If you get shape mismatch errors, verify tensor dimensions match layer expectations.
- Do not call
forwarddirectly; always call the model instance. - Use
print(x.shape)insideforwardto debug tensor shapes. - Ensure
super().__init__()is called in your__init__method.
Key Takeaways
- Override the
forwardmethod in a subclass oftorch.nn.Moduleto define the forward pass. - The
forwardmethod transforms input tensors through layers and returns output tensors. - Always call the model instance (e.g.,
model(input)) instead of callingforwarddirectly. - Use PyTorch layers and functional APIs inside
forwardfor flexible model design. - Debug tensor shapes inside
forwardto avoid shape mismatch errors.