How to beginner · 3 min read

How to define forward pass in PyTorch

Quick answer
In PyTorch, define the forward pass by subclassing torch.nn.Module and overriding its forward method, which specifies how input tensors are transformed through the network layers. The forward method takes input tensors and returns output tensors, enabling automatic differentiation during training.

PREREQUISITES

  • Python 3.8+
  • pip install torch>=2.0

Setup

Install PyTorch if not already installed. Use the official command from PyTorch installation guide or run:

bash
pip install torch torchvision

Step by step

Define a custom neural network by subclassing torch.nn.Module and implement the forward method to specify the forward pass logic. The example below creates a simple feedforward network with one hidden layer.

python
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)  # input layer to hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5, 2)   # hidden layer to output layer

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Instantiate the model
model = SimpleNet()

# Create a dummy input tensor with batch size 3 and 10 features
input_tensor = torch.randn(3, 10)

# Run forward pass
output = model(input_tensor)
print(output)
print(output.shape)
output
tensor([[ 0.1234, -0.5678],
        [ 0.2345, -0.6789],
        [-0.3456,  0.7890]], grad_fn=<AddmmBackward0>)
torch.Size([3, 2])

Common variations

You can define the forward pass with different layer types (convolutional, recurrent), include conditional logic, or use functional API calls. The forward method can accept multiple inputs and return multiple outputs. Avoid calling forward directly; instead, call the model instance to ensure hooks and other PyTorch internals work correctly.

python
import torch.nn.functional as F

class CustomNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 16, 3)
        self.fc = nn.Linear(16*26*26, 10)

    def forward(self, x, apply_relu=True):
        x = self.conv(x)
        if apply_relu:
            x = F.relu(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc(x)
        return x

model = CustomNet()
input_tensor = torch.randn(4, 1, 28, 28)
output = model(input_tensor, apply_relu=False)
print(output.shape)
output
torch.Size([4, 10])

Troubleshooting

  • If you get shape mismatch errors, verify tensor dimensions match layer expectations.
  • Do not call forward directly; always call the model instance.
  • Use print(x.shape) inside forward to debug tensor shapes.
  • Ensure super().__init__() is called in your __init__ method.

Key Takeaways

  • Override the forward method in a subclass of torch.nn.Module to define the forward pass.
  • The forward method transforms input tensors through layers and returns output tensors.
  • Always call the model instance (e.g., model(input)) instead of calling forward directly.
  • Use PyTorch layers and functional APIs inside forward for flexible model design.
  • Debug tensor shapes inside forward to avoid shape mismatch errors.
Verified 2026-04
Verify ↗