Personal blog written from scratch using Node.js, Bootstrap, and MySQL. https://jrtechs.net
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
7.4 KiB

  1. Generative adversarial networks (GAN) are all the buzz in AI right now due to their fantastic ability to create new content.
  2. Last semester, my final Computer Vision (CSCI-431) research project was on comparing the results of three different GAN architectures using the NMIST dataset.
  3. I'm writing this post to go over some of the PyTorch code used because PyTorch makes it easy to write GANs.
  4. <customHTML />
  5. # GAN Background
  6. The goal of a GAN is to generate new realistic content.
  7. There are three major components of training a GAN: the generator, the discriminator, and the loss function.
  8. The generator is a neural network that will take in random noise and generate a new image.
  9. The discriminator is a neural network that decides whether the image it sees is real or fake.
  10. The discriminator is analogous to a detective trying to identify forges.
  11. The loss function decides how incorrect the discriminator and generator is based on the confidence provided for both real images and fake images.
  12. Once a GAN is fully trained, the accuracy of the discriminator should be 50% because the generator generates images so good that the discriminator can no longer detect the forges and is just guessing.
  13. ![GAN Archetecture](media/gan/gan-arch.jpg)
  14. Training a GAN can be tricky for a multitude of reasons.
  15. First, you want to make sure that the Generator and Discriminator learn at the same rate.
  16. If you start with a Discriminator that is too good, it will always be correct, and the generator will not be able to learn from it.
  17. Second, compared to other neural networks, training a GAN requires a lot of data.
  18. In this project, we used the infamous MNIST dataset of handwritten digits containing nearly seventy thousand handwritten numbers.
  19. <youtube src="Sw9r8CL98N0" />
  20. # Vanilla GAN in PyTorch
  21. Our generator is a PyTorch neural network that takes a random vector of size 128x1 and outputs a new vector of size 1024-- which is re-sized to our 32x32 image.
  22. ```python
  23. class Generator(nn.Module):
  24. def __init__(self):
  25. super(Generator, self).__init__()
  26. def block(in_feat, out_feat, normalize=True):
  27. layers = [nn.Linear(in_feat, out_feat)]
  28. if normalize:
  29. layers.append(nn.BatchNorm1d(out_feat, 0.8))
  30. layers.append(nn.LeakyReLU(0.2, inplace=True))
  31. return layers
  32. self.model = nn.Sequential(
  33. *block(latent_dim, 128, normalize=False),
  34. *block(128, 256),
  35. *block(256, 512),
  36. *block(512, 1024),
  37. nn.Linear(1024, int(np.prod(img_shape))),
  38. nn.Tanh()
  39. )
  40. def forward(self, z):
  41. img = self.model(z)
  42. img = img.view(img.size(0), *img_shape)
  43. return img
  44. generator = Generator()
  45. ```
  46. The discriminator is a neural network that takes in an image and determines whether it is a real or fake image-- similar to code for binary classification.
  47. ```python
  48. class Discriminator(nn.Module):
  49. def __init__(self):
  50. super(Discriminator, self).__init__()
  51. self.model = nn.Sequential(
  52. nn.Linear(int(np.prod(img_shape)), 512),
  53. nn.LeakyReLU(0.2, inplace=True),
  54. nn.Linear(512, 256),
  55. nn.LeakyReLU(0.2, inplace=True),
  56. nn.Linear(256, 1),
  57. nn.Sigmoid(),
  58. )
  59. def forward(self, img):
  60. img_flat = img.view(img.size(0), -1)
  61. validity = self.model(img_flat)
  62. return validity
  63. discriminator = Discriminator()
  64. ```
  65. In this example, we are using the PyTorch data loader for the MNIST dataset.
  66. The built-in data loader makes our lives more comfortable because it allows us to specify our batch size, downloads the data for us, and even normalizes it.
  67. ```python
  68. os.makedirs("../data/mnist", exist_ok=True)
  69. dataloader = torch.utils.data.DataLoader(
  70. datasets.MNIST(
  71. "../data/mnist",
  72. train=True,
  73. download=True,
  74. transform=transforms.Compose(
  75. [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
  76. ),
  77. ),
  78. batch_size=batch_size,
  79. shuffle=True,
  80. )
  81. ```
  82. In this example, we need two optimizers, one for the discriminator and one for the generator.
  83. We are using the Adam optimizer, which is a first-order gradient-based optimizer that works well within PyTorch.
  84. ```python
  85. optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
  86. optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
  87. adversarial_loss = torch.nn.BCELoss()
  88. ```
  89. The training loop is pretty standard, except that we have two neural networks to optimize each batch cycle.
  90. ```python
  91. for epoch in range(n_epochs):
  92. # chunks by batch
  93. for i, (imgs, _) in enumerate(dataloader):
  94. # Adversarial ground truths
  95. valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
  96. fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
  97. # Configure input
  98. real_imgs = Variable(imgs.type(Tensor))
  99. # training for generator
  100. optimizer_G.zero_grad()
  101. # Sample noise as generator input
  102. z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
  103. # Generate a batch of images
  104. gen_imgs = generator(z)
  105. # Loss measures generator's ability to fool the discriminator
  106. g_loss = adversarial_loss(discriminator(gen_imgs), valid)
  107. g_loss.backward()
  108. optimizer_G.step()
  109. # Training for discriminator
  110. optimizer_D.zero_grad()
  111. # Measure discriminator's ability to classify real from generated samples
  112. real_loss = adversarial_loss(discriminator(real_imgs), valid)
  113. fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
  114. d_loss = (real_loss + fake_loss) / 2
  115. d_loss.backward()
  116. optimizer_D.step()
  117. # total batches ran
  118. batches_done = epoch * len(dataloader) + i
  119. print(
  120. "[Epoch %d/%d] [Batch %d/%d] [Batches Done: %d] [D loss: %f] [G loss: %f]"
  121. % (epoch, n_epochs, i, len(dataloader), batches_done, d_loss.item(), g_loss.item())
  122. )
  123. ```
  124. Plus or minus a few things, that is a GAN in PyTorch. Pretty easy, right?
  125. You can find the full code for both the paper and this blog post on my [github](https://github.com/jrtechs/CSCI-431-final-GANs)
  126. ## Tensorboard
  127. Tensorboard is a library used to visualize the training progress and other aspects of machine learning experimentation.
  128. It is a little known fact that you can use Tensorboard even if you are using PyTorch since TensorBoard is primarily associated with the TensorFlow framework.
  129. Tensorboard gets installed via pip:
  130. ```
  131. pip install tensorboard
  132. ```
  133. Making minimal modifications to our PyTorch code, we can add the TensorBoard logging.
  134. ```python
  135. # inport
  136. from torch.utils.tensorboard import SummaryWriter
  137. # creates a new tensorboard logger
  138. writer = SummaryWriter()
  139. # add this to run for each batch
  140. writer.add_scalar('D_Loss', d_loss.item(), batches_done)
  141. writer.add_scalar('G_Loss', g_loss.item(), batches_done)
  142. # flushes file
  143. writer.close()
  144. ```
  145. After the model finishes training, you can open the TensorBoard logs using the "tensorboard" command in the terminal.
  146. ```
  147. tensorboard --logdir=runs
  148. ```
  149. Opening "http://0.0.0.0:6006/" in your browser will give you access to the TensorBoard web UI.
  150. ```
  151. [TensorBoard screen grab](media/gan/tensorboard.png)
  152. ```
  153. # Takeaways
  154. Robust, flexible GANs are relatively easy to create in PyTorch.
  155. For this reason, you find a lot of researchers who use PyTorch in their experimentation.