Training your first GAN in PyTorch


GAN has been the talk of the town since its inception in 2014 by Goodfellow. In this tutorial, you’ll learn to train your first GAN in PyTorch. We also try to explain the inner working of GAN and walk through a simple implementation of GAN with PyTorch.

Libraries to Import

We first import the libraries and functions that will be used in the implementation.

import torch
from torch import nn

from torchvision import transforms
from torchvision.utils import make_grid

from torchvision.datasets import MNIST 
from import DataLoader

import matplotlib.pyplot as plt
from IPython.display import clear_output

What is a GAN?

A generative network can be simply described network that can learn from the training data, and generate data like the training data. There are various ways to design a generative model, one of them being adversarial.

In a generative adversarial network, there are two submodels – the generator and the discriminator. We will be looking into these submodels in more details:

1. The Generator

The generator as the name suggests is assigned with the task of generating an image.

The generator takes in small low dimensional input(generally a 1-D vector) and gives the image data of dimension 128x128x3 as output.

This operation of scaling lower dimension to higher dimension is achieved using series deconvolution and convolution layers.

Our generator can be considered as a function that takes in low dimensional data and maps it to the high-dimensional image data.

Over the training period, the generator learns how to map the low dimensional to the high dimensional data more and more effectively.

The goal of the generator is to generate an image that can fool the discriminator for a real image.

Generator Logic
Fig 1: Working of the Generator

The Generator Class:

class Generator(nn.Module):
  def __init__(self, z_dim, im_chan, hidden_dim=64):
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            # We define the generator as stacks of deconvolution layers
            # with batch normalization and non-linear activation function
            # You can try to play with the values of the layers

            nn.ConvTranspose2d(z_dim, 4*hidden_dim, 3, 2),

            nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, 4, 1),

            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim ,3 ,2),
            nn.ConvTranspose2d(hidden_dim, im_chan, 4, 2),
  def forward(self, noise):
      # Define how the generator computes the output

      noise = noise.view(len(noise), self.z_dim, 1, 1)
      return self.gen(noise)
# We define a generator with latent dimension 100 and img_dim 1
gen = Generator(100, 1)
print("Composition of the Generator:", end="\n\n")
Compostion of the Generator:

  (gen): Sequential(
    (0): ConvTranspose2d(100, 256, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2))
    (10): Tanh()

Additional Notes: The image is a very high-dimensional data. Even an RGB image of dimension 3x128x128 the size is 49152.

The images that we want lies in the sub-space or manifold of such a huge space.

Ideally, the generator should learn where the subspace is located and is randomly sampled from the learned subspace to produce output.

The search for this ideal subspace is a very computationally expensive task, to deal with this most common way is to map a latent vector space to the data space using a push forward.

2. The Discriminator

Our Discriminator D has a simpler, but none the less important task at hand. The discriminator is a binary classifier that indicates whether the input data is from the original source or from our Generator. An ideal discriminator should classify the data from the original distribution as true, and the data from G as fake.

Discriminator Logic
Fig 2: Working of the disrciminator
class Discriminator(nn.Module):
    def __init__(self, im_chan, hidden_dim=16):
        self.disc = nn.Sequential(
            # Discriminator is defined as a stack of
            # convolution layers with batch normalization
            # and non-linear activations.

            nn.Conv2d(im_chan, hidden_dim, 4, 2),
            nn.Conv2d(hidden_dim, hidden_dim * 2, 4, 2),
            nn.Conv2d(hidden_dim*2, 1, 4, 2)

    def forward(self, image):

        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)
# We define a discriminator for one class classification
disc = Discriminator(1)
print("Composition of the Discriminator:", end="\n\n")
Composition of the Discriminator:

  (disc): Sequential(
    (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(32, 1, kernel_size=(4, 4), stride=(2, 2))
Model Working
Fig 3: Working of the model

Loss Functions in a GAN

Now we define the loss for the generator and the discriminator.

1. Generator Loss

The generator tries to generate images that can fool the discriminator to consider them as real.

So the generator tries to maximize the probability of assigning fake images to true label.

So the generator loss is the expected probability that the discriminator classifies the generated image as fake.

def gen_loss(gen, disc, num_images, latent_dim, device):
    # Generate the the fake images
    noise = random_noise(num_images, latent_dim).to(device)
    gen_img = gen(noise)
    # Pass through discriminator and find the binary cross entropy loss
    disc_gen = disc(gen_img)
    gen_loss = Loss(disc_gen, torch.ones_like(disc_gen))
    return gen_loss

2. Discriminator Loss

We want the discriminator to maximize the probability of assigning the true label to real images and maximize the probability of assigning the fake label to the fake images.

Similar to generator loss the discriminator loss is the probability that the real image is classified as fake and the fake image is classified real.

Notice how the loss function of our two models acts against each other.

def disc_loss(gen, disc, real_images, num_images, latent_dim, device):
    # Generate the fake images
    noise = random_noise(num_images, latent_dim).to(device);
    img_gen = gen(noise).detach()
    # Pass the real and fake images through discriminator
    disc_gen = disc(img_gen)
    disc_real = disc(real_images)
    # Find loss for the generator and discriminator
    gen_loss  = Loss(disc_gen, torch.zeros_like(disc_gen))
    real_loss = Loss(disc_real, torch.ones_like(disc_real))
    # Average over the losses for the discriminator loss
    disc_loss = ((gen_loss + real_loss) /2).mean()

    return disc_loss

Loading up the MNIST Training Dataset

We load the the MNIST training data. We will be using the torchvision package for downloading the required dataset.

# Set the batch size

# Download the data in the Data folder in the directory above the current folder
data_iter = DataLoader(
                MNIST('../Data', download=True, transform=transforms.ToTensor()),

Initializing the model

Set the hyper parameters of the models.

# Set Loss as Binary CrossEntropy with logits 
Loss = nn.BCEWithLogitsLoss()
# Set the latent dimension
latent_dim = 100
display_step = 500
# Set the learning rate
lr = 0.0002

# Set the beta_1 and beta_2 for the optimizer
beta_1 = 0.5 
beta_2 = 0.999

Set the device to cpu or cuda depending on whether you have hardware acceleration enabled.

device = "cpu"
if torch.cuda.is_available():
  device = "cuda"

Now we initialize the generator, the discriminator and the optimizers. We also initialize the the layer’s starting/initial weights.

# Initialize the Generator and the Discriminator along with
# their optimizer gen_opt and disc_opt
# We choose ADAM as the optimizer for both models
gen = Generator(latent_dim, 1).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator(1 ).to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# Initialize the weights of the various layers
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

# Apply the initial weights on the generator and discriminator 
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

Setting up the Utility Functions

We always need some utility functions that do not fit specifically into our application but makes some of our tasks easier. We define a function that can display images in a grid, making use of the torchvision make_grid function.

def display_images(image_tensor, num_images=25, size=(1, 28, 28)):

    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())

We define a noise function to generate random noise that will be used a input to the generator.

def random_noise(n_samples, z_dim):
  return torch.randn(n_samples, z_dim)

Training loop for our GAN in PyTorch

# Set the number of epochs
num_epochs = 100
# Set the interval at which generated images will be displayed
display_step = 100
# Inter parameter
itr = 0

for epoch in range(num_epochs):
  for images, _ in data_iter:
   num_images = len(images)
   # Transfer the images to cuda if harware accleration is present
   real_images = 
   # Discriminator step
   D_loss = disc_loss(gen, disc, real_images, num_images, latent_dim, device)
   # Generator Step
   G_loss = gen_loss(gen, disc, num_images, latent_dim, device)

   if itr% display_step ==0 :
    with torch.no_grad():
      # Clear the previous output
      noise =  noise = random_noise(25,latent_dim).to(device)
      img = gen(noise)
      # Display the generated images


These are some of the results of our GAN.


We have seen how we can generate new images from a set of images. GANs are not restricted to images of numbers. Modern GANs are powerful enough to generate real looking human faces. GANs are now being used to generate music, art etc. If you want to learn more about the working of GANs you can refer to this original GAN paper by Goodfellow.