Building Neural Networks with torch.nn in PyTorch

Torch Nn In PyTorch

PyTorch is a popular Python library that helps all deep learning enthusiasts. The torch.nn module is a very important component of PyTorch which helps with the building and training of neural networks.

In this article, we will take a deep dive into the torch.nn module, its key components, and the implementation of the module in the Python programming language.

The torch.nn module in PyTorch is essential for building and training neural networks. It provides a wide range of pre-defined layers, loss functions, and classes that facilitate the creation and optimization of neural network models.

Recommended: Deep Learning Using PyTorch In 7 Steps

Overview of torch.nn

The torch.nn module in the PyTorch library is used to build and train neural networks. It also provides classes and functions for defining various layers, loss functions, and optimization algorithms. torch.nn is also designed to facilitate the creation of neural networks. Let us look at the key components of torch.nn.

Key Components of torch.nn

  • Modules: The torch.nn module is the base class of all neural network models in PyTorch. It also provides a feasible way to encapsulate parameters.
  • Layers: The torch.nn module has a variety of pre-defined layers like ‘Linear’, ‘ReLu” etc. These layers can easily be used in other neural network models.
  • Loss Functions: The module also has a variety of loss functions such as ‘CrossEntrpyLoss’ to predict the difference between predicted and actual values during training.
  • Optimizers: PyTorch offers a variety of optimization algorithms like the torch.optim module. Optimizers like SGD, Adam, and Adagrad can also be easily used in other neural network modules as well.

Recommended: A Quick Guide to Pytorch Loss Functions

Example 1: Simple Neural Network

Let us now observe a simple Python code that uses torch.nn module.

import torch
import torch.nn as nn
import torch.optim as optim

# Dummy dataset
X = torch.rand((100, 10))
y = torch.randint(0, 2, (100,))

# Define the model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Instantiate the model, loss function, and optimizer
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
epochs = 100
for epoch in range(epochs):
    # Forward pass
    outputs = model(X)
    loss = criterion(outputs, y)

    # Backward pass and optimization

    # Print the loss every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

# Test the trained model
with torch.no_grad():
    test_outputs = model(X)
    predicted_classes = torch.argmax(test_outputs, dim=1)

    accuracy = torch.sum(predicted_classes == y).item() / len(y)
    print(f'Test Accuracy: {accuracy * 100:.2f}%')

In the above code, the torch.nn module in PyTorch helps in the process of creation, training and testing a simple neural network using PyTorch’s ‘torch.nn’ module. This code can be modified into more complex neural networks. Let us look at the output.

Torch Nn Output
torch.nn Output

Example 2: Convolutional Neural Network (CNN)

Let us look at another Python code regarding torch.nn module of Python programming language.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define a simple Convolutional Neural Network (CNN)
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Load the MNIST dataset and apply transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader =, batch_size=64, shuffle=True)

# Instantiate the CNN, define loss function, and choose an optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 5
for epoch in range(epochs):
    total_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Zero the gradients
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Compute the loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update the weights

        total_loss += loss.item() * images.size(0)

    average_loss = total_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {average_loss:.4f}')

print('Training finished.')

# Save the trained model (optional), 'mnist_cnn_model.pth')

In the example above, we have constructed a CNN (Convolutional neural network) and train it as well. Let us look at the output.

CNN Using Torch Nn
CNN Using Torch.nn


torch.nn is a fundamental module in PyTorch that empowers developers to build and train neural networks efficiently. With a rich set of pre-defined layers, loss functions, and optimization algorithms, it simplifies the process of creating complex models. Whether you’re building a simple feedforward network or a sophisticated CNN, torch.nn has you covered.

So, go ahead and experiment with different architectures and hyperparameters to unleash the full potential of your neural networks. The possibilities are endless with torch.nn!

Recommended: What Is Cross Entropy In Python?