Open In Colab   Open in Kaggle

Tutorial 3: Conditional GANs and Implications of GAN Technology

Week 2, Day 4: Generative Models

By Neuromatch Academy

Content creators: Seungwook Han, Kai Xu, Akash Srivastava

Content reviewers: Polina Turishcheva, Melvin Selim Atay, Hadi Vafaei, Deepak Raya, Kelson Shilling-Scrivo

Content editors: Spiros Chavlis

Production editors: Arush Tagade, Gagana B, Spiros Chavlis


Tutorial Objectives

The goal of this tutorial is to understand conditional GANs. Then you will have the opportunity to experience first-hand how effective GANs are at modeling the data distribution and to question what the consequences of this technology may be.

By the end of this tutorial you will be able to:

  • Understand the differences in conditional GANs.

  • Generate high-dimensional natural images from a BigGAN.

  • Understand the efficacy of GANs in modeling the data distribution (e.g., faces).

  • Understand the energy inefficiency / environmental impact of training these large generative models.

  • Understand the implications of this technology (ethics, environment, etc.).

Tutorial slides

These are the slides for the videos in this tutorial. If you want to locally download the slides, click here.


Setup

Install dependencies

Install Huggingface BigGAN library

# @title Install dependencies
# @markdown Install *Huggingface BigGAN* library
!pip install pytorch-pretrained-biggan --quiet
!pip install Pillow libsixel-python --quiet
!pip install nltk --quiet

!pip install git+https://github.com/NeuromatchAcademy/evaltools --quiet

from evaltools.airtable import AirtableForm
atform = AirtableForm('appn7VdPRseSoMXEG', 'W2D4_T3','https://portal.neuromatchacademy.org/api/redirect/to/ddf809a2-5590-4d71-a764-1572e85dce27')
# Imports
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt

from pytorch_pretrained_biggan import BigGAN
from pytorch_pretrained_biggan import one_hot_from_names
from pytorch_pretrained_biggan import truncated_noise_sample

Figure settings

# @title Figure settings
import ipywidgets as widgets       # Interactive display
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle")

Set random seed

Executing set_seed(seed=seed) you are setting the seed

# @title Set random seed

# @markdown Executing `set_seed(seed=seed)` you are setting the seed

# for DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html

# Call `set_seed` function in the exercises to ensure reproducibility.
import random
import torch

def set_seed(seed=None, seed_torch=True):
  """
  Function that controls randomness. NumPy and random modules must be imported.

  Args:
    seed : Integer
      A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
      If `True` sets the random seed for pytorch tensors, so pytorch module
      must be imported. Default is `True`.

  Returns:
    Nothing.
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


# In case that `DataLoader` is used
def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

Set device (GPU or CPU). Execute set_device()

# @title Set device (GPU or CPU). Execute `set_device()`
# especially if torch modules used.

# Inform the user if the notebook uses GPU or CPU.

def set_device():
  """
  Set the device. CUDA if available, CPU otherwise

  Args:
    None

  Returns:
    Nothing
  """
  device = "cuda" if torch.cuda.is_available() else "cpu"
  if device != "cuda":
    print("WARNING: For this notebook to perform best, "
        "if possible, in the menu under `Runtime` -> "
        "`Change runtime type.`  select `GPU` ")
  else:
    print("GPU is enabled in this notebook.")

  return device
SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()
Random seed 2021 has been set.
WARNING: For this notebook to perform best, if possible, in the menu under `Runtime` -> `Change runtime type.`  select `GPU` 

Download wordnet dataset

# @title Download `wordnet` dataset

# import nltk
# nltk.download('wordnet')

import os, requests, zipfile

os.environ['NLTK_DATA'] = 'nltk_data/'

fnames = ['wordnet.zip', 'omw-1.4.zip']
urls = ['https://osf.io/ekjxy/download', 'https://osf.io/kuwep/download']

for fname, url in zip(fnames, urls):
  r = requests.get(url, allow_redirects=True)

  with open(fname, 'wb') as fd:
    fd.write(r.content)

  with zipfile.ZipFile(fname, 'r') as zip_ref:
    zip_ref.extractall('nltk_data/corpora')

Section 1: Generating with a conditional GAN (BigGAN)

Video 1: Conditional Generative Models

In this section, we will load a pre-trained conditional GAN, BigGAN, which is the state-of-the-art model in conditional high-dimensional natural image generation, and generate samples from it. Since it is a class conditional model, we will be able to use the class label to generate images from the different classes of objects.

Read here for more details on BigGAN: Brock et al., 2019.

def load_biggan(model_res):
  """
  Load respective BigGAN model for the specified resolution (biggan-deep-128, biggan-deep-256, biggan-deep-512)
  """
  return BigGAN.from_pretrained('biggan-deep-{}'.format(model_res))


def create_class_noise_vectors(class_str, trunc, num_samples):
  """
  Create class and noise vectors for sampling from BigGAN

  Args:
    class_str: string
      Class
    trunc: float
      Truncation factor
    num_samples: int
      Number of samples

  Returns:
    class_vector: np.ndarray
      Class vector sampled from BigGan
    noise_vector: np.ndarray
      Noise vector
  """
  class_vector = one_hot_from_names([class_str]*num_samples, batch_size=num_samples)
  noise_vector = truncated_noise_sample(truncation=trunc, batch_size=num_samples)

  return class_vector, noise_vector

def generate_biggan_samples(model, class_vector, noise_vector, device,
                            truncation=0.4):
  """
  Generate samples from BigGAN

  Args:
    model: nn.module
      Model
    device: string
      GPU if available. CPU otherwise.
    truncation: float
      Truncation factor
    class_vector: np.ndarray
      Class vector sampled from BigGan
    noise_vector: np.ndarray
      Noise vector

  Returns:
    output_grid: torch.tensor
      Make grid and display generated samples
  """
  # Convert to tensor
  noise_vector = torch.from_numpy(noise_vector)
  class_vector = torch.from_numpy(class_vector)

  # Move to GPU
  noise_vector = noise_vector.to(device)
  class_vector = class_vector.to(device)
  model.to(device)

  # Generate an image
  with torch.no_grad():
      output = model(noise_vector, class_vector, truncation)

  # Back to CPU
  output = output.to('cpu')

  # The output layer of BigGAN has a tanh layer, resulting the range of [-1, 1] for the output image
  # Therefore, we normalize the images properly to [0, 1] range.
  # Clipping is only in case of numerical instability problems

  output = torch.clip(((output.detach().clone() + 1) / 2.0), 0, 1)
  output = output

  # Make grid and show generated samples
  output_grid = torchvision.utils.make_grid(output,
                                            nrow=min(4, output.shape[0]),
                                            padding=5)
  plt.imshow(output_grid.permute(1, 2, 0))

  return output_grid


def generate(b):
  """
  Generation function

  Args:
    None

  Returns:
    Nothing
  """
  # Create BigGAN model
  model = load_biggan(MODEL_RESOLUTION)

  # Use specified parameters (resolution, class, number of samples, etc) to generate from BigGAN
  class_vector, noise_vector = create_class_noise_vectors(CLASS, TRUNCATION,
                                                          NUM_SAMPLES)
  samples_grid = generate_biggan_samples(model, class_vector, noise_vector,
                                         DEVICE, TRUNCATION)
  torchvision.utils.save_image(samples_grid, 'samples.png')
  ### If CUDA out of memory issue, lower NUM_SAMPLES (number of samples)

Section 1.1: Define configurations

We will now define the configurations (resolution of model, number of samples, class to sample from, truncation level) under which we will sample from BigGAN.

Question: What is the truncation trick employed by BigGAN? How does sample variety and fidelity change by varying the truncation level? (Hint: play with the truncation slider and try sampling at different levels)

{ run: “auto” }

# @title { run: "auto" }

### RUN THIS BLOCK EVERY TIME YOU CHANGE THE PARAMETERS FOR GENERATION

# Resolution at which to generate
MODEL_RESOLUTION = "128" # @param [128, 256, 512]

# Number of images to generate
NUM_SAMPLES = 4 # @param {type:"slider", min:4, max:12, step:4}

# Class of images to generate
CLASS = 'German shepherd'  # @param ['tench', 'magpie', 'jellyfish', 'German shepherd', 'bee', 'acoustic guitar', 'coffee mug', 'minibus', 'monitor']

# Truncation level of the normal distribution we sample z from
TRUNCATION = 0.4  # @param {type:"slider", min:0.1, max:1, step:0.1}

Generate

# @title Generate
# Create generate button, given parameters specified above
button = widgets.Button(description="GENERATE!",
                        layout=widgets.Layout(width='30%', height='80px'),
                        button_style='danger')
output = widgets.Output()
display(button, output)
button.on_click(generate)

Think! 1: BigGANs

  1. How does BigGAN differ from previous state-of-the-art generative models for high-dimensional natural images? In other words, how does BigGAN solve high-dimensional image generation? (Hint: look into model architecture and training configurations) (BigGAN paper: Brock et al., 2018)

  2. Continuing from Question 1, what are the drawbacks of introducing such techniques into training large models for high-dimensional, diverse datasets?

  3. Play with other pre-trained generative models like StyleGAN here – where code for sampling and interpolation in the latent space is available here

Student Response

# @title Student Response
from ipywidgets import widgets


text=widgets.Textarea(
   value='Type answer here and Push submit',
   placeholder='Type something',
   description='',
   disabled=False
)

button = widgets.Button(description="Submit!")

display(text,button)

def on_button_clicked(b):
   atform.add_answer('q1' , text.value)
   print("Submission successful!")


button.on_click(on_button_clicked)

Click for solution


Section 2: Ethical issues

Video 2: Ethical Issues

Section 2.1: Faces Quiz

Now is your turn to test your abilities on recognizing a real vs. a fake image!

Real or Fake?

Section 2.2: Energy Efficiency Quiz

Make a guess


Summary

Hooray! You have finished the second week of NMA-DL!!!

In the first section of this tutorial, we have learned:

  • How conditional GANs differ from unconditional models

  • How to use a pre-trained BigGAN model to generate high-dimensional photo-realistic images and its tricks to modulate diversity and image fidelity

In the second section, we learned about the broader ethical implications of GAN technology on society through deepfakes and their tremendous energy inefficiency.

On the brighter side, as we learned throughout the week, GANs are very effective in modeling the data distribution and have many practical applications.

For example, as personalized healthcare and applications of AI in healthcare rise, the need to remove any Personally Identifiable Information (PII) becomes more important. As shown in Piacentino and Angulo, 2020, GANs can be leveraged to anonymize healthcare data.

As a food for thought, what are some other practical applications of GANs that you can think of? Discuss with your pod your ideas.

(Bonus) Video 3: Recap and advanced topics