Skip to main content

Generative Adversarial Networks (GANs)

Generative Adversarial Networks, or GANs, are a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in 2014. Two neural networks contest with each other in a game (in the form of a zero-sum game, where one agent's gain is another agent's loss).

In this tutorial, we'll learn how to code a GAN with PyTorch. We will do this step-by-step, so don't worry if you're not familiar with GANs or PyTorch.

Prerequisites

  • Basic knowledge of Python
  • Understanding of PyTorch basics
  • Familiarity with deep learning concepts

Content

  1. Introduction to GANs
  2. Building a GAN in PyTorch
  3. Training a GAN
  4. Visualizing GAN Results
## Introduction to GANs

GANs consist of two main components: a Generator and a Discriminator. The Generator generates new data instances, while the Discriminator evaluates them for authenticity; i.e. it decides whether each instance of data belongs to the actual training dataset or not.

## Building a GAN in PyTorch

Before we start coding, let's import the necessary libraries.

import torch
from torch import nn

Define the Generator

The Generator is a simple fully connected neural network with one hidden layer. It takes a latent space vector and outputs a data instance.

class Generator(nn.Module):
def __init__(self, input_dim, output_dim):
super(Generator, self).__init__()
self.hidden_layer = nn.Linear(input_dim, 128)
self.output_layer = nn.Linear(128, output_dim)

def forward(self, x):
x = torch.relu(self.hidden_layer(x))
x = self.output_layer(x)
return x

Define the Discriminator

The Discriminator is also a fully connected neural network that takes a data instance and outputs a probability that the data is real.

class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.hidden_layer = nn.Linear(input_dim, 128)
self.output_layer = nn.Linear(128, 1)

def forward(self, x):
x = torch.relu(self.hidden_layer(x))
x = torch.sigmoid(self.output_layer(x))
return x
## Training a GAN

Training a GAN involves updating the Generator and Discriminator in a two-player minimax game. The Generator tries to fool the Discriminator by generating real-looking images, while the Discriminator tries to distinguish real images from fakes.

# Hyperparameters
z_dim = 64
image_dim = 784 # 28*28
lr = 0.00001

# Initialize Generator and Discriminator
generator = Generator(z_dim, image_dim)
discriminator = Discriminator(image_dim)

# Loss and optimizers
loss_fn = nn.BCELoss()
G_opt = torch.optim.Adam(generator.parameters(), lr=lr)
D_opt = torch.optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
for real_images, _ in dataloader:
real_images = real_images.view(real_images.size(0), -1)
batch_size = real_images.shape[0]

# Train Discriminator
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)

real_outputs = discriminator(real_images)
real_loss = loss_fn(real_outputs, real_labels)

z = torch.randn(batch_size, z_dim)
fake_images = generator(z)
fake_outputs = discriminator(fake_images)
fake_loss = loss_fn(fake_outputs, fake_labels)

D_loss = real_loss + fake_loss
discriminator.zero_grad()
D_loss.backward()
D_opt.step()

# Train Generator
z = torch.randn(batch_size, z_dim)
fake_images = generator(z)
outputs = discriminator(fake_images)
G_loss = loss_fn(outputs, real_labels)

generator.zero_grad()
G_loss.backward()
G_opt.step()
## Visualizing GAN Results

To visualize the Generator's progress, we can plot the images it generates after every few epochs.

import matplotlib.pyplot as plt

z = torch.randn(1, z_dim)
gen_image = generator(z).view(28, 28).data
plt.imshow(gen_image, cmap='gray')
plt.show()

Conclusion

In this tutorial, we introduced GANs and implemented one using PyTorch. We trained the GAN on a dataset and visualized the results. GANs are powerful tools that have a lot of potential. They are not without their challenges, but they are a very active area of research and development.

Next, try experimenting with different types of GAN architectures and loss functions. Happy coding!