{ "cells": [ { "cell_type": "markdown", "id": "2d9f0b20", "metadata": { "execution": {} }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "id": "renayVUI7b9x", "metadata": { "execution": {} }, "source": [ "# Knowledge Extraction from a Convolutional Neural Network\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Jan Funke\n", "\n", "__Production editors:__ Spiros Chavlis, Konstantine Tsafatinos" ] }, { "cell_type": "markdown", "id": "U6wofKujWp6X", "metadata": { "execution": {} }, "source": [ "---\n", "# Objective\n", "\n", "Train a convolutional neural network to classify images and a CycleGAN to translate between images of different types.\n", "\n", "This notebook contains everything to train a VGG network on labelled images and to train a CycleGAN to translate between images.\n", "\n", "We will use electron microscopy images of Drosophila synapses for this project. Those images can be classified according to the neurotransmitter type they release." ] }, { "cell_type": "markdown", "id": "zO4YN6W8W0Cp", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fO1IZwvkW9Me", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "# @title Install dependencies\n", "!pip install scikit-image --quiet\n", "!pip install pillow --quiet\n", "!pip install scikit-image --quiet" ] }, { "cell_type": "code", "execution_count": null, "id": "gKkHjjTGWzUk", "metadata": { "execution": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/yuda/code/neuromatch/course-content-dl/venv/lib/python3.9/site-packages/torch/cuda/__init__.py:619: UserWarning: Can't initialize NVML\n", " warnings.warn(\"Can't initialize NVML\")\n" ] } ], "source": [ "import glob\n", "import json\n", "import torch\n", "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm\n", "\n", "from skimage.io import imread\n", "from torchvision.datasets import ImageFolder\n", "from torch.utils.data import DataLoader, random_split\n", "from torch.utils.data.sampler import WeightedRandomSampler\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "bd7d427d", "metadata": { "execution": {} }, "source": [ "---\n", "# Project Ideas\n", "\n", "1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy.\n", " * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.).\n", " * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes.\n", " * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels.\n", " * (medium) Other networks: Try different architectures (e.g., a ResNet) and see if the accuracy can be improved.\n", " * (medium) Inspect errors made by the classifier. Which classes are most accurately predicted? Which classes are confused with each other?\n", " \n", " \n", "2. Explore the CycleGAN.\n", " * (easy) The example code below shows how to translate between GABA and glutamate. Try different combinations, and also in the reverse direction. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious?\n", " * (hard) Watching the CycleGAN train can be a bit boring. Find a way to show (periodically) the current image and its translation to see how the network is improving over time. Hint: The `cycle_gan` module has a `Visualizer`, which might be helpful.\n", " \n", "\n", "3. Try on your own data!\n", " * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code.\n", "\n", "### Acknowledgments\n", "\n", "This notebook was written by Jan Funke, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation.\n" ] }, { "cell_type": "markdown", "id": "5642d709", "metadata": { "execution": {} }, "source": [ "---\n", "# Train an Image Classifier\n", "\n", "In this section, we will implement and train a VGG classifier to classify images of synapses into one of six classes, corresponding to the neurotransmitter type that is released at the synapse: GABA, acethylcholine, glutamate, octopamine, serotonin, and dopamine." ] }, { "cell_type": "markdown", "id": "c61a11c6", "metadata": { "execution": {} }, "source": [ "## Data Preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download the data\n" ] }, { "cell_type": "code", "execution_count": null, "id": "821dc497", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data are already downloaded.\n" ] } ], "source": [ "# @title Download the data\n", "import requests, os\n", "from zipfile import ZipFile\n", "\n", "def download_file_parts(urls, output_directory='.'):\n", " \"\"\"\n", " Download file parts from given URLs and save them in the specified directory.\n", "\n", " :param urls: List of URLs to download\n", " :param output_directory: Directory to save the downloaded parts (default is current directory)\n", " :return: List of downloaded file paths\n", " \"\"\"\n", " if not os.path.exists(output_directory):\n", " os.makedirs(output_directory)\n", "\n", " downloaded_files = []\n", "\n", " for i, url in enumerate(urls, 1):\n", " try:\n", " response = requests.get(url, stream=True)\n", " response.raise_for_status() # Raises an HTTPError for bad requests\n", "\n", " file_name = f\"part{i}\"\n", " file_path = os.path.join(output_directory, file_name)\n", "\n", " with open(file_path, 'wb') as file:\n", " for chunk in response.iter_content(chunk_size=32768):\n", " file.write(chunk)\n", "\n", " downloaded_files.append(file_path)\n", " print(f\"Downloaded: {file_path}\")\n", "\n", "\n", " except requests.RequestException as e:\n", " print(f\"Error downloading {url}: {e}\")\n", "\n", " return downloaded_files\n", "\n", "def reassemble_file(output_file):\n", " chunk_number = 1\n", " with open(output_file, 'wb') as outfile:\n", " while True:\n", " chunk_name = f'part{chunk_number}'\n", " if not os.path.exists(chunk_name):\n", " break\n", " with open(chunk_name, 'rb') as infile:\n", " outfile.write(infile.read())\n", " chunk_number += 1\n", " for i in ['part1', 'part2', 'part3']:\n", " if os.path.exists(i):\n", " os.remove(i)\n", " print(f\"Downloaded files have been removed!\")\n", " print(f\"Reassembled {chunk_number-1} parts into {output_file}\")\n", "\n", "\n", "\n", "# @markdown Download the resources for this tutorial (one zip file)\n", "fname = 'resources.zip'\n", "urls = [\n", " \"https://osf.io/download/4x7p3/\",\n", " \"https://osf.io/download/fzwea/\",\n", " \"https://osf.io/download/qpbcv/\"\n", "]\n", "\n", "if not os.path.exists('data/'):\n", " print('Data downloading...')\n", " output_dir = \".\"\n", " downloaded_parts = download_file_parts(urls, output_dir)\n", " print('Download is completed.')\n", "\n", " print('Reassembling Files...')\n", " base_name = ''\n", " reassemble_file(fname)\n", "\n", " # @markdown Unzip the file\n", " with ZipFile(fname, 'r') as zf:\n", " # extracting all the files\n", " print('Extracting all the files now...')\n", " zf.extractall(path='.')\n", " print('Done!')\n", "\n", " # # @markdown Extract the data\n", " fnames = ['data.zip', 'checkpoints.zip']\n", "\n", " for fname in fnames:\n", " with ZipFile(fname, 'r') as zh:\n", " # extracting all the files\n", " print(f\"\\nArchive: {fname}\")\n", " print(f\"\\tExtracting data...\")\n", " zh.extractall(path='.')\n", " print('Done!')\n", "\n", " # @markdown Make sure the order of classes matches the pretrained model\n", " os.rename('data/raw/synapses/gaba', 'data/raw/synapses/0_gaba')\n", " os.rename('data/raw/synapses/acetylcholine', 'data/raw/synapses/1_acetylcholine')\n", " os.rename('data/raw/synapses/glutamate', 'data/raw/synapses/2_glutamate')\n", " os.rename('data/raw/synapses/serotonin', 'data/raw/synapses/3_serotonin')\n", " os.rename('data/raw/synapses/octopamine', 'data/raw/synapses/4_octopamine')\n", " os.rename('data/raw/synapses/dopamine', 'data/raw/synapses/5_dopamine')\n", "\n", " # @markdown Remove the archives\n", " for i in ['checkpoints.zip', 'experiments.zip', 'data.zip', 'resources.zip']:\n", " if os.path.exists(i):\n", " os.remove(i)\n", "\n", "else:\n", " print('Data are already downloaded.')" ] }, { "cell_type": "markdown", "id": "0b84ec7b", "metadata": { "execution": {} }, "source": [ "## Classifier Training" ] }, { "cell_type": "markdown", "id": "a79ab567", "metadata": { "execution": {} }, "source": [ "### Create and Inspect Datasets\n", "\n", "First, we create a `torch` data loaders for training, validation, and testing. We will use weighted sampling to account for the class imbalance during training." ] }, { "cell_type": "code", "execution_count": null, "id": "ae50b16a", "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of images per class:\n", "\t0_gaba:\tn=15945\tweight=6.271558482282847e-05\n", "\t1_acetylcholine:\tn=4852\tweight=0.00020610057708161583\n", "\t2_glutamate:\tn=3556\tweight=0.00028121484814398203\n", "\t3_serotonin:\tn=2316\tweight=0.0004317789291882556\n", "\t4_octopamine:\tn=934\tweight=0.0010706638115631692\n", "\t5_dopamine:\tn=4640\tweight=0.00021551724137931034\n", "\tcycle_gan:\tn=19383\tweight=5.159160088737554e-05\n" ] } ], "source": [ "def load_image(filename):\n", "\n", " image = imread(filename)\n", "\n", " # images are grescale, we only need one of the RGB channels\n", " image = image[:, :, 0]\n", " # img is uint8 in [0, 255], but we want float32 in [-1, 1]\n", " image = image.astype(np.float32)/255.0\n", " image = (image - 0.5)/0.5\n", "\n", " return image\n", "\n", "\n", "# create a dataset for all images of all classes\n", "full_dataset = ImageFolder(root='data/raw/synapses', loader=load_image)\n", "\n", "# randomly split the dataset into train, validation, and test\n", "num_images = len(full_dataset)\n", "# ~70% for training\n", "num_training = int(0.7 * num_images)\n", "# ~15% for validation\n", "num_validation = int(0.15 * num_images)\n", "# ~15% for testing\n", "num_test = num_images - (num_training + num_validation)\n", "# split the data randomly (but with a fixed random seed)\n", "train_dataset, validation_dataset, test_dataset = random_split(\n", " full_dataset,\n", " [num_training, num_validation, num_test],\n", " generator=torch.Generator().manual_seed(23061912))\n", "\n", "# compute class weights in training dataset for uniform sampling\n", "ys = np.array([y for _, y in train_dataset])\n", "counts = np.bincount(ys)\n", "label_weights = 1.0 / counts\n", "weights = label_weights[ys]\n", "\n", "print(\"Number of images per class:\")\n", "for c, n, w in zip(full_dataset.classes, counts, label_weights):\n", " print(f\"\\t{c}:\\tn={n}\\tweight={w}\")\n", "\n", "# create a data loader with uniform sampling\n", "sampler = WeightedRandomSampler(weights, len(weights))\n", "# this data loader will serve 8 images in a \"mini-batch\" at a time\n", "dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True, sampler=sampler)" ] }, { "cell_type": "markdown", "id": "e9010bdc", "metadata": { "execution": {} }, "source": [ "The cell below visualizes a single, randomly chosen batch from the training data loader. Feel free to execute this cell multiple times to get a feeling for the dataset. See if you can tell the difference between synapses of different types!" ] }, { "cell_type": "code", "execution_count": null, "id": "3d8c6f3a", "metadata": { "execution": {} }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def show_batch(x, y):\n", " fig, axs = plt.subplots(1, x.shape[0], figsize=(14, 14), sharey=True)\n", " for i in range(x.shape[0]):\n", " axs[i].imshow(np.squeeze(x[i]), cmap='gray')\n", " axs[i].set_title(train_dataset.dataset.classes[y[i].item()])\n", " plt.show()\n", "\n", "# show a random batch from the data loader\n", "# (run this cell repeatedly to see different batches)\n", "for x, y in dataloader:\n", " show_batch(x, y)\n", " break" ] }, { "cell_type": "markdown", "id": "f882416f", "metadata": { "execution": {} }, "source": [ "### Create a Model, Loss, and Optimizer" ] }, { "cell_type": "code", "execution_count": null, "id": "54f177cc", "metadata": { "execution": {} }, "outputs": [], "source": [ "class Vgg2D(torch.nn.Module):\n", "\n", " def __init__(\n", " self,\n", " input_size,\n", " fmaps=12,\n", " downsample_factors=[(2, 2), (2, 2), (2, 2), (2, 2)],\n", " output_classes=6):\n", "\n", " super(Vgg2D, self).__init__()\n", "\n", " self.input_size = input_size\n", "\n", " current_fmaps = 1\n", " current_size = tuple(input_size)\n", "\n", " features = []\n", " for i in range(len(downsample_factors)):\n", "\n", " features += [\n", " torch.nn.Conv2d(\n", " current_fmaps,\n", " fmaps,\n", " kernel_size=3,\n", " padding=1),\n", " torch.nn.BatchNorm2d(fmaps),\n", " torch.nn.ReLU(inplace=True),\n", " torch.nn.Conv2d(\n", " fmaps,\n", " fmaps,\n", " kernel_size=3,\n", " padding=1),\n", " torch.nn.BatchNorm2d(fmaps),\n", " torch.nn.ReLU(inplace=True),\n", " torch.nn.MaxPool2d(downsample_factors[i])\n", " ]\n", "\n", " current_fmaps = fmaps\n", " fmaps *= 2\n", "\n", " size = tuple(\n", " int(c/d)\n", " for c, d in zip(current_size, downsample_factors[i]))\n", " check = (\n", " s*d == c\n", " for s, d, c in zip(size, downsample_factors[i], current_size))\n", " assert all(check), \\\n", " \"Can not downsample %s by chosen downsample factor\" % \\\n", " (current_size,)\n", " current_size = size\n", "\n", " self.features = torch.nn.Sequential(*features)\n", "\n", " classifier = [\n", " torch.nn.Linear(\n", " current_size[0] *\n", " current_size[1] *\n", " current_fmaps,\n", " 4096),\n", " torch.nn.ReLU(inplace=True),\n", " torch.nn.Dropout(),\n", " torch.nn.Linear(\n", " 4096,\n", " 4096),\n", " torch.nn.ReLU(inplace=True),\n", " torch.nn.Dropout(),\n", " torch.nn.Linear(\n", " 4096,\n", " output_classes)\n", " ]\n", "\n", " self.classifier = torch.nn.Sequential(*classifier)\n", "\n", " def forward(self, raw):\n", "\n", " # add a channel dimension to raw\n", " shape = tuple(raw.shape)\n", " raw = raw.reshape(shape[0], 1, shape[1], shape[2])\n", "\n", " # compute features\n", " f = self.features(raw)\n", " f = f.view(f.size(0), -1)\n", "\n", " # classify\n", " y = self.classifier(f)\n", "\n", " return y" ] }, { "cell_type": "code", "execution_count": null, "id": "5da43245", "metadata": { "execution": {} }, "outputs": [], "source": [ "# get the size of our images\n", "for x, y in train_dataset:\n", " input_size = x.shape\n", " break\n", "\n", "# create the model to train\n", "model = Vgg2D(input_size)\n", "\n", "# create a loss\n", "loss = torch.nn.CrossEntropyLoss()\n", "\n", "# create an optimzer\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", "id": "01688095", "metadata": { "execution": {} }, "source": [ "### Train the Model" ] }, { "cell_type": "code", "execution_count": null, "id": "fa65090d", "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Will use device cpu for training\n" ] } ], "source": [ "# use a GPU, if it is available\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "print(f\"Will use device {device} for training\")" ] }, { "cell_type": "markdown", "id": "ecbab4f7", "metadata": { "execution": {} }, "source": [ "The next cell merely defines some convenience functions for training, validation, and testing:" ] }, { "cell_type": "code", "execution_count": null, "id": "1a8c7fe9", "metadata": { "execution": {} }, "outputs": [], "source": [ "def train(dataloader, optimizer, loss, device):\n", " '''Train the model for one epoch.'''\n", "\n", " # set the model into train mode\n", " model.train()\n", "\n", " epoch_loss, num_batches = 0, 0\n", " for x, y in tqdm(dataloader, 'train'):\n", "\n", " x, y = x.to(device), y.to(device)\n", " optimizer.zero_grad()\n", "\n", " y_pred = model(x)\n", " l = loss(y_pred, y)\n", " l.backward()\n", "\n", " optimizer.step()\n", "\n", " epoch_loss += l\n", " num_batches += 1\n", "\n", " return epoch_loss/num_batches\n", "\n", "\n", "def evaluate(dataloader, name, device):\n", "\n", " correct = 0\n", " total = 0\n", " for x, y in tqdm(dataloader, name):\n", "\n", " x, y = x.to(device), y.to(device)\n", "\n", " logits = model(x)\n", " probs = torch.nn.Softmax(dim=1)(logits)\n", " predictions = torch.argmax(probs, dim=1)\n", "\n", " correct += int(torch.sum(predictions == y).cpu().detach().numpy())\n", " total += len(y)\n", "\n", " accuracy = correct/total\n", "\n", " return accuracy\n", "\n", "\n", "def validate(validation_dataset, device):\n", " '''Evaluate prediction accuracy on the validation dataset.'''\n", "\n", " model.eval()\n", " dataloader = DataLoader(validation_dataset, batch_size=32)\n", "\n", " return evaluate(dataloader, 'validate', device)\n", "\n", "\n", "def test(test_dataset, device):\n", " '''Evaluate prediction accuracy on the test dataset.'''\n", "\n", " model.eval()\n", " dataloader = DataLoader(test_dataset, batch_size=32)\n", "\n", " return evaluate(dataloader, 'test', device)" ] }, { "cell_type": "markdown", "id": "68bcfbbf", "metadata": { "execution": {} }, "source": [ "We are ready to train. After each epoch (roughly going through each training image once), we report the training loss and the validation accuracy." ] }, { "cell_type": "code", "execution_count": null, "id": "d0af7638", "metadata": { "execution": {} }, "outputs": [], "source": [ "def train_from_scratch(dataloader, validation_dataset,\n", " optimizer, loss,\n", " num_epochs=100, device=device):\n", "\n", " for epoch in range(num_epochs):\n", " epoch_loss = train(dataloader, optimizer, loss, device=device)\n", " print(f\"epoch {epoch}, training loss={epoch_loss}\")\n", "\n", " accuracy = validate(validation_dataset, device=device)\n", " print(f\"epoch {epoch}, validation accuracy={accuracy}\")" ] }, { "cell_type": "markdown", "id": "45e31b87", "metadata": { "execution": {} }, "source": [ "`yes_I_want_the_pretrained_model = True` will load a checkpoint that we already prepared, whereas setting it to `False` will train the model from scratch.\n", "\n", "Unceck the box below and run the cell to train a model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "W5KA7zDIa3Lw", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown\n", "yes_I_want_the_pretrained_model = True # @param {type:\"boolean\"}" ] }, { "cell_type": "code", "execution_count": null, "id": "53fb8dda", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Load a pretrained model or train the model from scratch\n", "\n", "# set this to True and run this cell if you want a shortcut\n", "\n", "if yes_I_want_the_pretrained_model:\n", " checkpoint = torch.load('checkpoints/synapses/classifier/vgg_checkpoint',\n", " map_location=device)\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", "else:\n", " train_from_scratch(dataloader, validation_dataset,\n", " optimizer, loss,\n", " num_epochs=100, device=device)" ] }, { "cell_type": "code", "execution_count": null, "id": "4f6e3663", "metadata": { "execution": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "test: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 346/346 [00:39<00:00, 8.72it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "final test accuracy: 0.49737888647866957\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "accuracy = test(test_dataset, device=device)\n", "print(f\"final test accuracy: {accuracy}\")" ] }, { "cell_type": "markdown", "id": "3f43bba5", "metadata": { "execution": {} }, "source": [ "This concludes the first section. We now have a classifier that can discriminate between images of different types.\n", "\n", "If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images. Furthermore, it is not so clear for humans how the classifier does it. Feel free to explore the data a bit more and see for yourself if you can tell the difference betwee, say, GABAergic and glutamatergic synapses.\n", "\n", "So this is an interesting situation: The VGG network knows something we don't quite know. In the next section, we will see how we can visualize the relevant differences between images of different types." ] }, { "cell_type": "markdown", "id": "72b5240c", "metadata": { "execution": {} }, "source": [ "---\n", "# Train a GAN to Translate Images\n", "\n", "We will train a so-called CycleGAN to translate images from one class to another." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get the CycleGAN code and dependencies\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " GitHub repo: https://github.com/funkey/neuromatch_xai\n" ] }, { "cell_type": "code", "execution_count": null, "id": "41c9e63b", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "# @title Get the CycleGAN code and dependencies\n", "\n", "# @markdown GitHub repo: https://github.com/funkey/neuromatch_xai\n", "\n", "import requests, zipfile, io\n", "\n", "url = 'https://osf.io/vutn5/download'\n", "r = requests.get(url)\n", "z = zipfile.ZipFile(io.BytesIO(r.content))\n", "z.extractall()\n", "\n", "!pip install dominate --quiet" ] }, { "cell_type": "markdown", "id": "e5da5c01", "metadata": { "execution": {} }, "source": [ "In this example, we will translate between GABAergic and glutamatergic synapses.\n", "\n", "First, we have to copy images of either type into a format that the CycleGAN library is happy with. Afterwards, we can start training on those images." ] }, { "cell_type": "code", "execution_count": null, "id": "2b2519c4", "metadata": { "execution": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22648/22648 [00:01<00:00, 20146.29it/s]\n", "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5085/5085 [00:00<00:00, 20709.24it/s]\n", "0it [00:00, ?it/s]\n", "0it [00:00, ?it/s]\n", "0it [00:00, ?it/s]\n", "0it [00:00, ?it/s]\n" ] } ], "source": [ "import cycle_gan\n", "\n", "cycle_gan.prepare_dataset('data/raw/synapses/', ['0_gaba', '2_glutamate'])\n", "\n", "## Uncomment if you want to enable the training procedure\n", "# cycle_gan.train('data/raw/synapses/', '0_gaba', '2_glutamate', 128)" ] }, { "cell_type": "markdown", "id": "0d328904", "metadata": { "execution": {} }, "source": [ "Training the CycleGAN takes a lot longer than the VGG we trained above (on the synapse dataset, this will be around 7 days...).\n", "\n", "To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "a182c3bc", "metadata": { "execution": {} }, "outputs": [], "source": [ "# translate images from class A to B, and classify each with the VGG network trained above\n", "# Note: cycle_gan requires CUDA devices\n", "if device == \"cuda\":\n", " cycle_gan.test(\n", " data_dir='data/raw/synapses/',\n", " class_A='0_gaba',\n", " class_B='2_glutamate',\n", " img_size=128,\n", " checkpoints_dir='checkpoints/synapses/cycle_gan/gaba_glutamate/',\n", " vgg_checkpoint='checkpoints/synapses/classifier/vgg_checkpoint'\n", " )" ] }, { "cell_type": "markdown", "id": "17fc1703", "metadata": { "execution": {} }, "source": [ "Read all translated images and sort them by how much the translation \"fools\" the VGG classifier trained above:" ] }, { "cell_type": "code", "execution_count": null, "id": "2a582ba6", "metadata": { "execution": {} }, "outputs": [], "source": [ "class_A_index = 0\n", "class_B_index = 2\n", "\n", "result_dir = 'data/raw/synapses/cycle_gan/0_gaba_2_glutamate/results/test_latest/images/'\n", "classification_results = []\n", "for f in glob.glob(result_dir + '/*.json'):\n", " result = json.load(open(f))\n", " result['basename'] = f.replace('_aux.json', '')\n", " classification_results.append(result)\n", "classification_results.sort(\n", " key=lambda c: c['aux_real'][class_A_index] * c['aux_fake'][class_B_index],\n", " reverse=True)" ] }, { "cell_type": "markdown", "id": "2cc0d486", "metadata": { "execution": {} }, "source": [ "Show the top real and fake images that make the classifier change its mind:" ] }, { "cell_type": "code", "execution_count": null, "id": "1567b00e", "metadata": { "execution": {} }, "outputs": [], "source": [ "def show_pair(a, b, score_a, score_b, class_a, class_b):\n", " fig, axs = plt.subplots(1, 2, figsize=(20, 20), sharey=True)\n", " axs[0].imshow(a, cmap='gray')\n", " axs[0].set_title(f\"p({class_a}) = \" + str(score_a))\n", " axs[1].imshow(b, cmap='gray')\n", " axs[1].set_title(f\"p({class_b}) = \" + str(score_b))\n", " plt.show()\n", "\n", "\n", "# show the top successful translations (according to our VGG classifier)\n", "# Note: only run if cycle_gan ran successfully\n", "if classification_results:\n", " for i in range(10):\n", " basename = classification_results[i]['basename']\n", " score_A = classification_results[i]['aux_real'][class_A_index]\n", " score_B = classification_results[i]['aux_fake'][class_B_index]\n", " real_A = imread(basename + '_real.png')\n", " fake_B = imread(basename + '_fake.png')\n", " show_pair(real_A, fake_B, score_A, score_B, 'gaba', 'glutamate')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "gpuType": "T4", "include_colab_link": true, "name": "em_synapses", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }