GAN has been the talk of the town since its inception in 2014 by Goodfellow. In this tutorial, you’ll learn to train your first GAN in PyTorch. We also try to explain the inner working of GAN and walk through a simple implementation of GAN with PyTorch.
Libraries to Import
We first import the libraries and functions that will be used in the implementation.
import torch from torch import nn from torchvision import transforms from torchvision.utils import make_grid from torchvision.datasets import MNIST from torch.utils.data import DataLoader import matplotlib.pyplot as plt from IPython.display import clear_output
What is a GAN?
A generative network can be simply described network that can learn from the training data, and generate data like the training data. There are various ways to design a generative model, one of them being adversarial.
In a generative adversarial network, there are two submodels – the generator and the discriminator. We will be looking into these submodels in more details:
1. The Generator
The generator as the name suggests is assigned with the task of generating an image.
The generator takes in small low dimensional input(generally a 1-D vector) and gives the image data of dimension 128x128x3 as output.
This operation of scaling lower dimension to higher dimension is achieved using series deconvolution and convolution layers.
Our generator can be considered as a function that takes in low dimensional data and maps it to the high-dimensional image data.
Over the training period, the generator learns how to map the low dimensional to the high dimensional data more and more effectively.
The goal of the generator is to generate an image that can fool the discriminator for a real image.
The Generator Class:
class Generator(nn.Module): def __init__(self, z_dim, im_chan, hidden_dim=64): super().__init__() self.z_dim = z_dim self.gen = nn.Sequential( # We define the generator as stacks of deconvolution layers # with batch normalization and non-linear activation function # You can try to play with the values of the layers nn.ConvTranspose2d(z_dim, 4*hidden_dim, 3, 2), nn.BatchNorm2d(4*hidden_dim), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, 4, 1), nn.BatchNorm2d(hidden_dim*2), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden_dim * 2, hidden_dim ,3 ,2), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden_dim, im_chan, 4, 2), nn.Tanh() ) def forward(self, noise): # Define how the generator computes the output noise = noise.view(len(noise), self.z_dim, 1, 1) return self.gen(noise)
# We define a generator with latent dimension 100 and img_dim 1 gen = Generator(100, 1) print("Composition of the Generator:", end="\n\n") print(gen)
Compostion of the Generator: Generator( (gen): Sequential( (0): ConvTranspose2d(100, 256, kernel_size=(3, 3), stride=(2, 2)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1)) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) (6): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2)) (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (8): ReLU(inplace=True) (9): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2)) (10): Tanh() ) )
Additional Notes: The image is a very high-dimensional data. Even an RGB image of dimension 3x128x128 the size is 49152.
The images that we want lies in the sub-space or manifold of such a huge space.
Ideally, the generator should learn where the subspace is located and is randomly sampled from the learned subspace to produce output.
The search for this ideal subspace is a very computationally expensive task, to deal with this most common way is to map a latent vector space to the data space using a push forward.
2. The Discriminator
Our Discriminator D has a simpler, but none the less important task at hand. The discriminator is a binary classifier that indicates whether the input data is from the original source or from our Generator. An ideal discriminator should classify the data from the original distribution as true, and the data from G as fake.
class Discriminator(nn.Module): def __init__(self, im_chan, hidden_dim=16): super().__init__() self.disc = nn.Sequential( # Discriminator is defined as a stack of # convolution layers with batch normalization # and non-linear activations. nn.Conv2d(im_chan, hidden_dim, 4, 2), nn.BatchNorm2d(hidden_dim), nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(hidden_dim, hidden_dim * 2, 4, 2), nn.BatchNorm2d(hidden_dim*2), nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(hidden_dim*2, 1, 4, 2) ) def forward(self, image): disc_pred = self.disc(image) return disc_pred.view(len(disc_pred), -1)
# We define a discriminator for one class classification disc = Discriminator(1) print("Composition of the Discriminator:", end="\n\n") print(disc)
Composition of the Discriminator: Discriminator( (disc): Sequential( (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): LeakyReLU(negative_slope=0.2, inplace=True) (3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2)) (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): LeakyReLU(negative_slope=0.2, inplace=True) (6): Conv2d(32, 1, kernel_size=(4, 4), stride=(2, 2)) )
Loss Functions in a GAN
Now we define the loss for the generator and the discriminator.
1. Generator Loss
The generator tries to generate images that can fool the discriminator to consider them as real.
So the generator tries to maximize the probability of assigning fake images to true label.
So the generator loss is the expected probability that the discriminator classifies the generated image as fake.
def gen_loss(gen, disc, num_images, latent_dim, device): # Generate the the fake images noise = random_noise(num_images, latent_dim).to(device) gen_img = gen(noise) # Pass through discriminator and find the binary cross entropy loss disc_gen = disc(gen_img) gen_loss = Loss(disc_gen, torch.ones_like(disc_gen)) return gen_loss
2. Discriminator Loss
We want the discriminator to maximize the probability of assigning the true label to real images and maximize the probability of assigning the fake label to the fake images.
Similar to generator loss the discriminator loss is the probability that the real image is classified as fake and the fake image is classified real.
Notice how the loss function of our two models acts against each other.
def disc_loss(gen, disc, real_images, num_images, latent_dim, device): # Generate the fake images noise = random_noise(num_images, latent_dim).to(device); img_gen = gen(noise).detach() # Pass the real and fake images through discriminator disc_gen = disc(img_gen) disc_real = disc(real_images) # Find loss for the generator and discriminator gen_loss = Loss(disc_gen, torch.zeros_like(disc_gen)) real_loss = Loss(disc_real, torch.ones_like(disc_real)) # Average over the losses for the discriminator loss disc_loss = ((gen_loss + real_loss) /2).mean() return disc_loss
Loading up the MNIST Training Dataset
# Set the batch size BATCH_SIZE = 512 # Download the data in the Data folder in the directory above the current folder data_iter = DataLoader( MNIST('../Data', download=True, transform=transforms.ToTensor()), batch_size=BATCH_SIZE, shuffle=True)
Initializing the model
Set the hyper parameters of the models.
# Set Loss as Binary CrossEntropy with logits Loss = nn.BCEWithLogitsLoss() # Set the latent dimension latent_dim = 100 display_step = 500 # Set the learning rate lr = 0.0002 # Set the beta_1 and beta_2 for the optimizer beta_1 = 0.5 beta_2 = 0.999
Set the device to cpu or cuda depending on whether you have hardware acceleration enabled.
device = "cpu" if torch.cuda.is_available(): device = "cuda" device
Now we initialize the generator, the discriminator and the optimizers. We also initialize the the layer’s starting/initial weights.
# Initialize the Generator and the Discriminator along with # their optimizer gen_opt and disc_opt # We choose ADAM as the optimizer for both models gen = Generator(latent_dim, 1).to(device) gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2)) disc = Discriminator(1 ).to(device) disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2)) # Initialize the weights of the various layers def weights_init(m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): torch.nn.init.normal_(m.weight, 0.0, 0.02) if isinstance(m, nn.BatchNorm2d): torch.nn.init.normal_(m.weight, 0.0, 0.02) torch.nn.init.constant_(m.bias, 0) # Apply the initial weights on the generator and discriminator gen = gen.apply(weights_init) disc = disc.apply(weights_init)
Setting up the Utility Functions
We always need some utility functions that do not fit specifically into our application but makes some of our tasks easier. We define a function that can display images in a grid, making use of the torchvision make_grid function.
def display_images(image_tensor, num_images=25, size=(1, 28, 28)): image_unflat = image_tensor.detach().cpu().view(-1, *size) image_grid = make_grid(image_unflat[:num_images], nrow=5) plt.imshow(image_grid.permute(1, 2, 0).squeeze()) plt.show()
We define a noise function to generate random noise that will be used a input to the generator.
def random_noise(n_samples, z_dim): return torch.randn(n_samples, z_dim)
Training loop for our GAN in PyTorch
# Set the number of epochs num_epochs = 100 # Set the interval at which generated images will be displayed display_step = 100 # Inter parameter itr = 0 for epoch in range(num_epochs): for images, _ in data_iter: num_images = len(images) # Transfer the images to cuda if harware accleration is present real_images = images.to(device) # Discriminator step disc_opt.zero_grad() D_loss = disc_loss(gen, disc, real_images, num_images, latent_dim, device) D_loss.backward(retain_graph=True) disc_opt.step() # Generator Step gen_opt.zero_grad() G_loss = gen_loss(gen, disc, num_images, latent_dim, device) G_loss.backward(retain_graph=True) gen_opt.step() if itr% display_step ==0 : with torch.no_grad(): # Clear the previous output clear_output(wait=True) noise = noise = random_noise(25,latent_dim).to(device) img = gen(noise) # Display the generated images display_images(img) itr+=1
These are some of the results of our GAN.
We have seen how we can generate new images from a set of images. GANs are not restricted to images of numbers. Modern GANs are powerful enough to generate real looking human faces. GANs are now being used to generate music, art etc. If you want to learn more about the working of GANs you can refer to this original GAN paper by Goodfellow.