{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "id": "view-in-github" }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Data Augmentation in image classification models\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ [Jama Hussein Mohamud](https://engmubarak48.github.io/jmohamud/index.html), [Alex Hernandez-Garcia](https://alexhernandezgarcia.github.io/)\n", "\n", "__Production editors:__ Spiros Chavlis, Saeed Salehi\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "----\n", "# Objective\n", "\n", "Data augmentation refers to synthetically increasing the amount of training data by transforming the existing training examples. Data augmentation has been shown to be a very useful technique, especially in computer vision applications. However, there are multiple ways of performing data augmentation and it is yet to be understood which transformations are more effective and why, and how data augmentation interacts with other techniques. In fact, it is common to see different augmentation schemes and setups in different papers. For example, there are perceptually possible image transformations (related to human visual perception), simple synthetic transformations such as cutout, more artificial transformations such as mixup that even transform the class labels, among many others. \n", "\n", "In this notebook, we will show how to train deep neural networks for image classification with data augmentation and analyse the results." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install dependencies\n", "!pip install pandas --quiet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# imports\n", "import os\n", "import csv\n", "import multiprocessing\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import torch.backends.cudnn as cudnn\n", "from torch.autograd import Variable\n", "\n", "import torchvision\n", "import torchvision.transforms as transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set random seed\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Executing `set_seed(seed=seed)` you are setting the seed\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set random seed\n", "\n", "# @markdown Executing `set_seed(seed=seed)` you are setting the seed\n", "\n", "# for DL its critical to set the random seed so that students can have a\n", "# baseline to compare their results to expected results.\n", "# Read more here: https://pytorch.org/docs/stable/notes/randomness.html\n", "\n", "# Call `set_seed` function in the exercises to ensure reproducibility.\n", "import random\n", "import torch\n", "\n", "def set_seed(seed=None, seed_torch=True):\n", " if seed is None:\n", " seed = np.random.choice(2 ** 32)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " if seed_torch:\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.backends.cudnn.benchmark = False\n", " torch.backends.cudnn.deterministic = True\n", "\n", " print(f'Random seed {seed} has been set.')\n", "\n", "# In case that `DataLoader` is used\n", "def seed_worker(worker_id):\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set device (GPU or CPU)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set device (GPU or CPU)\n", "\n", "# inform the user if the notebook uses GPU or CPU.\n", "\n", "def set_device():\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " if device != \"cuda\":\n", " print(\"WARNING: For this notebook to perform best, \"\n", " \"if possible, in the menu under `Runtime` -> \"\n", " \"`Change runtime type.` select `GPU` \")\n", " else:\n", " print(\"GPU is enabled in this notebook.\")\n", "\n", " return device" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random seed 2021 has been set.\n", "GPU is enabled in this notebook.\n" ] }, { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'cuda'" ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "set_seed(seed=2021)\n", "set_device()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Training hyperparameters\n", "\n", "**Note:** We have reduced the number of epochs, `end_epochs`. The value was set to 200. Please, change it back and run the code." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# hyper-parameters\n", "use_cuda = torch.cuda.is_available()\n", "alpha = 1 # alpha for mixup augmentation\n", "best_acc = 0 # best test accuracy\n", "start_epoch = 0 # start from epoch 0 or last checkpoint epoch\n", "batch_size = 128\n", "end_apochs = 15 # Please change this to 200\n", "base_learning_rate = 0.1\n", "cutout = True # True/False if you want to use cutout augmentation\n", "mixup = False # True/False if you want to use mixup augmentation\n", "n_holes = 1 # number of holes to cut out from image for cutout\n", "length = 16 # length of the holes for cutout augmentation\n", "torchvision_transforms = False # True/False if you want use torchvision augmentations" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Augmentation" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Cutout\n", "Randomly mask out one or more patches from an image." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " `Cutout` Augmentation class\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown `Cutout` Augmentation class\n", "\n", "class Cutout(object):\n", " \"\"\"\n", " code from: https://github.com/uoguelph-mlrg/Cutout\n", "\n", " Randomly mask out one or more patches from an image.\n", " Args:\n", " n_holes (int): Number of patches to cut out of each image.\n", " length (int): The length (in pixels) of each square patch.\n", " \"\"\"\n", " def __init__(self, n_holes, length):\n", " self.n_holes = n_holes\n", " self.length = length\n", "\n", " def __call__(self, img):\n", " \"\"\"\n", " Args:\n", " img (Tensor): Tensor image of size (C, H, W).\n", " Returns:\n", " Tensor: Image with n_holes of dimension length x length cut out of it.\n", " \"\"\"\n", " h = img.size(1)\n", " w = img.size(2)\n", "\n", " mask = np.ones((h, w), np.float32)\n", "\n", " for n in range(self.n_holes):\n", " y = np.random.randint(h)\n", " x = np.random.randint(w)\n", "\n", " y1 = np.clip(y - self.length // 2, 0, h)\n", " y2 = np.clip(y + self.length // 2, 0, h)\n", " x1 = np.clip(x - self.length // 2, 0, w)\n", " x2 = np.clip(x + self.length // 2, 0, w)\n", "\n", " mask[y1: y2, x1: x2] = 0.\n", "\n", " mask = torch.from_numpy(mask)\n", " mask = mask.expand_as(img)\n", " img = img * mask\n", "\n", " return img" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Mixup\n", "\n", "Mixup is a data augmentation technique that combines pairs of examples via a convex combination of the images and the labels. Given images $x_i$ and $x_j$ with labels $y_i$ and $y_j$, respectively, and $\\lambda \\in [0, 1]$, mixup creates a new image $\\hat{x}$ with label $\\hat{y}$ the following way:\n", "\n", "\\begin{align}\n", "\\hat{x} &= \\lambda x_i + (1 - \\lambda) x_j \\\\\n", "\\hat{y} &= \\lambda y_i + (1 - \\lambda) y_j\n", "\\end{align}\n", "\n", "You may check the [original paper](https://arxiv.org/abs/1710.09412) and [code repository](https://github.com/hongyi-zhang/mixup)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " `mixup_data` Augmentation function\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown `mixup_data` Augmentation function\n", "\n", "def mixup_data(x, y, alpha=1.0, use_cuda=True):\n", " '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda\n", " - https://github.com/hongyi-zhang/mixup\n", " '''\n", " if alpha > 0.:\n", " lam = np.random.beta(alpha, alpha)\n", " else:\n", " lam = 1.\n", " batch_size = x.size()[0]\n", " if use_cuda:\n", " index = torch.randperm(batch_size).cuda()\n", " else:\n", " index = torch.randperm(batch_size)\n", "\n", " mixed_x = lam * x + (1 - lam) * x[index, :]\n", " y_a, y_b = y, y[index]\n", "\n", " return mixed_x, y_a, y_b, lam" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Data" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Datasets\n", "\n", "We will start using CIFAR-10 data set from PyTorch, but with small tweaks we can get any other data we are interested in. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Download and prepare Data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==> Preparing data...\n", "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./CIFAR10/cifar-10-python.tar.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3a3a8eff5d924100831e7e7cd0202361", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Extracting ./CIFAR10/cifar-10-python.tar.gz to ./CIFAR10\n", "Files already downloaded and verified\n" ] } ], "source": [ "# @markdown Download and prepare Data\n", "print('==> Preparing data...')\n", "def percentageSplit(full_dataset, percent=0.0):\n", " set1_size = int(percent * len(full_dataset))\n", " set2_size = len(full_dataset) - set1_size\n", " final_dataset, _ = torch.utils.data.random_split(full_dataset,\n", " [set1_size, set2_size])\n", " return final_dataset\n", "\n", "# CIFAR100 normalizing\n", "# mean = [0.5071, 0.4866, 0.4409]\n", "# std = [0.2673, 0.2564, 0.2762]\n", "\n", "# CIFAR10 normalizing\n", "mean = (0.4914, 0.4822, 0.4465)\n", "std = (0.2023, 0.1994, 0.2010)\n", "\n", "# torchvision transforms\n", "transform_train = transforms.Compose([])\n", "if torchvision_transforms:\n", " transform_train.transforms.append(transforms.RandomCrop(32, padding=4))\n", " transform_train.transforms.append(transforms.RandomHorizontalFlip())\n", "\n", "transform_train.transforms.append(transforms.ToTensor())\n", "transform_train.transforms.append(transforms.Normalize(mean, std))\n", "if cutout:\n", " transform_train.transforms.append(Cutout(n_holes=n_holes, length=length))\n", "\n", "transform_test = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean, std),\n", "])\n", "\n", "trainset = torchvision.datasets.CIFAR10(\n", " root='./CIFAR10', train=True, download=True,\n", " transform=transform_train)\n", "\n", "testset = torchvision.datasets.CIFAR10(\n", " root='./CIFAR10', train=False, download=True,\n", " transform=transform_test)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### CIFAR-10\n", "\n", "CIFAR-10 is a data set of 50,000 colour (RGB) training images and 10,000 test images, of size 32 x 32 pixels. Each image is labelled as 1 of 10 possible classes: \n", "```\n", "'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'\n", "```\n", "The data set is stored as a custom `torchvision.datasets.cifar.CIFAR` object. You can check some of its properties with the following code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Object type: \n", "Training data shape: (50000, 32, 32, 3)\n", "Test data shape: (10000, 32, 32, 3)\n", "Number of classes: 10\n" ] } ], "source": [ "print(f\"Object type: {type(trainset)}\")\n", "print(f\"Training data shape: {trainset.data.shape}\")\n", "print(f\"Test data shape: {testset.data.shape}\")\n", "print(f\"Number of classes: {np.unique(trainset.targets).shape[0]}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of the new trainset: 50000\n" ] } ], "source": [ "# choose percentage from the trainset. set percent = 1.0 to use the whole train data\n", "percent = 1.0\n", "trainset = percentageSplit(trainset, percent = percent)\n", "print(f\"size of the new trainset: {len(trainset)}\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Data loaders\n", "\n", "A dataloader is an optimized data iterator that provides functionality for efficient shuffling, transformation and batching of the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----> number of workers: 4\n" ] } ], "source": [ "# Dataloader\n", "num_workers = multiprocessing.cpu_count()\n", "\n", "print(f'----> number of workers: {num_workers}')\n", "\n", "trainloader = torch.utils.data.DataLoader(\n", " trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n", "testloader = torch.utils.data.DataLoader(\n", " testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Visualization\n", "\n", "To visualize some of the augmentations, make sure you set to ```True``` their corresponding flags in the hyperparameters section" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# get batch of data\n", "batch_X, batch_Y = next(iter(trainloader))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def plot_mixed_images(images):\n", " inv_normalize = transforms.Normalize(\n", " mean= [-m/s for m, s in zip(mean, std)],\n", " std= [1/s for s in std]\n", " )\n", " inv_PIL = transforms.ToPILImage()\n", " fig = plt.figure(figsize=(10, 8))\n", " for i in range(1, len(images) + 1):\n", " image = images[i-1]\n", " ax = fig.add_subplot(1, 4, i)\n", " inv_tensor = inv_normalize(image).cpu()\n", " ax.imshow(inv_PIL(inv_tensor))\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Mixup Visualization\n", "if mixup:\n", " alpha = 0.9\n", " mixed_x, y_a, y_b, lam = mixup_data(batch_X, batch_Y,\n", " alpha=alpha, use_cuda=use_cuda)\n", " plot_mixed_images(mixed_x[:4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "# Mixup Visualization\n", "if cutout:\n", " plot_mixed_images(batch_X[:4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Torchvision Visualization\n", "if torchvision_transforms:\n", " plot_mixed_images(batch_X[:4])" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Model" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Architecture: ResNet\n", "\n", "ResNet is a family of network architectures whose main property is that the network is organised as a stack of _residual blocks_. Residual blocks consist of a stack of layers whose output is added the input, making a _shortcut connection_.\n", "\n", "See the [original paper](https://arxiv.org/abs/1512.03385) for more details.\n", "\n", "ResNet is just a popular choice out of many others, but data augmentation works well in general. We just picked ResNet for illustration purposes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ResNet model in PyTorch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown ResNet model in PyTorch\n", "\n", "class BasicBlock(nn.Module):\n", " \"\"\"ResNet in PyTorch.\n", " Reference:\n", " [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun\n", " Deep Residual Learning for Image Recognition.\n", " arXiv:1512.03385\n", " \"\"\"\n", " expansion = 1\n", "\n", " def __init__(self, in_planes, planes, stride=1):\n", " super(BasicBlock, self).__init__()\n", " self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", "\n", " self.shortcut = nn.Sequential()\n", " if stride != 1 or in_planes != self.expansion*planes:\n", " self.shortcut = nn.Sequential(\n", " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(self.expansion*planes)\n", " )\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = self.bn2(self.conv2(out))\n", " out += self.shortcut(x)\n", " out = F.relu(out)\n", " return out\n", "\n", "class Bottleneck(nn.Module):\n", " expansion = 4\n", "\n", " def __init__(self, in_planes, planes, stride=1):\n", " super(Bottleneck, self).__init__()\n", " self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", " self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)\n", " self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n", "\n", " self.shortcut = nn.Sequential()\n", " if stride != 1 or in_planes != self.expansion*planes:\n", " self.shortcut = nn.Sequential(\n", " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(self.expansion*planes)\n", " )\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = F.relu(self.bn2(self.conv2(out)))\n", " out = self.bn3(self.conv3(out))\n", " out += self.shortcut(x)\n", " out = F.relu(out)\n", " return out\n", "\n", "\n", "class ResNet(nn.Module):\n", " def __init__(self, block, num_blocks, num_classes=10):\n", " super(ResNet, self).__init__()\n", " self.in_planes = 64\n", "\n", " self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(64)\n", " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n", " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n", " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n", " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n", " self.linear = nn.Linear(512*block.expansion, num_classes)\n", "\n", " def _make_layer(self, block, planes, num_blocks, stride):\n", " strides = [stride] + [1]*(num_blocks-1)\n", " layers = []\n", " for stride in strides:\n", " layers.append(block(self.in_planes, planes, stride))\n", " self.in_planes = planes * block.expansion\n", " return nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = self.layer1(out)\n", " out = self.layer2(out)\n", " out = self.layer3(out)\n", " out = self.layer4(out)\n", " out = F.avg_pool2d(out, 4)\n", " out = out.view(out.size(0), -1)\n", " out = self.linear(out)\n", " return out\n", "\n", "\n", "def ResNet18():\n", " return ResNet(BasicBlock, [2, 2, 2, 2])\n", "\n", "\n", "def ResNet34():\n", " return ResNet(BasicBlock, [3, 4, 6, 3])\n", "\n", "\n", "def ResNet50():\n", " return ResNet(Bottleneck, [3, 4, 6, 3])" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Model setup and test" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-----> verify if model is run on random data\n", "model loaded\n", "Using 1 GPUs.\n", "Using CUDA..\n" ] } ], "source": [ "# load the Model\n", "net = ResNet18()\n", "print('-----> verify if model is run on random data')\n", "y = net(Variable(torch.randn(1,3,32,32)))\n", "print('model loaded')\n", "\n", "result_folder = './results/'\n", "if not os.path.exists(result_folder):\n", " os.makedirs(result_folder)\n", "\n", "logname = result_folder + net.__class__.__name__ + '_' + '.csv'\n", "\n", "if use_cuda:\n", " net.cuda()\n", " net = torch.nn.DataParallel(net)\n", " print('Using', torch.cuda.device_count(), 'GPUs.')\n", " cudnn.benchmark = True\n", " print('Using CUDA..')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Training" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Loss function and Optimizer\n", "\n", "We use the cross entropy loss, commonly used for classification, and stochastic gradient descent (SGD) as optimizer, with momentum and weight decay." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# optimizer and criterion\n", "\n", "def mixup_criterion(y_a, y_b, lam):\n", " '''\n", " - Mixup criterion\n", " - https://github.com/hongyi-zhang/mixup\n", " '''\n", " return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)\n", "\n", "criterion = nn.CrossEntropyLoss() # only for test data\n", "optimizer = optim.SGD(net.parameters(), lr=base_learning_rate, momentum=0.9, weight_decay=1e-4)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Train and test loops" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Training & Test functions\n", "def train(epoch, alpha, use_cuda=False):\n", " print('\\nEpoch: %d' % epoch)\n", " net.train()\n", " train_loss = 0\n", " correct = 0\n", " total = 0\n", " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", " if use_cuda:\n", " inputs, targets = inputs.cuda(), targets.cuda()\n", " optimizer.zero_grad()\n", " if mixup:\n", " # generate mixed inputs, two one-hot label vectors and mixing coefficient\n", " inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha, use_cuda)\n", " inputs, targets_a, targets_b = Variable(inputs), Variable(targets_a), Variable(targets_b)\n", " outputs = net(inputs)\n", " loss_func = mixup_criterion(targets_a, targets_b, lam)\n", " loss = loss_func(criterion, outputs)\n", " else:\n", " inputs, targets = Variable(inputs), Variable(targets)\n", " outputs = net(inputs)\n", " loss = criterion(outputs, targets)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_loss += loss.item()\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += targets.size(0)\n", " if mixup:\n", " correct += lam * predicted.eq(targets_a.data).cpu().sum() + (1 - lam) * predicted.eq(targets_b.data).cpu().sum()\n", " else:\n", " correct += predicted.eq(targets.data).cpu().sum()\n", "\n", " if batch_idx % 500 == 0:\n", " print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n", " % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))\n", " return (train_loss/batch_idx, 100.*correct/total)\n", "\n", "\n", "def test(epoch, use_cuda=False):\n", " global best_acc\n", " net.eval()\n", " test_loss = 0\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for batch_idx, (inputs, targets) in enumerate(testloader):\n", " if use_cuda:\n", " inputs, targets = inputs.cuda(), targets.cuda()\n", " # inputs, targets = Variable(inputs, volatile=True), Variable(targets)\n", " outputs = net(inputs)\n", " loss = criterion(outputs, targets)\n", "\n", " test_loss += loss.item()\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += targets.size(0)\n", " correct += predicted.eq(targets.data).cpu().sum()\n", "\n", " if batch_idx % 200 == 0:\n", " print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n", " % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))\n", "\n", " # Save checkpoint.\n", " acc = 100.*correct/total\n", " if acc > best_acc:\n", " best_acc = acc\n", " checkpoint(acc, epoch)\n", " return (test_loss/batch_idx, 100.*correct/total)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Auxiliary functions\n", "\n", "* `checkpoint()`: Store checkpoints of the model\n", "* `adjust_learning_rate()`: Decreases the learning rate (learning rate decay) at certain epochs of training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " `checkpoint` and `adjust_learning_rate` functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown `checkpoint` and `adjust_learning_rate` functions\n", "def checkpoint(acc, epoch):\n", " # Save checkpoint.\n", " print('Saving..')\n", " state = {\n", " 'net': net.state_dict(),\n", " 'acc': acc,\n", " 'epoch': epoch,\n", " 'rng_state': torch.get_rng_state()\n", " }\n", " if not os.path.isdir('checkpoint'):\n", " os.mkdir('checkpoint')\n", " torch.save(state, './checkpoint/ckpt.t7')\n", "\n", "\n", "def adjust_learning_rate(optimizer, epoch):\n", " \"\"\"decrease the learning rate at 100 and 150 epoch\"\"\"\n", " lr = base_learning_rate\n", " if epoch <= 9 and lr > 0.1:\n", " # warm-up training for large minibatch\n", " lr = 0.1 + (base_learning_rate - 0.1) * epoch / 10.\n", " if epoch >= 100:\n", " lr /= 10\n", " if epoch >= 150:\n", " lr /= 10\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = lr" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch: 0\n", "0 391 Loss: 2.443 | Acc: 10.938% (14/128)\n", "0 79 Loss: 1.531 | Acc: 46.094% (59/128)\n", "Saving..\n", "Epoch: 0 | train acc: 31.604000091552734 | test acc: 44.2599983215332\n", "\n", "Epoch: 1\n", "0 391 Loss: 1.619 | Acc: 39.844% (51/128)\n", "0 79 Loss: 1.199 | Acc: 60.156% (77/128)\n", "Saving..\n", "Epoch: 1 | train acc: 47.03200149536133 | test acc: 54.41999816894531\n", "\n", "Epoch: 2\n", "0 391 Loss: 1.301 | Acc: 53.906% (69/128)\n", "0 79 Loss: 1.013 | Acc: 61.719% (79/128)\n", "Saving..\n", "Epoch: 2 | train acc: 56.257999420166016 | test acc: 62.599998474121094\n", "\n", "Epoch: 3\n", "0 391 Loss: 1.036 | Acc: 64.062% (82/128)\n", "0 79 Loss: 0.909 | Acc: 69.531% (89/128)\n", "Saving..\n", "Epoch: 3 | train acc: 62.43199920654297 | test acc: 65.6500015258789\n", "\n", "Epoch: 4\n", "0 391 Loss: 0.839 | Acc: 68.750% (88/128)\n", "0 79 Loss: 0.859 | Acc: 70.312% (90/128)\n", "Saving..\n", "Epoch: 4 | train acc: 67.0 | test acc: 69.08999633789062\n", "\n", "Epoch: 5\n", "0 391 Loss: 0.922 | Acc: 64.844% (83/128)\n", "0 79 Loss: 0.660 | Acc: 76.562% (98/128)\n", "Saving..\n", "Epoch: 5 | train acc: 70.1259994506836 | test acc: 72.52999877929688\n", "\n", "Epoch: 6\n", "0 391 Loss: 0.833 | Acc: 65.625% (84/128)\n", "0 79 Loss: 0.616 | Acc: 78.125% (100/128)\n", "Saving..\n", "Epoch: 6 | train acc: 73.45999908447266 | test acc: 73.45999908447266\n", "\n", "Epoch: 7\n", "0 391 Loss: 0.686 | Acc: 75.000% (96/128)\n", "0 79 Loss: 0.533 | Acc: 81.250% (104/128)\n", "Saving..\n", "Epoch: 7 | train acc: 75.99600219726562 | test acc: 75.91000366210938\n", "\n", "Epoch: 8\n", "0 391 Loss: 0.626 | Acc: 78.125% (100/128)\n", "0 79 Loss: 0.458 | Acc: 82.031% (105/128)\n", "Saving..\n", "Epoch: 8 | train acc: 78.42400360107422 | test acc: 79.11000061035156\n", "\n", "Epoch: 9\n", "0 391 Loss: 0.465 | Acc: 85.938% (110/128)\n", "0 79 Loss: 0.465 | Acc: 87.500% (112/128)\n", "Saving..\n", "Epoch: 9 | train acc: 80.72599792480469 | test acc: 80.37000274658203\n", "\n", "Epoch: 10\n", "0 391 Loss: 0.509 | Acc: 81.250% (104/128)\n", "0 79 Loss: 0.523 | Acc: 79.688% (102/128)\n", "Epoch: 10 | train acc: 82.16400146484375 | test acc: 79.25\n", "\n", "Epoch: 11\n", "0 391 Loss: 0.423 | Acc: 82.031% (105/128)\n", "0 79 Loss: 0.610 | Acc: 78.125% (100/128)\n", "Epoch: 11 | train acc: 83.96199798583984 | test acc: 79.68000030517578\n", "\n", "Epoch: 12\n", "0 391 Loss: 0.221 | Acc: 89.844% (115/128)\n", "0 79 Loss: 0.467 | Acc: 82.812% (106/128)\n", "Saving..\n", "Epoch: 12 | train acc: 85.61799621582031 | test acc: 80.88999938964844\n", "\n", "Epoch: 13\n", "0 391 Loss: 0.427 | Acc: 85.938% (110/128)\n", "0 79 Loss: 0.522 | Acc: 82.812% (106/128)\n", "Saving..\n", "Epoch: 13 | train acc: 87.21199798583984 | test acc: 81.54000091552734\n", "\n", "Epoch: 14\n", "0 391 Loss: 0.216 | Acc: 93.750% (120/128)\n", "0 79 Loss: 0.386 | Acc: 86.719% (111/128)\n", "Epoch: 14 | train acc: 88.08000183105469 | test acc: 81.44999694824219\n" ] } ], "source": [ "# start training\n", "if not os.path.exists(logname):\n", " with open(logname, 'w') as logfile:\n", " logwriter = csv.writer(logfile, delimiter=',')\n", " logwriter.writerow(['epoch', 'train loss', 'train acc',\n", " 'test loss', 'test acc'])\n", "\n", "for epoch in range(start_epoch, end_apochs):\n", " adjust_learning_rate(optimizer, epoch)\n", " train_loss, train_acc = train(epoch, alpha, use_cuda=use_cuda)\n", " test_loss, test_acc = test(epoch, use_cuda=use_cuda)\n", " with open(logname, 'a') as logfile:\n", " logwriter = csv.writer(logfile, delimiter=',')\n", " logwriter.writerow([epoch, train_loss, train_acc.item(),\n", " test_loss, test_acc.item()])\n", " print(f'Epoch: {epoch} | train acc: {train_acc} | test acc: {test_acc}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain losstrain acctest losstest acc
001.93213031.6040001.53523344.259998
111.44686347.0320011.26277954.419998
221.21251856.2579991.06959362.599998
331.05185062.4319990.99647665.650002
440.92813167.0000000.89835469.089996
\n", "
" ], "text/plain": [ " epoch train loss train acc test loss test acc\n", "0 0 1.932130 31.604000 1.535233 44.259998\n", "1 1 1.446863 47.032001 1.262779 54.419998\n", "2 2 1.212518 56.257999 1.069593 62.599998\n", "3 3 1.051850 62.431999 0.996476 65.650002\n", "4 4 0.928131 67.000000 0.898354 69.089996" ] }, "execution_count": 24, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# plot results\n", "results = pd.read_csv('/content/results/ResNet_.csv', sep=',')\n", "results.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average test Accuracy over 15 epochs: 72.0\n", "best test accuraccy over 15 epochs: 81.54000091552734\n" ] } ], "source": [ "train_accuracy = results['train acc'].values\n", "test_accuracy = results['test acc'].values\n", "\n", "print(f\"Average test Accuracy over {end_apochs} epochs: {sum(test_accuracy)//len(test_accuracy)}\")\n", "print(f\"best test accuraccy over {end_apochs} epochs: {max(test_accuracy)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "figureName = 'WithMixUp' # change figure name\n", "\n", "plt.figure(figsize=(9, 6))\n", "plt.plot(results['epoch'].values, train_accuracy, label='train')\n", "plt.plot(results['epoch'].values, test_accuracy, label='test')\n", "plt.xlabel('Number of epochs')\n", "plt.ylabel('Accuracy')\n", "plt.title(f'Train/Test Accuracy curve for {end_apochs} epochs')\n", "plt.savefig(f'/content/results/{figureName}.png')\n", "plt.legend()\n", "plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "MHZXiV5ZsWXq" ], "include_colab_link": true, "machine_shape": "hm", "name": "data_augmentation", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }