Open In Colab   Open in Kaggle

Knowledge Extraction from a Convolutional Neural Network#

By Neuromatch Academy

Content creators: Jan Funke

Production editors: Spiros Chavlis, Konstantine Tsafatinos


Objective#

Train a convolutional neural network to classify images and a CycleGAN to translate between images of different types.

This notebook contains everything to train a VGG network on labelled images and to train a CycleGAN to translate between images.

We will use electron microscopy images of Drosophila synapses for this project. Those images can be classified according to the neurotransmitter type they release.


Setup#

Install dependencies#

Hide code cell source
# @title Install dependencies
!pip install scikit-image --quiet
!pip install pillow --quiet
!pip install scikit-image --quiet
[notice] A new release of pip is available: 24.0 -> 24.1.2
[notice] To update, run: pip install --upgrade pip

[notice] A new release of pip is available: 24.0 -> 24.1.2
[notice] To update, run: pip install --upgrade pip

[notice] A new release of pip is available: 24.0 -> 24.1.2
[notice] To update, run: pip install --upgrade pip
import glob
import json
import torch
import numpy as np

import matplotlib.pyplot as plt
from tqdm import tqdm

from skimage.io import imread
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler

%matplotlib inline

Project Ideas#

  1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy.

    • (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.).

    • (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes.

    • (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels.

    • (medium) Other networks: Try different architectures (e.g., a ResNet) and see if the accuracy can be improved.

    • (medium) Inspect errors made by the classifier. Which classes are most accurately predicted? Which classes are confused with each other?

  2. Explore the CycleGAN.

    • (easy) The example code below shows how to translate between GABA and glutamate. Try different combinations, and also in the reverse direction. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious?

    • (hard) Watching the CycleGAN train can be a bit boring. Find a way to show (periodically) the current image and its translation to see how the network is improving over time. Hint: The cycle_gan module has a Visualizer, which might be helpful.

  3. Try on your own data!

    • Have a look at how the synapse images are organized in data/raw/synapses. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code.

Acknowledgments#

This notebook was written by Jan Funke, using code from Nils Eckstein and a modified version of the CycleGAN implementation.


Train an Image Classifier#

In this section, we will implement and train a VGG classifier to classify images of synapses into one of six classes, corresponding to the neurotransmitter type that is released at the synapse: GABA, acethylcholine, glutamate, octopamine, serotonin, and dopamine.

Data Preparation#

Download the data#

Hide code cell source
# @title Download the data
import requests, os
from zipfile import ZipFile

def download_file_parts(urls, output_directory='.'):
    """
    Download file parts from given URLs and save them in the specified directory.

    :param urls: List of URLs to download
    :param output_directory: Directory to save the downloaded parts (default is current directory)
    :return: List of downloaded file paths
    """
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    downloaded_files = []

    for i, url in enumerate(urls, 1):
        try:
            response = requests.get(url, stream=True)
            response.raise_for_status()  # Raises an HTTPError for bad requests

            file_name = f"part{i}"
            file_path = os.path.join(output_directory, file_name)

            with open(file_path, 'wb') as file:
                for chunk in response.iter_content(chunk_size=32768):
                    file.write(chunk)

            downloaded_files.append(file_path)
            print(f"Downloaded: {file_path}")


        except requests.RequestException as e:
            print(f"Error downloading {url}: {e}")

    return downloaded_files

def reassemble_file(output_file):
    chunk_number = 1
    with open(output_file, 'wb') as outfile:
        while True:
            chunk_name = f'part{chunk_number}'
            if not os.path.exists(chunk_name):
                break
            with open(chunk_name, 'rb') as infile:
                outfile.write(infile.read())
            chunk_number += 1
    for i in ['part1', 'part2', 'part3']:
        if os.path.exists(i):
            os.remove(i)
    print(f"Downloaded files have been removed!")
    print(f"Reassembled {chunk_number-1} parts into {output_file}")



# @markdown Download the resources for this tutorial (one zip file)
fname = 'resources.zip'
urls = [
    "https://osf.io/download/4x7p3/",
    "https://osf.io/download/fzwea/",
    "https://osf.io/download/qpbcv/"
]

if not os.path.exists('data/'):
    print('Data downloading...')
    output_dir = "."
    downloaded_parts = download_file_parts(urls, output_dir)
    print('Download is completed.')

    print('Reassembling Files...')
    base_name = ''
    reassemble_file(fname)

    # @markdown Unzip the file
    with ZipFile(fname, 'r') as zf:
        # extracting all the files
        print('Extracting all the files now...')
        zf.extractall(path='.')
        print('Done!')

    # # @markdown Extract the data
    fnames = ['data.zip', 'checkpoints.zip']

    for fname in fnames:
        with ZipFile(fname, 'r') as zh:
            # extracting all the files
            print(f"\nArchive: {fname}")
            print(f"\tExtracting data...")
            zh.extractall(path='.')
    print('Done!')

    # @markdown Make sure the order of classes matches the pretrained model
    os.rename('data/raw/synapses/gaba', 'data/raw/synapses/0_gaba')
    os.rename('data/raw/synapses/acetylcholine', 'data/raw/synapses/1_acetylcholine')
    os.rename('data/raw/synapses/glutamate', 'data/raw/synapses/2_glutamate')
    os.rename('data/raw/synapses/serotonin', 'data/raw/synapses/3_serotonin')
    os.rename('data/raw/synapses/octopamine', 'data/raw/synapses/4_octopamine')
    os.rename('data/raw/synapses/dopamine', 'data/raw/synapses/5_dopamine')

    # @markdown Remove the archives
    for i in ['checkpoints.zip', 'experiments.zip', 'data.zip', 'resources.zip']:
        if os.path.exists(i):
            os.remove(i)

else:
    print('Data are already downloaded.')
Data are already downloaded.

Classifier Training#

Create and Inspect Datasets#

First, we create a torch data loaders for training, validation, and testing. We will use weighted sampling to account for the class imbalance during training.

def load_image(filename):

  image = imread(filename)

  # images are grescale, we only need one of the RGB channels
  image = image[:, :, 0]
  # img is uint8 in [0, 255], but we want float32 in [-1, 1]
  image = image.astype(np.float32)/255.0
  image = (image - 0.5)/0.5

  return image


# create a dataset for all images of all classes
full_dataset = ImageFolder(root='data/raw/synapses', loader=load_image)

# randomly split the dataset into train, validation, and test
num_images = len(full_dataset)
# ~70% for training
num_training = int(0.7 * num_images)
# ~15% for validation
num_validation = int(0.15 * num_images)
# ~15% for testing
num_test = num_images - (num_training + num_validation)
# split the data randomly (but with a fixed random seed)
train_dataset, validation_dataset, test_dataset = random_split(
    full_dataset,
    [num_training, num_validation, num_test],
    generator=torch.Generator().manual_seed(23061912))

# compute class weights in training dataset for uniform sampling
ys = np.array([y for _, y in train_dataset])
counts = np.bincount(ys)
label_weights = 1.0 / counts
weights = label_weights[ys]

print("Number of images per class:")
for c, n, w in zip(full_dataset.classes, counts, label_weights):
  print(f"\t{c}:\tn={n}\tweight={w}")

# create a data loader with uniform sampling
sampler = WeightedRandomSampler(weights, len(weights))
# this data loader will serve 8 images in a "mini-batch" at a time
dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True, sampler=sampler)
Number of images per class:
	0_gaba:	n=15945	weight=6.271558482282847e-05
	1_acetylcholine:	n=4852	weight=0.00020610057708161583
	2_glutamate:	n=3556	weight=0.00028121484814398203
	3_serotonin:	n=2316	weight=0.0004317789291882556
	4_octopamine:	n=934	weight=0.0010706638115631692
	5_dopamine:	n=4640	weight=0.00021551724137931034
	cycle_gan:	n=19383	weight=5.159160088737554e-05

The cell below visualizes a single, randomly chosen batch from the training data loader. Feel free to execute this cell multiple times to get a feeling for the dataset. See if you can tell the difference between synapses of different types!

def show_batch(x, y):
  fig, axs = plt.subplots(1, x.shape[0], figsize=(14, 14), sharey=True)
  for i in range(x.shape[0]):
    axs[i].imshow(np.squeeze(x[i]), cmap='gray')
    axs[i].set_title(train_dataset.dataset.classes[y[i].item()])
  plt.show()

# show a random batch from the data loader
# (run this cell repeatedly to see different batches)
for x, y in dataloader:
  show_batch(x, y)
  break
../../_images/8f9075fcda24135cdc6a8bcb404f520111c5a4603215a31f42edb0670dd6b1f1.png

Create a Model, Loss, and Optimizer#

class Vgg2D(torch.nn.Module):

  def __init__(
          self,
          input_size,
          fmaps=12,
          downsample_factors=[(2, 2), (2, 2), (2, 2), (2, 2)],
          output_classes=6):

    super(Vgg2D, self).__init__()

    self.input_size = input_size

    current_fmaps = 1
    current_size = tuple(input_size)

    features = []
    for i in range(len(downsample_factors)):

      features += [
          torch.nn.Conv2d(
              current_fmaps,
              fmaps,
              kernel_size=3,
              padding=1),
          torch.nn.BatchNorm2d(fmaps),
          torch.nn.ReLU(inplace=True),
          torch.nn.Conv2d(
              fmaps,
              fmaps,
              kernel_size=3,
              padding=1),
          torch.nn.BatchNorm2d(fmaps),
          torch.nn.ReLU(inplace=True),
          torch.nn.MaxPool2d(downsample_factors[i])
      ]

      current_fmaps = fmaps
      fmaps *= 2

      size = tuple(
          int(c/d)
          for c, d in zip(current_size, downsample_factors[i]))
      check = (
          s*d == c
          for s, d, c in zip(size, downsample_factors[i], current_size))
      assert all(check), \
          "Can not downsample %s by chosen downsample factor" % \
          (current_size,)
      current_size = size

    self.features = torch.nn.Sequential(*features)

    classifier = [
        torch.nn.Linear(
            current_size[0] *
            current_size[1] *
            current_fmaps,
            4096),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(),
        torch.nn.Linear(
            4096,
            4096),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(),
        torch.nn.Linear(
            4096,
            output_classes)
    ]

    self.classifier = torch.nn.Sequential(*classifier)

  def forward(self, raw):

    # add a channel dimension to raw
    shape = tuple(raw.shape)
    raw = raw.reshape(shape[0], 1, shape[1], shape[2])

    # compute features
    f = self.features(raw)
    f = f.view(f.size(0), -1)

    # classify
    y = self.classifier(f)

    return y
# get the size of our images
for x, y in train_dataset:
  input_size = x.shape
  break

# create the model to train
model = Vgg2D(input_size)

# create a loss
loss = torch.nn.CrossEntropyLoss()

# create an optimzer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

Train the Model#

# use a GPU, if it is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f"Will use device {device} for training")
Will use device cpu for training

The next cell merely defines some convenience functions for training, validation, and testing:

def train(dataloader, optimizer, loss, device):
  '''Train the model for one epoch.'''

  # set the model into train mode
  model.train()

  epoch_loss, num_batches = 0, 0
  for x, y in tqdm(dataloader, 'train'):

    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()

    y_pred = model(x)
    l = loss(y_pred, y)
    l.backward()

    optimizer.step()

    epoch_loss += l
    num_batches += 1

  return epoch_loss/num_batches


def evaluate(dataloader, name, device):

  correct = 0
  total = 0
  for x, y in tqdm(dataloader, name):

    x, y = x.to(device), y.to(device)

    logits = model(x)
    probs = torch.nn.Softmax(dim=1)(logits)
    predictions = torch.argmax(probs, dim=1)

    correct += int(torch.sum(predictions == y).cpu().detach().numpy())
    total += len(y)

  accuracy = correct/total

  return accuracy


def validate(validation_dataset, device):
  '''Evaluate prediction accuracy on the validation dataset.'''

  model.eval()
  dataloader = DataLoader(validation_dataset, batch_size=32)

  return evaluate(dataloader, 'validate', device)


def test(test_dataset, device):
  '''Evaluate prediction accuracy on the test dataset.'''

  model.eval()
  dataloader = DataLoader(test_dataset, batch_size=32)

  return evaluate(dataloader, 'test', device)

We are ready to train. After each epoch (roughly going through each training image once), we report the training loss and the validation accuracy.

def train_from_scratch(dataloader, validation_dataset,
                       optimizer, loss,
                       num_epochs=100, device=device):

  for epoch in range(num_epochs):
    epoch_loss = train(dataloader, optimizer, loss, device=device)
    print(f"epoch {epoch}, training loss={epoch_loss}")

    accuracy = validate(validation_dataset, device=device)
    print(f"epoch {epoch}, validation accuracy={accuracy}")

yes_I_want_the_pretrained_model = True will load a checkpoint that we already prepared, whereas setting it to False will train the model from scratch.

Unceck the box below and run the cell to train a model.

Hide code cell source
# @markdown
yes_I_want_the_pretrained_model =  True # @param {type:"boolean"}
# Load a pretrained model or train the model from scratch

# set this to True and run this cell if you want a shortcut

if yes_I_want_the_pretrained_model:
  checkpoint = torch.load('checkpoints/synapses/classifier/vgg_checkpoint',
                          map_location=device)
  model.load_state_dict(checkpoint['model_state_dict'])
else:
  train_from_scratch(dataloader, validation_dataset,
                     optimizer, loss,
                     num_epochs=100, device=device)
accuracy = test(test_dataset, device=device)
print(f"final test accuracy: {accuracy}")
final test accuracy: 0.49737888647866957

This concludes the first section. We now have a classifier that can discriminate between images of different types.

If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images. Furthermore, it is not so clear for humans how the classifier does it. Feel free to explore the data a bit more and see for yourself if you can tell the difference betwee, say, GABAergic and glutamatergic synapses.

So this is an interesting situation: The VGG network knows something we don’t quite know. In the next section, we will see how we can visualize the relevant differences between images of different types.


Train a GAN to Translate Images#

We will train a so-called CycleGAN to translate images from one class to another.

Get the CycleGAN code and dependencies#

GitHub repo: https://github.com/funkey/neuromatch_xai

Hide code cell source
# @title Get the CycleGAN code and dependencies

# @markdown GitHub repo: https://github.com/funkey/neuromatch_xai

import requests, zipfile, io

url = 'https://osf.io/vutn5/download'
r = requests.get(url)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall()

!pip install dominate --quiet
[notice] A new release of pip is available: 24.0 -> 24.1.2
[notice] To update, run: pip install --upgrade pip

In this example, we will translate between GABAergic and glutamatergic synapses.

First, we have to copy images of either type into a format that the CycleGAN library is happy with. Afterwards, we can start training on those images.

import cycle_gan

cycle_gan.prepare_dataset('data/raw/synapses/', ['0_gaba', '2_glutamate'])

## Uncomment if you want to enable the training procedure
# cycle_gan.train('data/raw/synapses/', '0_gaba', '2_glutamate', 128)

Training the CycleGAN takes a lot longer than the VGG we trained above (on the synapse dataset, this will be around 7 days…).

To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset.

# translate images from class A to B, and classify each with the VGG network trained above
# Note: cycle_gan requires CUDA devices
if device == "cuda":
    cycle_gan.test(
        data_dir='data/raw/synapses/',
        class_A='0_gaba',
        class_B='2_glutamate',
        img_size=128,
        checkpoints_dir='checkpoints/synapses/cycle_gan/gaba_glutamate/',
        vgg_checkpoint='checkpoints/synapses/classifier/vgg_checkpoint'
    )

Read all translated images and sort them by how much the translation “fools” the VGG classifier trained above:

class_A_index = 0
class_B_index = 2

result_dir = 'data/raw/synapses/cycle_gan/0_gaba_2_glutamate/results/test_latest/images/'
classification_results = []
for f in glob.glob(result_dir + '/*.json'):
  result = json.load(open(f))
  result['basename'] = f.replace('_aux.json', '')
  classification_results.append(result)
classification_results.sort(
    key=lambda c: c['aux_real'][class_A_index] * c['aux_fake'][class_B_index],
    reverse=True)

Show the top real and fake images that make the classifier change its mind:

def show_pair(a, b, score_a, score_b, class_a, class_b):
  fig, axs = plt.subplots(1, 2, figsize=(20, 20), sharey=True)
  axs[0].imshow(a, cmap='gray')
  axs[0].set_title(f"p({class_a}) = " + str(score_a))
  axs[1].imshow(b, cmap='gray')
  axs[1].set_title(f"p({class_b}) = " + str(score_b))
  plt.show()


# show the top successful translations (according to our VGG classifier)
# Note: only run if cycle_gan ran successfully
if classification_results:
    for i in range(10):
      basename = classification_results[i]['basename']
      score_A = classification_results[i]['aux_real'][class_A_index]
      score_B = classification_results[i]['aux_fake'][class_B_index]
      real_A = imread(basename + '_real.png')
      fake_B = imread(basename + '_fake.png')
      show_pair(real_A, fake_B, score_A, score_B, 'gaba', 'glutamate')