|
|
- Generative adversarial networks (GAN) are all the buzz in AI right now due to their fantastic ability to create new content.
- Last semester, my final Computer Vision (CSCI-431) research project was on comparing the results of three different GAN architectures using the NMIST dataset.
- I'm writing this post to go over some of the PyTorch code used because PyTorch makes it easy to write GANs.
-
- <customHTML />
-
- # GAN Background
-
- The goal of a GAN is to generate new realistic content.
- There are three major components of training a GAN: the generator, the discriminator, and the loss function.
- The generator is a neural network that will take in random noise and generate a new image.
- The discriminator is a neural network that decides whether the image it sees is real or fake.
- The discriminator is analogous to a detective trying to identify forges.
- The loss function decides how incorrect the discriminator and generator is based on the confidence provided for both real images and fake images.
- Once a GAN is fully trained, the accuracy of the discriminator should be 50% because the generator generates images so good that the discriminator can no longer detect the forges and is just guessing.
-
- ![GAN Archetecture](media/gan/gan-arch.jpg)
-
- Training a GAN can be tricky for a multitude of reasons.
- First, you want to make sure that the Generator and Discriminator learn at the same rate.
- If you start with a Discriminator that is too good, it will always be correct, and the generator will not be able to learn from it.
- Second, compared to other neural networks, training a GAN requires a lot of data.
- In this project, we used the infamous MNIST dataset of handwritten digits containing nearly seventy thousand handwritten numbers.
-
- <youtube src="Sw9r8CL98N0" />
-
- # Vanilla GAN in PyTorch
-
- Our generator is a PyTorch neural network that takes a random vector of size 128x1 and outputs a new vector of size 1024-- which is re-sized to our 32x32 image.
-
- ```python
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
-
- def block(in_feat, out_feat, normalize=True):
- layers = [nn.Linear(in_feat, out_feat)]
- if normalize:
- layers.append(nn.BatchNorm1d(out_feat, 0.8))
- layers.append(nn.LeakyReLU(0.2, inplace=True))
- return layers
-
- self.model = nn.Sequential(
- *block(latent_dim, 128, normalize=False),
- *block(128, 256),
- *block(256, 512),
- *block(512, 1024),
- nn.Linear(1024, int(np.prod(img_shape))),
- nn.Tanh()
- )
-
- def forward(self, z):
- img = self.model(z)
- img = img.view(img.size(0), *img_shape)
- return img
- generator = Generator()
- ```
-
- The discriminator is a neural network that takes in an image and determines whether it is a real or fake image-- similar to code for binary classification.
-
- ```python
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
-
- self.model = nn.Sequential(
- nn.Linear(int(np.prod(img_shape)), 512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(512, 256),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(256, 1),
- nn.Sigmoid(),
- )
-
- def forward(self, img):
- img_flat = img.view(img.size(0), -1)
- validity = self.model(img_flat)
- return validity
- discriminator = Discriminator()
- ```
-
- In this example, we are using the PyTorch data loader for the MNIST dataset.
- The built-in data loader makes our lives more comfortable because it allows us to specify our batch size, downloads the data for us, and even normalizes it.
-
- ```python
- os.makedirs("../data/mnist", exist_ok=True)
- dataloader = torch.utils.data.DataLoader(
- datasets.MNIST(
- "../data/mnist",
- train=True,
- download=True,
- transform=transforms.Compose(
- [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
- ),
- ),
- batch_size=batch_size,
- shuffle=True,
- )
- ```
-
- In this example, we need two optimizers, one for the discriminator and one for the generator.
- We are using the Adam optimizer, which is a first-order gradient-based optimizer that works well within PyTorch.
-
- ```python
- optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
- optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
- adversarial_loss = torch.nn.BCELoss()
- ```
-
-
- The training loop is pretty standard, except that we have two neural networks to optimize each batch cycle.
-
- ```python
- for epoch in range(n_epochs):
-
- # chunks by batch
- for i, (imgs, _) in enumerate(dataloader):
-
- # Adversarial ground truths
- valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
- fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
-
- # Configure input
- real_imgs = Variable(imgs.type(Tensor))
-
- # training for generator
- optimizer_G.zero_grad()
-
- # Sample noise as generator input
- z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
-
- # Generate a batch of images
- gen_imgs = generator(z)
-
- # Loss measures generator's ability to fool the discriminator
- g_loss = adversarial_loss(discriminator(gen_imgs), valid)
-
- g_loss.backward()
- optimizer_G.step()
-
- # Training for discriminator
- optimizer_D.zero_grad()
-
- # Measure discriminator's ability to classify real from generated samples
- real_loss = adversarial_loss(discriminator(real_imgs), valid)
- fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
- d_loss = (real_loss + fake_loss) / 2
-
- d_loss.backward()
- optimizer_D.step()
-
- # total batches ran
- batches_done = epoch * len(dataloader) + i
-
- print(
- "[Epoch %d/%d] [Batch %d/%d] [Batches Done: %d] [D loss: %f] [G loss: %f]"
- % (epoch, n_epochs, i, len(dataloader), batches_done, d_loss.item(), g_loss.item())
- )
- ```
-
- Plus or minus a few things, that is a GAN in PyTorch. Pretty easy, right?
- You can find the full code for both the paper and this blog post on my [github](https://github.com/jrtechs/CSCI-431-final-GANs)
-
- ## Tensorboard
-
- Tensorboard is a library used to visualize the training progress and other aspects of machine learning experimentation.
- It is a little known fact that you can use Tensorboard even if you are using PyTorch since TensorBoard is primarily associated with the TensorFlow framework.
-
- Tensorboard gets installed via pip:
-
- ```
- pip install tensorboard
- ```
-
- Making minimal modifications to our PyTorch code, we can add the TensorBoard logging.
-
-
- ```python
- # inport
- from torch.utils.tensorboard import SummaryWriter
-
- # creates a new tensorboard logger
- writer = SummaryWriter()
-
- # add this to run for each batch
- writer.add_scalar('D_Loss', d_loss.item(), batches_done)
- writer.add_scalar('G_Loss', g_loss.item(), batches_done)
-
- # flushes file
- writer.close()
- ```
-
- After the model finishes training, you can open the TensorBoard logs using the "tensorboard" command in the terminal.
-
- ```
- tensorboard --logdir=runs
- ```
-
- Opening "http://0.0.0.0:6006/" in your browser will give you access to the TensorBoard web UI.
-
- ```
- [TensorBoard screen grab](media/gan/tensorboard.png)
- ```
-
- # Takeaways
-
- Robust, flexible GANs are relatively easy to create in PyTorch.
- For this reason, you find a lot of researchers who use PyTorch in their experimentation.
|