PyTorch Lightning: How to Train your First Model?

Training Your First Model Using Pytorch Lightning

In this article, we’ll train our first model with PyTorch Lightning. PyTorch has been the go-to choice for many researchers since its inception in 2016. It became popular because of its more pythonic approach and very strong support for CUDA. However, it has some fundamental issues with boilerplate code. Some features such as distributed training using multiple GPUs are meant for power users.

PyTorch lightning is a wrapper around PyTorch and is aimed at giving PyTorch a Keras-like interface without taking away any of the flexibility. If you already use PyTorch as your daily driver, PyTorch-lightning can be a good addition to your toolset.

Getting Started with PyTorch Lightning

We’ll go over the steps to create our first model here in an easy to follow way. So without any further ado, let’s get right into it!

1. Install PyTorch Lightning

To install PyTorch-lightning you run the simple pip command. The lightning bolts module will also come in handy if you want to start with some pre-defined datasets.

pip install pytorch-lightning lightning-bolts

2. Import the modules

First we import the pytorch and pytorch-lightning modules.

import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl

There might be a usual question: “Why do we need torch when we are already using lightning?”

Well, lightning makes coding in torch faster. Being built on top of torch, lightning allows easy extensibility with torch modules allowing the user to makes critical application-specific changes when necessary.

3. Setting up the MNIST Dataset

Unlike base PyTorch, lightning makes the database code more user-accessible and organized.

A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required.

https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html#what-is-a-datamodule

In the PyTorch a MNIST DataModule is generally defined like:

from torchvision import datasets, transforms

# transforms
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))])

mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=64)

As you can see the DataModule is not really structured into one block. If you wish to add more functionalities like a data preparation step or a validation data loader, the code becomes a lot messier. Lightning organizes the code into a LightningDataModule class.

Defining DataModule in PyTorch-Lightning

1. Setup the dataset

Let us first load and set up the dataset using the LightningDataModule.

from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage = None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])


        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

The preapre_data function downloads the data and saves it in a torch-readable form. The setup function splits the dataset into train, test, and validation. These functions can be arbitrarily complex depending on how much pre-processing the data needs.

2. Defining the DataLoaders

Now that we have the setup, we can add the dataloader functions.

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

3. Final look at the MNIST DataModule

The final LightningDataModule looks like this:

class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage = None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])


        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)


    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

MNIST data-module is predefined in PyTorch-bolts datamodules. If you don’t want to go into the hassle of writing the whole code for yourself, you can just import the datamodule and start working with it instead.

from pl_bolts.datamodules import MNISTDataModule

# Create MNIST DataModule instance
data_module = MNISTDataModule()

Now that we have the ready data in our hand we need the model for training.

Creating a Multi-perceptron Model

A lighting model is very similar to a base PyTorch model class, except it has some special class functions to make the training easier. The __init__ and the forward method is exactly similar to PyTorch. We are creating a 3 layer perception, with the number of perceptions in each layer being (128, 256, 10). There is also an input layer of size 28 * 28 (784) which takes flattened 28×28 MNIST images.

1. Base PyTorch-like Model

class MyMNISTModel(nn.Module):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = nn.Linear(28 * 28, 128)
        # The hidden layer of size 256
        self.layer_2 = nn.Linear(128, 256)
        # 3rd hidden layer of size 10.
        # This the prediction layer
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # Flatten the image into a linear tensor
        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # Pass the tensor through the layers
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        
        # Softmax the values to get a probability
        x = F.log_softmax(x, dim=1)
        return x

Let us check if the model works or not, using a random (28, 28) value.

net = MyMNISTModel()

x = torch.randn(1, 1, 28, 28)
print(net(x).shape)

Output:

torch.Size([1, 10])

The 1 indicates the batches and the 10 indicates the number of output classes. So our model is working fine.

2. Defining the Initialization and forward function

The PyTorch DataModule would look exactly similar except it would derive it’s properties from pl.LightningModule. The lightning network will look like:

class MyMNISTModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        ...
     def forward(self, x):
       ....

In addition to these base torch functions, lighting offers functions that allow us to define what happens inside the training, test and validation loop.

2. Defining training and validation loop

Define the training loop for train and validation step for the model.

    def training_step(self, batch, batch_idx):
        x, y = batch

        # Pass through the forward function of the network
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        y_hat = torch.argmax(logits, dim=1)
        accuracy = torch.sum(y == y_hat).item() / (len(y) * 1.0)
        output = dict({
            'test_loss': loss,
            'test_acc': torch.tensor(accuracy),
        })
        return output


3. Optimizers

The lightning model allows us to define optimizers for the specific model inside the model definition.

    # We are using the ADAM optimizer for this tutorial
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

4. Final look at our model

The final lightning model looks should look like this:

class MyMNISTModel(pl.LightningModule):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = nn.Linear(28 * 28, 128)
        # The hidden layer of size 256
        self.layer_2 = nn.Linear(128, 256)
        # 3rd hidden layer of size 10.
        # This the prediction layer
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # Flatten the image into a linear tensor
        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # Pass the tensor through the layers
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        
        # Softmax the values to get a probability
        x = F.log_softmax(x, dim=1)
        return x


    def training_step(self, batch, batch_idx):
        x, y = batch

        # Pass through the forward function of the network
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        y_hat = torch.argmax(logits, dim=1)
        accuracy = torch.sum(y == y_hat).item() / (len(y) * 1.0)
        output = dict({
            'test_loss': loss,
            'test_acc': torch.tensor(accuracy),
        })
        return output

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

We are now all set with our data and model. Let’s proceed with training the model with the data.

5. Training the model

Instead of the traditional boilerplate loop of finding the loss and doing a backward pass, the trainer in pytorch-lighting module does the job for us without much code.

First we initialize a Trainer in lightning with specific parameters.

from pytorch_lightning import Trainer

# Set gpus = 0 for training on cpu
# Set the max_epochs for maximum number of epochs you want
trainer = Trainer(gpus=1, max_epochs=20)

Fit the dataset with the MNISTDataModule

trainer.fit(net, data_module)
Trainig Step

6. Results

Let’s check the final accuracy on the train dataset,

trainer.test(test_dataloaders=data_module.train_dataloader())

Output:

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(.98), 'test_loss': tensor(0.0017, device='cuda:0')}
--------------------------------------------------------------------------------

Getting high accuracy in the training dataset may indicate overfitting. So we also need to test our model on the test dataset which we had separated earlier. Let’s check the final accuracy of the model on validation dataset.

trainer.test(test_dataloaders=data_module.test_dataloader())

Output:

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(.96), 'test_loss': tensor(0.0021, device='cuda:0')}
--------------------------------------------------------------------------------

So with these results, we confirm that the model has trained well on the data.

Conclusion

So with this, we come to the end of this tutorial on PyTorch-lightning. PyTorch-lightning is relatively new and it’s developing rapidly, so we can expect more features in the near future. So stay tuned for more such articles on machine learning and deep learning.