{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "id": "view-in-github" }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Transfer Learning \n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ [Jama Hussein Mohamud](https://engmubarak48.github.io/jmohamud/index.html) & [Alex Hernandez-Garcia](https://alexhernandezgarcia.github.io/)\n", "\n", "__Production editors:__ Saeed Salehi, Spiros Chavlis\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Objective\n", "\n", "One desired capability for machines is the ability to transfer the knowledge (features) learned on one domain to another This can potentially save compute time, enable training when data is scarce, and even improve performance. Unfortunately, there is no single recipe for transfer learning and instead multiple options are possible and much remains to be well understood. In this project, you will explore how transfer learning works in different scenarios. " ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# imports\n", "import os\n", "import gc\n", "import csv\n", "import glob\n", "import torch\n", "import multiprocessing\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import torch.nn as nn\n", "import matplotlib.pyplot as plt\n", "\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import torch.backends.cudnn as cudnn\n", "from torch.autograd import Variable\n", "\n", "import torchvision\n", "import torchvision.transforms as transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set random seed\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Executing `set_seed(seed=seed)` you are setting the seed\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set random seed\n", "\n", "# @markdown Executing `set_seed(seed=seed)` you are setting the seed\n", "\n", "# for DL its critical to set the random seed so that students can have a\n", "# baseline to compare their results to expected results.\n", "# Read more here: https://pytorch.org/docs/stable/notes/randomness.html\n", "\n", "# Call `set_seed` function in the exercises to ensure reproducibility.\n", "import random\n", "import torch\n", "\n", "def set_seed(seed=None, seed_torch=True):\n", " if seed is None:\n", " seed = np.random.choice(2 ** 32)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " if seed_torch:\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.backends.cudnn.benchmark = False\n", " torch.backends.cudnn.deterministic = True\n", "\n", " print(f'Random seed {seed} has been set.')\n", "\n", "# In case that `DataLoader` is used\n", "def seed_worker(worker_id):\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set device (GPU or CPU)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set device (GPU or CPU)\n", "\n", "# inform the user if the notebook uses GPU or CPU.\n", "\n", "def set_device():\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " if device != \"cuda\":\n", " print(\"WARNING: For this notebook to perform best, \"\n", " \"if possible, in the menu under `Runtime` -> \"\n", " \"`Change runtime type.` select `GPU` \")\n", " else:\n", " print(\"GPU is enabled in this notebook.\")\n", "\n", " return device" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Random seeds\n", "\n", "If you want to obtain reproducible results, it is a good practice to set seeds for the random number generators of the various libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random seed 2021 has been set.\n", "GPU is enabled in this notebook.\n" ] } ], "source": [ "set_seed(seed=2021)\n", "device = set_device()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Training hyperparameters\n", "\n", "Here we set some general training hyperparameters such as the learning rate, batch size, etc. as well as other training options such as including data augmentation (`torchvision_transforms`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# hyper-parameters\n", "use_cuda = torch.cuda.is_available()\n", "best_acc = 0 # best test accuracy\n", "start_epoch = 0 # start from epoch 0 or last checkpoint epoch\n", "batch_size = 128\n", "max_epochs = 15 # Please change this to 200\n", "max_epochs_target = 10\n", "base_learning_rate = 0.1\n", "torchvision_transforms = True # True/False if you want use torchvision augmentations" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Data" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Source dataset\n", "\n", "We will train the source model using CIFAR-100 data set from PyTorch, but with small tweaks we can get any other data we are interested in.\n", "\n", "Note that the data set is normalised by substracted the mean and dividing by the standard deviation (pre-computed) of the training set. Also, if `torchvision_transforms` is `True`, data augmentation will be applied during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Download and prepare Data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==> Preparing data..\n", "Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./CIFAR100/cifar-100-python.tar.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3dd2d265bd4a44de80738813df1a1b1e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=169001437.0), HTML(value='')))" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Extracting ./CIFAR100/cifar-100-python.tar.gz to ./CIFAR100\n", "Files already downloaded and verified\n" ] } ], "source": [ "# @markdown Download and prepare Data\n", "print('==> Preparing data..')\n", "def percentageSplit(full_dataset, percent = 0.0):\n", " set1_size = int(percent * len(full_dataset))\n", " set2_size = len(full_dataset) - set1_size\n", " final_dataset, _ = torch.utils.data.random_split(full_dataset, [set1_size, set2_size])\n", " return final_dataset\n", "\n", "\n", "# CIFAR100 normalizing\n", "mean = [0.5071, 0.4866, 0.4409]\n", "std = [0.2673, 0.2564, 0.2762]\n", "\n", "# CIFAR10 normalizing\n", "# mean = (0.4914, 0.4822, 0.4465)\n", "# std = (0.2023, 0.1994, 0.2010)\n", "\n", "# torchvision transforms\n", "transform_train = transforms.Compose([])\n", "if torchvision_transforms:\n", " transform_train.transforms.append(transforms.RandomCrop(32, padding=4))\n", " transform_train.transforms.append(transforms.RandomHorizontalFlip())\n", "\n", "transform_train.transforms.append(transforms.ToTensor())\n", "transform_train.transforms.append(transforms.Normalize(mean, std))\n", "\n", "transform_test = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean, std),\n", "])\n", "\n", "trainset = torchvision.datasets.CIFAR100(\n", " root='./CIFAR100', train=True, download=True, transform=transform_train)\n", "\n", "testset = torchvision.datasets.CIFAR100(\n", " root='./CIFAR100', train=False, download=True, transform=transform_test)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### CIFAR-100\n", "\n", "CIFAR-100 is a data set of 50,000 colour (RGB) training images and 10,000 test images, of size 32 x 32 pixels. Each image is labelled as 1 of 100 possible classes. \n", "\n", "The data set is stored as a custom `torchvision.datasets.cifar.CIFAR` object. You can check some of its properties with the following code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Object type: \n", "Training data shape: (50000, 32, 32, 3)\n", "Test data shape: (10000, 32, 32, 3)\n", "Number of classes: 100\n" ] } ], "source": [ "print(f\"Object type: {type(trainset)}\")\n", "print(f\"Training data shape: {trainset.data.shape}\")\n", "print(f\"Test data shape: {testset.data.shape}\")\n", "print(f\"Number of classes: {np.unique(trainset.targets).shape[0]}\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Data loaders\n", "\n", "A dataloader is an optimized data iterator that provides functionality for efficient shuffling, transformation and batching of the data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataloader\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----> number of workers: 2\n" ] } ], "source": [ "##@title Dataloader\n", "num_workers = multiprocessing.cpu_count()\n", "\n", "print(f'----> number of workers: {num_workers}')\n", "\n", "trainloader = torch.utils.data.DataLoader(\n", " trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n", "testloader = torch.utils.data.DataLoader(\n", " testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Architecture: ResNet\n", "\n", "ResNet is a family of network architectures whose main property is that the network is organised as a stack of _residual blocks_. Residual blocks consist of a stack of layers whose output is added the input, making a _shortcut connection_.\n", "\n", "See the [original paper](https://arxiv.org/abs/1512.03385) for more details.\n", "\n", "ResNet is just a popular choice out of many others, but data augmentation works well in general. We just picked ResNet for illustration purposes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ResNet model in PyTorch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title ResNet model in PyTorch\n", "\n", "class BasicBlock(nn.Module):\n", " \"\"\"ResNet in PyTorch.\n", " Reference:\n", " [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun\n", " Deep Residual Learning for Image Recognition. arXiv:1512.03385\n", " \"\"\"\n", "\n", " expansion = 1\n", "\n", " def __init__(self, in_planes, planes, stride=1):\n", " super(BasicBlock, self).__init__()\n", " self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", "\n", " self.shortcut = nn.Sequential()\n", " if stride != 1 or in_planes != self.expansion*planes:\n", " self.shortcut = nn.Sequential(\n", " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(self.expansion*planes)\n", " )\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = self.bn2(self.conv2(out))\n", " out += self.shortcut(x)\n", " out = F.relu(out)\n", " return out\n", "\n", "\n", "class Bottleneck(nn.Module):\n", " expansion = 4\n", "\n", " def __init__(self, in_planes, planes, stride=1):\n", " super(Bottleneck, self).__init__()\n", " self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", " self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)\n", " self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n", "\n", " self.shortcut = nn.Sequential()\n", " if stride != 1 or in_planes != self.expansion*planes:\n", " self.shortcut = nn.Sequential(\n", " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(self.expansion*planes)\n", " )\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = F.relu(self.bn2(self.conv2(out)))\n", " out = self.bn3(self.conv3(out))\n", " out += self.shortcut(x)\n", " out = F.relu(out)\n", " return out\n", "\n", "\n", "class ResNet(nn.Module):\n", " def __init__(self, block, num_blocks, num_classes=100):\n", " super(ResNet, self).__init__()\n", " self.in_planes = 64\n", "\n", " self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(64)\n", " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n", " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n", " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n", " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n", " self.linear = nn.Linear(512*block.expansion, num_classes)\n", "\n", " def _make_layer(self, block, planes, num_blocks, stride):\n", " strides = [stride] + [1]*(num_blocks-1)\n", " layers = []\n", " for stride in strides:\n", " layers.append(block(self.in_planes, planes, stride))\n", " self.in_planes = planes * block.expansion\n", " return nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = self.layer1(out)\n", " out = self.layer2(out)\n", " out = self.layer3(out)\n", " out = self.layer4(out)\n", " out = F.avg_pool2d(out, 4)\n", " out = out.view(out.size(0), -1)\n", " out = self.linear(out)\n", " return out\n", "\n", "\n", "def ResNet18():\n", " return ResNet(BasicBlock, [2, 2, 2, 2])\n", "\n", "\n", "def ResNet34():\n", " return ResNet(BasicBlock, [3, 4, 6, 3])\n", "\n", "\n", "def ResNet50():\n", " return ResNet(Bottleneck, [3, 4, 6, 3])" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Test on random data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-----> verify if model is run on random data\n", "model loaded\n", "Using 1 GPUs.\n", "Using CUDA..\n" ] } ], "source": [ "# Load the Model\n", "net = ResNet18()\n", "print('-----> verify if model is run on random data')\n", "y = net(Variable(torch.randn(1,3,32,32)))\n", "print('model loaded')\n", "\n", "result_folder = './results/'\n", "if not os.path.exists(result_folder):\n", " os.makedirs(result_folder)\n", "\n", "logname = result_folder + net.__class__.__name__ + '_pretrain' + '.csv'\n", "\n", "if use_cuda:\n", " net.cuda()\n", " net = torch.nn.DataParallel(net)\n", " print('Using', torch.cuda.device_count(), 'GPUs.')\n", " cudnn.benchmark = True\n", " print('Using CUDA..')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Set up training" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Set loss function and optimizer\n", "\n", "We use the cross entropy loss, commonly used for classification, and stochastic gradient descent (SGD) as optimizer, with momentum and weight decay." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Optimizer and criterion\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.SGD(net.parameters(), lr=base_learning_rate, momentum=0.9, weight_decay=1e-4)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Train and test loops" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Training & Test functions\n", "\n", "def train(net, epoch, use_cuda=True):\n", " print('\\nEpoch: %d' % epoch)\n", " net.train()\n", " train_loss = 0\n", " correct = 0\n", " total = 0\n", " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", " if use_cuda:\n", " inputs, targets = inputs.cuda(), targets.cuda()\n", "\n", " optimizer.zero_grad()\n", " inputs, targets = Variable(inputs), Variable(targets)\n", " outputs = net(inputs)\n", " loss = criterion(outputs, targets)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_loss += loss.item()\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += targets.size(0)\n", " correct += predicted.eq(targets.data).cpu().sum()\n", "\n", " if batch_idx % 500 == 0:\n", " print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n", " % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))\n", " return (train_loss/batch_idx, 100.*correct/total)\n", "\n", "\n", "def test(net, epoch, outModelName, use_cuda=True):\n", " global best_acc\n", " net.eval()\n", " test_loss, correct, total = 0, 0, 0\n", " with torch.no_grad():\n", " for batch_idx, (inputs, targets) in enumerate(testloader):\n", " if use_cuda:\n", " inputs, targets = inputs.cuda(), targets.cuda()\n", "\n", " outputs = net(inputs)\n", " loss = criterion(outputs, targets)\n", "\n", " test_loss += loss.item()\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += targets.size(0)\n", " correct += predicted.eq(targets.data).cpu().sum()\n", "\n", " if batch_idx % 200 == 0:\n", " print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n", " % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))\n", "\n", " # Save checkpoint.\n", " acc = 100.*correct/total\n", " if acc > best_acc:\n", " best_acc = acc\n", " checkpoint(net, acc, epoch, outModelName)\n", " return (test_loss/batch_idx, 100.*correct/total)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Auxiliary functions\n", "\n", "* `checkpoint()`: Store checkpoints of the model\n", "* `adjust_learning_rate()`: Decreases the learning rate (learning rate decay) at certain epochs of training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# checkpoint & adjust_learning_rate\n", "def checkpoint(model, acc, epoch, outModelName):\n", " # Save checkpoint.\n", " print('Saving..')\n", " state = {\n", " 'state_dict': model.state_dict(),\n", " 'acc': acc,\n", " 'epoch': epoch,\n", " 'rng_state': torch.get_rng_state()\n", " }\n", " if not os.path.isdir('checkpoint'):\n", " os.mkdir('checkpoint')\n", " torch.save(state, f'./checkpoint/{outModelName}.t7')\n", "\n", "def adjust_learning_rate(optimizer, epoch):\n", " \"\"\"decrease the learning rate at 100 and 150 epoch\"\"\"\n", " lr = base_learning_rate\n", " if epoch <= 9 and lr > 0.1:\n", " # warm-up training for large minibatch\n", " lr = 0.1 + (base_learning_rate - 0.1) * epoch / 10.\n", " if epoch >= 100:\n", " lr /= 10\n", " if epoch >= 150:\n", " lr /= 10\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = lr" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Train the model\n", "\n", "This is the loop where the model is trained for `max_epochs` epochs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch: 0\n", "0 391 Loss: 4.748 | Acc: 0.781% (1/128)\n", "0 79 Loss: 3.545 | Acc: 12.500% (16/128)\n", "Saving..\n", "Epoch: 0 | train acc: 8.527999877929688 | test acc: 13.369999885559082\n", "\n", "Epoch: 1\n", "0 391 Loss: 3.597 | Acc: 16.406% (21/128)\n", "0 79 Loss: 3.157 | Acc: 23.438% (30/128)\n", "Saving..\n", "Epoch: 1 | train acc: 18.392000198364258 | test acc: 21.829999923706055\n", "\n", "Epoch: 2\n", "0 391 Loss: 2.932 | Acc: 26.562% (34/128)\n", "0 79 Loss: 2.450 | Acc: 39.844% (51/128)\n", "Saving..\n", "Epoch: 2 | train acc: 27.016000747680664 | test acc: 31.079999923706055\n", "\n", "Epoch: 3\n", "0 391 Loss: 2.649 | Acc: 35.938% (46/128)\n", "0 79 Loss: 2.134 | Acc: 39.844% (51/128)\n", "Saving..\n", "Epoch: 3 | train acc: 35.84000015258789 | test acc: 35.70000076293945\n", "\n", "Epoch: 4\n", "0 391 Loss: 2.153 | Acc: 41.406% (53/128)\n", "0 79 Loss: 1.911 | Acc: 49.219% (63/128)\n", "Saving..\n", "Epoch: 4 | train acc: 42.827999114990234 | test acc: 43.619998931884766\n", "\n", "Epoch: 5\n", "0 391 Loss: 1.878 | Acc: 50.000% (64/128)\n", "0 79 Loss: 2.149 | Acc: 43.750% (56/128)\n", "Saving..\n", "Epoch: 5 | train acc: 48.87200164794922 | test acc: 45.380001068115234\n", "\n", "Epoch: 6\n", "0 391 Loss: 1.814 | Acc: 51.562% (66/128)\n", "0 79 Loss: 1.847 | Acc: 46.875% (60/128)\n", "Saving..\n", "Epoch: 6 | train acc: 53.59000015258789 | test acc: 50.310001373291016\n", "\n", "Epoch: 7\n", "0 391 Loss: 1.514 | Acc: 56.250% (72/128)\n", "0 79 Loss: 1.568 | Acc: 51.562% (66/128)\n", "Saving..\n", "Epoch: 7 | train acc: 57.35200119018555 | test acc: 54.209999084472656\n", "\n", "Epoch: 8\n", "0 391 Loss: 1.194 | Acc: 62.500% (80/128)\n", "0 79 Loss: 1.403 | Acc: 59.375% (76/128)\n", "Saving..\n", "Epoch: 8 | train acc: 60.61600112915039 | test acc: 57.20000076293945\n", "\n", "Epoch: 9\n", "0 391 Loss: 1.124 | Acc: 69.531% (89/128)\n", "0 79 Loss: 1.339 | Acc: 64.844% (83/128)\n", "Saving..\n", "Epoch: 9 | train acc: 63.55400085449219 | test acc: 58.900001525878906\n", "\n", "Epoch: 10\n", "0 391 Loss: 1.013 | Acc: 72.656% (93/128)\n", "0 79 Loss: 1.225 | Acc: 66.406% (85/128)\n", "Epoch: 10 | train acc: 65.91999816894531 | test acc: 58.83000183105469\n", "\n", "Epoch: 11\n", "0 391 Loss: 0.971 | Acc: 64.844% (83/128)\n", "0 79 Loss: 1.491 | Acc: 63.281% (81/128)\n", "Epoch: 11 | train acc: 68.05000305175781 | test acc: 57.560001373291016\n", "\n", "Epoch: 12\n", "0 391 Loss: 1.028 | Acc: 70.312% (90/128)\n", "0 79 Loss: 1.358 | Acc: 63.281% (81/128)\n", "Saving..\n", "Epoch: 12 | train acc: 69.99600219726562 | test acc: 60.099998474121094\n", "\n", "Epoch: 13\n", "0 391 Loss: 0.699 | Acc: 82.812% (106/128)\n", "0 79 Loss: 1.299 | Acc: 63.281% (81/128)\n", "Saving..\n", "Epoch: 13 | train acc: 71.9739990234375 | test acc: 60.220001220703125\n", "\n", "Epoch: 14\n", "0 391 Loss: 0.768 | Acc: 75.781% (97/128)\n", "0 79 Loss: 1.182 | Acc: 74.219% (95/128)\n", "Saving..\n", "Epoch: 14 | train acc: 73.69400024414062 | test acc: 62.90999984741211\n" ] } ], "source": [ "# Start training\n", "outModelName = 'pretrain'\n", "if not os.path.exists(logname):\n", " with open(logname, 'w') as logfile:\n", " logwriter = csv.writer(logfile, delimiter=',')\n", " logwriter.writerow(['epoch', 'train loss', 'train acc', 'test loss', 'test acc'])\n", "\n", "for epoch in range(start_epoch, max_epochs):\n", " adjust_learning_rate(optimizer, epoch)\n", " train_loss, train_acc = train(net, epoch, use_cuda=use_cuda)\n", " test_loss, test_acc = test(net, epoch, outModelName, use_cuda=use_cuda)\n", " with open(logname, 'a') as logfile:\n", " logwriter = csv.writer(logfile, delimiter=',')\n", " logwriter.writerow([epoch, train_loss, train_acc.item(), test_loss, test_acc.item()])\n", " print(f'Epoch: {epoch} | train acc: {train_acc} | test acc: {test_acc}')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Transfer learning\n", "### Re-use the trained model to improve training on a different data set" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Delete variables from the previous model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# delete the backbone network\n", "delete = True\n", "if delete:\n", " del net\n", " del trainset\n", " del testset\n", " del trainloader\n", " del testloader\n", " gc.collect()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Target dataset\n", "\n", "We will now use CIFAR-10 as _target_ data set. Again, with small tweaks we can get any other data we are interested in.\n", "\n", "CIFAR-10 is very similar to CIFAR-100, but it contains only 10 classes instead of 100." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==> Preparing target domain data..\n", "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./CIFAR10/cifar-10-python.tar.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ef171f0d74cc467ab4b78cd32bba5071", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Extracting ./CIFAR10/cifar-10-python.tar.gz to ./CIFAR10\n", "Files already downloaded and verified\n" ] } ], "source": [ "# Target domain Data\n", "print('==> Preparing target domain data..')\n", "\n", "# CIFAR10 normalizing\n", "mean = (0.4914, 0.4822, 0.4465)\n", "std = (0.2023, 0.1994, 0.2010)\n", "num_classes = 10\n", "lr = 0.0001\n", "\n", "# torchvision transforms\n", "transform_train = transforms.Compose([])\n", "if torchvision_transforms:\n", " transform_train.transforms.append(transforms.RandomCrop(32, padding=4))\n", " transform_train.transforms.append(transforms.RandomHorizontalFlip())\n", "\n", "transform_train.transforms.append(transforms.ToTensor())\n", "transform_train.transforms.append(transforms.Normalize(mean, std))\n", "\n", "transform_test = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean, std),\n", "])\n", "\n", "trainset = torchvision.datasets.CIFAR10(\n", " root='./CIFAR10', train=True, download=True, transform=transform_train)\n", "\n", "testset = torchvision.datasets.CIFAR10(\n", " root='./CIFAR10', train=False, download=True, transform=transform_test)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Select a subset of the data\n", "\n", "To simulate a lower data regime, where transfer learning can be useful.\n", "\n", "Choose percentage from the trainset. Set `percent = 1.0` to use the whole train data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of the new trainset: 30000\n" ] } ], "source": [ "percent = 0.6\n", "\n", "trainset = percentageSplit(trainset, percent = percent)\n", "print('size of the new trainset: ', len(trainset))" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Dataloaders\n", "\n", "As before" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----> number of workers: 2\n" ] } ], "source": [ "# Dataloader\n", "num_workers = multiprocessing.cpu_count()\n", "\n", "print(f'----> number of workers: {num_workers}')\n", "\n", "trainloader = torch.utils.data.DataLoader(\n", " trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n", "testloader = torch.utils.data.DataLoader(\n", " testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Load pre-trained model\n", "\n", "Load the checkpoint of the model previously trained on CIFAR-100" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ===> loading pretrained model from: /content/checkpoint/pretrain.t7\n", "Best Accuracy: tensor(62.9100)\n", "Load pretrained model with msg: \n" ] } ], "source": [ "model = ResNet18()\n", "\n", "checkpointPath = '/content/checkpoint/pretrain.t7'\n", "\n", "print(' ===> loading pretrained model from: ', checkpointPath)\n", "if os.path.isfile(checkpointPath):\n", " state_dict = torch.load(checkpointPath)\n", " best_acc = state_dict['acc']\n", " print('Best Accuracy:', best_acc)\n", " if \"state_dict\" in state_dict:\n", " state_dict = state_dict[\"state_dict\"]\n", " # remove prefixe \"module.\"\n", " state_dict = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n", " for k, v in model.state_dict().items():\n", " if k not in list(state_dict):\n", " print('key \"{}\" could not be found in provided state dict'.format(k))\n", " elif state_dict[k].shape != v.shape:\n", " print('key \"{}\" is of different shape in model and provided state dict'.format(k))\n", " state_dict[k] = v\n", " msg = model.load_state_dict(state_dict, strict=False)\n", " print(\"Load pretrained model with msg: {}\".format(msg))\n", "else:\n", " raise Exception('No pretrained weights found')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Freeze model parameters\n", "\n", "In transfer learning, we usually do not re-train all the weights of the model, but only a subset of them, for instance the last layer. Here we first _freeze_ all the parameters of the model, and we will _unfreeze_ one layer below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Freeze the model parameters, you can also freeze some layers only\n", "\n", "for param in model.parameters():\n", " param.requires_grad = False" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Loss function, optimizer and _unfreeze_ last layer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "num_ftrs = model.linear.in_features\n", "model.linear = nn.Linear(num_ftrs, num_classes)\n", "\n", "model.to(device)\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(\n", " model.linear.parameters(),\n", " lr=lr,\n", " momentum=0.9,\n", " weight_decay=1e-4,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Check number of parameters\n", "\n", "We can calculate the number of total parameters and the number of trainable parameters, that is those that will be updated during training. Since we have freezed most of the parameters, the number of training parameters should be much smaller." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total Parameters: 11173962 Trainable parameters: 5130\n" ] } ], "source": [ "total_params = sum(p.numel() for p in model.parameters())\n", "trainable_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "print('Total Parameters:', total_params, 'Trainable parameters: ', trainable_total_params)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Train the target model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch: 0\n", "0 235 Loss: 2.302 | Acc: 16.406% (21/128)\n", "0 79 Loss: 0.630 | Acc: 79.688% (102/128)\n", "Saving..\n", "Epoch: 0 | train acc: 71.086669921875 | test acc: 75.08999633789062\n", "\n", "Epoch: 1\n", "0 235 Loss: 0.757 | Acc: 72.656% (93/128)\n", "0 79 Loss: 0.619 | Acc: 79.688% (102/128)\n", "Saving..\n", "Epoch: 1 | train acc: 75.86000061035156 | test acc: 76.54000091552734\n", "\n", "Epoch: 2\n", "0 235 Loss: 0.666 | Acc: 75.781% (97/128)\n", "0 79 Loss: 0.640 | Acc: 78.125% (100/128)\n", "Saving..\n", "Epoch: 2 | train acc: 77.04000091552734 | test acc: 76.55000305175781\n", "\n", "Epoch: 3\n", "0 235 Loss: 0.579 | Acc: 81.250% (104/128)\n", "0 79 Loss: 0.577 | Acc: 79.688% (102/128)\n", "Saving..\n", "Epoch: 3 | train acc: 77.56999969482422 | test acc: 77.2300033569336\n", "\n", "Epoch: 4\n", "0 235 Loss: 0.661 | Acc: 78.125% (100/128)\n", "0 79 Loss: 0.613 | Acc: 76.562% (98/128)\n", "Saving..\n", "Epoch: 4 | train acc: 77.6866683959961 | test acc: 77.44999694824219\n", "\n", "Epoch: 5\n", "0 235 Loss: 0.627 | Acc: 80.469% (103/128)\n", "0 79 Loss: 0.626 | Acc: 80.469% (103/128)\n", "Epoch: 5 | train acc: 78.163330078125 | test acc: 77.37999725341797\n", "\n", "Epoch: 6\n", "0 235 Loss: 0.602 | Acc: 77.344% (99/128)\n", "0 79 Loss: 0.607 | Acc: 78.125% (100/128)\n", "Saving..\n", "Epoch: 6 | train acc: 78.42333221435547 | test acc: 78.02999877929688\n", "\n", "Epoch: 7\n", "0 235 Loss: 0.537 | Acc: 75.781% (97/128)\n", "0 79 Loss: 0.608 | Acc: 79.688% (102/128)\n", "Saving..\n", "Epoch: 7 | train acc: 78.49333190917969 | test acc: 78.1500015258789\n", "\n", "Epoch: 8\n", "0 235 Loss: 0.578 | Acc: 75.781% (97/128)\n", "0 79 Loss: 0.650 | Acc: 76.562% (98/128)\n", "Epoch: 8 | train acc: 78.15333557128906 | test acc: 76.83000183105469\n", "\n", "Epoch: 9\n", "0 235 Loss: 0.583 | Acc: 77.344% (99/128)\n", "0 79 Loss: 0.616 | Acc: 77.344% (99/128)\n", "Saving..\n", "Epoch: 9 | train acc: 78.66999816894531 | test acc: 78.20999908447266\n" ] } ], "source": [ "outModelName = 'finetuned'\n", "logname = result_folder + model.__class__.__name__ + f'_{outModelName}.csv'\n", "\n", "if not os.path.exists(logname):\n", " with open(logname, 'w') as logfile:\n", " logwriter = csv.writer(logfile, delimiter=',')\n", " logwriter.writerow(['epoch', 'train loss', 'train acc', 'test loss', 'test acc'])\n", "\n", "for epoch in range(start_epoch, max_epochs_target):\n", " adjust_learning_rate(optimizer, epoch)\n", " train_loss, train_acc = train(model, epoch, use_cuda=use_cuda)\n", " test_loss, test_acc = test(model, epoch, outModelName, use_cuda=use_cuda)\n", " with open(logname, 'a') as logfile:\n", " logwriter = csv.writer(logfile, delimiter=',')\n", " logwriter.writerow([epoch, train_loss, train_acc.item(), test_loss, test_acc.item()])\n", " print(f'Epoch: {epoch} | train acc: {train_acc} | test acc: {test_acc}')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Plot results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain losstrain acctest losstest acc
000.81969071.0866700.71326075.089996
110.68094075.8600010.67443176.540001
220.65024577.0400010.67588376.550003
330.63855577.5700000.65277677.230003
440.63050077.6866680.66642877.449997
\n", "
" ], "text/plain": [ " epoch train loss train acc test loss test acc\n", "0 0 0.819690 71.086670 0.713260 75.089996\n", "1 1 0.680940 75.860001 0.674431 76.540001\n", "2 2 0.650245 77.040001 0.675883 76.550003\n", "3 3 0.638555 77.570000 0.652776 77.230003\n", "4 4 0.630500 77.686668 0.666428 77.449997" ] }, "execution_count": 24, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# title plot results\n", "results = pd.read_csv(f'/content/results/ResNet_{outModelName}.csv', sep =',')\n", "results.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average Accuracy over 10 epochs: 77.0\n", "best accuraccy over 10 epochs: 78.20999908447266\n" ] } ], "source": [ "train_accuracy = results['train acc'].values\n", "test_accuracy = results['test acc'].values\n", "\n", "print(f'Average Accuracy over {max_epochs_target} epochs:', sum(test_accuracy)//len(test_accuracy))\n", "print(f'best accuraccy over {max_epochs_target} epochs:', max(test_accuracy))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "figureName = 'figure' # change figure name\n", "\n", "plt.plot(results['epoch'].values, train_accuracy, label='train')\n", "plt.plot(results['epoch'].values, test_accuracy, label='test')\n", "plt.xlabel('Number of epochs')\n", "plt.ylabel('Accuracy')\n", "plt.title(f'Train/Test Accuracy curve for {max_epochs} epochs')\n", "plt.savefig(f'/content/results/{figureName}.png')\n", "plt.legend()\n", "plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "transfer_learning", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }