avatarNandhini Swaminathan

Summary

This web content provides a concise tutorial on constructing a basic Generative Adversarial Network (GAN) using Python and PyTorch, detailing the necessary code, models, and training process.

Abstract

The article "Building a Simple Python-Based GAN in 5 minutes" serves as a beginner-level guide to understanding and implementing Generative Adversarial Networks (GANs). It explains the concept of GANs, which are deep learning models capable of generating new, synthetic data that closely resembles real input data. The GAN consists of two competing neural networks: a generator that creates synthetic data and a discriminator that distinguishes between real and synthetic data. The tutorial outlines the steps to set up the GAN environment, including defining the generator and discriminator models, setting up the loss function and optimizers, and executing the training loop where the generator and discriminator are iteratively improved. The code provided uses the PyTorch library, and the training process is described with attention to the loss functions and the optimization of the neural networks' parameters using the Adam optimizer. The article also includes further reading resources for those interested in delving deeper into GANs.

Opinions

  • GANs are recognized for their ability to produce new and inspired works, causing both awe and horror in academic circles.
  • The author suggests that GANs are advantageous due to their ability to generate sharp images and their relative ease of training.
  • The tutorial positions GANs as accessible to beginners, with the potential for creating impressive generative models.
  • The use of binary cross-entropy loss and the Adam optimizer is presented as a common and effective choice for training GANs.
  • The article implies that the quality of the generated images and the success of the GAN can be significantly influenced by the choice of hyperparameters and the architecture of the neural networks.

Building a Simple Python-Based GAN in 5 minutes

A beginner-level tutorial

Credits

Generative Adversarial Networks, or GANs, have created an uproar in academic circles for their abilities. The machine’s ability to produce new and inspired works has caused awe and horror in everyone’s mind. And as such, one becomes curious, how to build one?

A Generative Adversarial Network (GAN) is a deep learning model that generates new, synthetic data similar to some input data. GANs consist of two neural networks: a generator and a discriminator. The generator is trained to produce synthetic data identical to the input data, while the discriminator is trained to distinguish between synthetic and real data.

A generative model learns the intrinsic distribution function of the input data f(x), allowing it to generate both synthetic input x’ and output y’, typically given some hidden parameters. GANs are advantageous because they generate the sharpest images and are easy to train.

The Code

This code trains the GAN for a given number of epochs, where an epoch is defined as one pass through the entire dataset. During each epoch, the code iterates over the data in the data loader (which should be a PyTorch DataLoader object that wraps your dataset) and trains both the discriminator and generator on each batch.

The generator is trained by trying to fool the discriminator, which is trained to distinguish real images from fake images. The loss function used here is binary cross-entropy loss, which is a common choice for GANs. The optimizers used are Adam, which is a type of stochastic gradient descent optimizer.

  1. First, import the necessary libraries and define the generator and discriminator models.
import torch
import torch.nn as nn
import torch.optim as optim
  • The generator should be a neural network that takes in a random noise vector and generates synthetic data. At the same time, the discriminator should be a neural network that takes in real or synthetic data and outputs a probability that the input data is real.
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return x
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

2. In the following code block, we set up the environment for the GAN. This includes:

  • Setting the sizes of the input, hidden, and output layers for the discriminator and generator networks.
  • Create an instance of the Generator and Discriminator class
  • Setting up the loss function and optimizers
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the input and output sizes
input_size = 784
hidden_size = 256
output_size = 1

# Create the discriminator and generator
discriminator = Discriminator(input_size, hidden_size, output_size).to(device)
generator = Generator(input_size, hidden_size, output_size).to(device)

# Set the loss function and optimizers
loss_fn = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)

# Set the number of epochs and the noise size
num_epochs = 200
noise_size = 100

# Training loop
for epoch in range(num_epochs):
  for i, (real_images, _) in enumerate(dataloader):
    # Get the batch size
    batch_size = real_images.size(0)

3. In the below code, the generator is trained by trying to fool the discriminator, which is trained to distinguish real and fake images. To do this,

  • We give the generator a batch of noise samples as input and generate a batch of fake images. These fake images are then passed through the discriminator, which produces a prediction for each image in the batch.
  • The loss for the generator is then calculated, and the code back-propagates the loss through the generator and optimizes the generator’s parameters using the Adam optimizer. This process updates the generator’s parameters in a direction that reduces the loss and improves the generator’s ability to fool the discriminator.
  # Generate fake images
  noise = torch.randn(batch_size, noise_size).to(device)
  fake_images = generator(noise)
  
  # Train the discriminator on real and fake images
  d_real = discriminator(real_images)
  d_fake = discriminator(fake_images)
  
  # Calculate the loss
  real_loss = loss_fn(d_real, torch.ones_like(d_real))
  fake_loss = loss_fn(d_fake, torch.zeros_like(d_fake))
  d_loss = real_loss + fake_loss
  
  # Backpropagate and optimize
  d_optimizer.zero_grad()
  d_loss.backward()
  d_optimizer.step()
  
  # Train the generator
  d_fake = discriminator(fake_images)
  g_loss = loss_fn(d_fake, torch.ones_like(d_fake))
  
  # Backpropagate and optimize
  g_optimizer.zero_grad()
  g_loss.backward()
  g_optimizer.step()
  
  # Print the loss every 50 batches
  if (i+1) % 50 == 0:
    print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}' 
          .format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))

And… that’s all. A quick basic GAN model ready to be used.

Further Reading

Gans
Generative Adversarial
Pytorch
Low Code
Quickcode
Recommended from ReadMedium