{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "id": "view-in-github" }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial 3: Image, Conditional Diffusion and Beyond\n", "\n", "**Week 2, Day 4: Name of the day**\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Binxu Wang\n", "\n", "__Content reviewers:__ Shaonan Wang, Dongrui Deng, Dora Zhiyu Yang, Adrita Das\n", "\n", "__Content editors:__ Shaonan Wang\n", "\n", "__Production editors:__ Spiros Chavlis" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Tutorial Objectives\n", "\n", "* Understand the idea behind Diffusion generative models: score and reversal of diffusion process.\n", "* Learn the score function by denoising data.\n", "* Hands-on experience in learning the score to generate certain distributions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/j89qg/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/j89qg/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " **WARNING**: There may be *errors* and/or *warnings* reported during the installation. However, they are to be ignored.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install dependencies\n", "# @markdown **WARNING**: There may be *errors* and/or *warnings* reported during the installation. However, they are to be ignored.\n", "!pip install pillow --quiet\n", "!pip install diffusers transformers tokenizers --quiet\n", "!pip install accelerate --quiet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip3 install vibecheck datatops --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_dl\",\n", " \"user_key\": \"f379rz8y\",\n", " },\n", " ).render()\n", "\n", "\n", "feedback_prefix = \"W2D4_T3\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "execution": {} }, "outputs": [], "source": [ "# Imports\n", "import random\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import functools\n", "\n", "from torch.optim import Adam\n", "from torch.utils.data import DataLoader\n", "import torchvision.transforms as transforms\n", "from torchvision.datasets import MNIST\n", "from tqdm.notebook import trange, tqdm\n", "from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR\n", "from torchvision.utils import make_grid" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure settings\n", "import ipywidgets as widgets # interactive display\n", "%config InlineBackend.figure_format = 'retina'\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set random seed\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Executing `set_seed(seed=seed)` you are setting the seed\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set random seed\n", "\n", "# @markdown Executing `set_seed(seed=seed)` you are setting the seed\n", "\n", "# For DL its critical to set the random seed so that students can have a\n", "# baseline to compare their results to expected results.\n", "# Read more here: https://pytorch.org/docs/stable/notes/randomness.html\n", "\n", "# Call `set_seed` function in the exercises to ensure reproducibility.\n", "import random\n", "import torch\n", "\n", "def set_seed(seed=None, seed_torch=True):\n", " \"\"\"\n", " Function that controls randomness.\n", " NumPy and random modules must be imported.\n", "\n", " Args:\n", " seed : Integer\n", " A non-negative integer that defines the random state. Default is `None`.\n", " seed_torch : Boolean\n", " If `True` sets the random seed for pytorch tensors, so pytorch module\n", " must be imported. Default is `True`.\n", "\n", " Returns:\n", " Nothing.\n", " \"\"\"\n", " if seed is None:\n", " seed = np.random.choice(2 ** 32)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " if seed_torch:\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.backends.cudnn.benchmark = False\n", " torch.backends.cudnn.deterministic = True\n", "\n", " print(f'Random seed {seed} has been set.')\n", "\n", "# In case that `DataLoader` is used\n", "def seed_worker(worker_id):\n", " \"\"\"\n", " DataLoader will reseed workers following randomness in\n", " multi-process data loading algorithm.\n", "\n", " Args:\n", " worker_id: integer\n", " ID of subprocess to seed. 0 means that\n", " the data will be loaded in the main process\n", " Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details\n", "\n", " Returns:\n", " Nothing\n", " \"\"\"\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set device (GPU or CPU). Execute `set_device()`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set device (GPU or CPU). Execute `set_device()`\n", "\n", "# Inform the user if the notebook uses GPU or CPU.\n", "\n", "def set_device():\n", " \"\"\"\n", " Set the device. CUDA if available, CPU otherwise\n", "\n", " Args:\n", " None\n", "\n", " Returns:\n", " Nothing\n", " \"\"\"\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " if device != \"cuda\":\n", " print(\"WARNING: For this notebook to perform best, \"\n", " \"if possible, in the menu under `Runtime` -> \"\n", " \"`Change runtime type.` select `GPU` \")\n", " else:\n", " print(\"GPU is enabled in this notebook.\")\n", "\n", " return device" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "DEVICE = set_device()\n", "SEED = 2021\n", "set_seed(seed=SEED)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Neural Network Architecture" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We just learned the basic principles of diffusion models, with the takeaway that the score function allows us to turn pure noise into some interesting data distribution. Further, we will approximate the score function with a neural network via denoising score matching. But when working with images, we need our neural network to 'play nice' with them and to reflect the inductive biases we associate with images.\n", "\n", "A reasonable choice is to choose the neural network architecture to be that of a **[U-Net](https://en.wikipedia.org/wiki/U-Net)**, which is a CNN-like architecture with:\n", "\n", "* downscaling/upscaling operations that help the network process features of images at different spatial scales.\n", "* skip connection as an information highway.\n", "\n", "Since the score function we're trying to learn a function of time, we also need to devise a way to ensure our neural network properly responds to changes in time. For this purpose, we can use a **time embedding**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Network architecture\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Network architecture\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'sV-ROEAZaO0'), ('Bilibili', 'BV1Yk4y1N7Ai')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Network_Architecture_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Coding Exercise 1: Train Diffusion for MNIST\n", "\n", "Finally, let's implement and train an actual image diffusion model for the MNIST dataset.\n", "\n", "By examining the neural network architecture of the score approximator, you will understand the inductive biases we built in." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In the next cell, you will implement the helper functions for the forward process.\n", "* `marginal_prob_std` for $\\sigma_t$ (note, it's standard deviation, not the variance)\n", "* `diffusion_coeff` for $g(t)$\n", "\n", "**Math Recap for Forward Processes**:\n", "\n", "We will use the same forward process (variance exploding SDE) as in the last tutorial, which reads:\n", "\n", "\\begin{equation}\n", "d\\mathbf x=g(t)d\\mathbf w\n", "\\end{equation}\n", "\n", "and we let the diffusion coefficient $g(t)=\\lambda^t$, with $\\lambda > 1$.\n", "\n", "If so, the marginal distribution of state $\\mathbf x_t$ at time t given an initial state $\\mathbf x_0$ will be a Gaussian $\\mathcal N(\\mathbf x_t|\\mathbf x_0,\\sigma_t^2 I)$. The variance is the integration of the squared diffusion coefficient.\n", "\n", "\\begin{equation}\n", "\\sigma_t^2 =\\int_0^tg(\\tau)^2d\\tau=\\frac{\\lambda^{2t}-1}{2\\log\\lambda}\n", "\\end{equation}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def marginal_prob_std(t, Lambda, device='cpu'):\n", " \"\"\"Compute the standard deviation of $p_{0t}(x(t) | x(0))$.\n", "\n", " Args:\n", " t: A vector of time steps.\n", " Lambda: The $\\lambda$ in our SDE.\n", "\n", " Returns:\n", " std : The standard deviation.\n", " \"\"\"\n", " t = t.to(device)\n", " #################################################\n", " ## TODO for students: Implement the standard deviation\n", " raise NotImplementedError(\"Student exercise: Implement the standard deviation\")\n", " #################################################\n", " std = ...\n", " return std\n", "\n", "\n", "def diffusion_coeff(t, Lambda, device='cpu'):\n", " \"\"\"Compute the diffusion coefficient of our SDE.\n", "\n", " Args:\n", " t: A vector of time steps.\n", " Lambda: The $\\lambda$ in our SDE.\n", "\n", " Returns:\n", " diff_coeff : The vector of diffusion coefficients.\n", " \"\"\"\n", " #################################################\n", " ## TODO for students: Implement the diffusion coefficients\n", " raise NotImplementedError(\"Student exercise: Implement the diffusion coefficients\")\n", " #################################################\n", " diff_coeff = ...\n", " return diff_coeff.to(device)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W2D4_GenerativeModels/solutions/W2D4_Tutorial3_Solution_514ccad3.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Train_Diffusion_for_MNIST_Exercise\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Network architecture\n", "\n", "Below is code for a simple time embedding and modulation layer. Basically, time $t$ is multiplexed as sine and cosine basis, then a linear readout creates the time modulation signal." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Time embedding and modulation\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Time embedding and modulation\n", "\n", "class GaussianFourierProjection(nn.Module):\n", " \"\"\"Gaussian random features for encoding time steps.\"\"\"\n", " def __init__(self, embed_dim, scale=30.):\n", " super().__init__()\n", " # Randomly sample weights (frequencies) during initialization.\n", " # These weights (frequencies) are fixed during optimization and are not trainable.\n", " self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)\n", " def forward(self, x):\n", " # Cosine(2 pi freq x), Sine(2 pi freq x)\n", " x_proj = x[:, None] * self.W[None, :] * 2 * np.pi\n", " return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)\n", "\n", "\n", "class Dense(nn.Module):\n", " \"\"\"A fully connected layer that reshapes outputs to feature maps.\n", " Allow time repr to input additively from the side of a convolution layer.\n", " \"\"\"\n", " def __init__(self, input_dim, output_dim):\n", " super().__init__()\n", " self.dense = nn.Linear(input_dim, output_dim)\n", " def forward(self, x):\n", " # this broadcast the 2d tensor to 4d, add the same value across space.\n", " return self.dense(x)[..., None, None]" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Below is code for a simple U-Net architecture. Apparently, diffusion models can be more or less successful with different architectural details. So this example is mainly for illustrative purposes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Time-dependent UNet score model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Time-dependent UNet score model\n", "\n", "class UNet(nn.Module):\n", " \"\"\"A time-dependent score-based model built upon U-Net architecture.\"\"\"\n", "\n", " def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):\n", " \"\"\"Initialize a time-dependent score-based network.\n", "\n", " Args:\n", " marginal_prob_std: A function that takes time t and gives the standard\n", " deviation of the perturbation kernel p_{0t}(x(t) | x(0)).\n", " channels: The number of channels for feature maps of each resolution.\n", " embed_dim: The dimensionality of Gaussian random feature embeddings.\n", " \"\"\"\n", " super().__init__()\n", " # Gaussian random feature embedding layer for time\n", " self.time_embed = nn.Sequential(\n", " GaussianFourierProjection(embed_dim=embed_dim),\n", " nn.Linear(embed_dim, embed_dim)\n", " )\n", " # Encoding layers where the resolution decreases\n", " self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)\n", " self.t_mod1 = Dense(embed_dim, channels[0])\n", " self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])\n", "\n", " self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)\n", " self.t_mod2 = Dense(embed_dim, channels[1])\n", " self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])\n", "\n", " self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)\n", " self.t_mod3 = Dense(embed_dim, channels[2])\n", " self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])\n", "\n", " self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)\n", " self.t_mod4 = Dense(embed_dim, channels[3])\n", " self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])\n", "\n", "\n", " # Decoding layers where the resolution increases\n", " self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)\n", " self.t_mod5 = Dense(embed_dim, channels[2])\n", " self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])\n", " self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)\n", " self.t_mod6 = Dense(embed_dim, channels[1])\n", " self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])\n", " self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)\n", " self.t_mod7 = Dense(embed_dim, channels[0])\n", " self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])\n", " self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)\n", "\n", " # The swish activation function\n", " self.act = lambda x: x * torch.sigmoid(x)\n", " # A restricted version of the `marginal_prob_std` function, after specifying a Lambda.\n", " self.marginal_prob_std = marginal_prob_std\n", "\n", " def forward(self, x, t, y=None):\n", " # Obtain the Gaussian random feature embedding for t\n", " embed = self.act(self.time_embed(t))\n", " # Encoding path, downsampling\n", " ## Incorporate information from t\n", " h1 = self.conv1(x) + self.t_mod1(embed)\n", " ## Group normalization and apply activation function\n", " h1 = self.act(self.gnorm1(h1))\n", " # 2nd conv\n", " h2 = self.conv2(h1) + self.t_mod2(embed)\n", " h2 = self.act(self.gnorm2(h2))\n", " # 3rd conv\n", " h3 = self.conv3(h2) + self.t_mod3(embed)\n", " h3 = self.act(self.gnorm3(h3))\n", " # 4th conv\n", " h4 = self.conv4(h3) + self.t_mod4(embed)\n", " h4 = self.act(self.gnorm4(h4))\n", "\n", " # Decoding path up sampling\n", " h = self.tconv4(h4) + self.t_mod5(embed)\n", " ## Skip connection from the encoding path\n", " h = self.act(self.tgnorm4(h))\n", " h = self.tconv3(torch.cat([h, h3], dim=1)) + self.t_mod6(embed)\n", " h = self.act(self.tgnorm3(h))\n", " h = self.tconv2(torch.cat([h, h2], dim=1)) + self.t_mod7(embed)\n", " h = self.act(self.tgnorm2(h))\n", " h = self.tconv1(torch.cat([h, h1], dim=1))\n", "\n", " # Normalize output\n", " h = h / self.marginal_prob_std(t)[:, None, None, None]\n", " return h" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Think! 1: U-Net Architecture\n", "\n", "Looking at the U-Net architecture, can you find the module(s) corresponding to the following operations?\n", "1. Downsampling the spatial features?\n", "2. Upsampling the spatial features?\n", "3. The skip connection from the down branch to the up branch, how is it implemented?\n", "5. How is time modulation implemented?\n", "5. Why is the output divided by `self.marginal_prob_std(t)` before output? How might this help or harm the score learning?\n", "\n", "Take 2 minutes to think in silence, then discuss as a group (~10 minutes)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W2D4_GenerativeModels/solutions/W2D4_Tutorial3_Solution_3fb699e4.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_UNet_Architecture_Discussion\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Coding Exercise 2: Defining the loss function\n", "\n", "In the next cell, you will implement the denoising score matching (DSM) objective as we used in the last tutorial.\n", "\n", "\\begin{equation}\n", "\\mathcal L=\\int_\\epsilon^1dt \\mathbb E_{x\\sim p_0(x)}\\mathbb E_{z\\sim \\mathcal N(0,I)}\\|\\sigma_t s_\\theta(x+\\sigma_t z, t)+z\\|^2\n", "\\end{equation}\n", "\n", "where the time weighting is chosen as $\\gamma_t=\\sigma_t^2$, which emphasizes the high noise period ($t\\sim 1$) more than the low noise period ($t\\sim 0$).\n", "\n", "**Tips**:\n", "\n", "* The major difference from the last tutorial is that the score $s$, noise $z$, and states $x$ are all batch image-shaped tensor, so remember to broadcast the $\\sigma_t$ properly. e.g. this `std[:, None, None, None]` will be helpful.\n", "* `eps` is set at a small number to stop the model from learning the score function of a very small noise scale, which is highly irregular.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def loss_fn(model, x, marginal_prob_std, eps=1e-3, device='cpu'):\n", " \"\"\"The loss function for training score-based generative models.\n", "\n", " Args:\n", " model: A PyTorch model instance that represents a\n", " time-dependent score-based model.\n", " Note, it takes two inputs in its forward function model(x, t)\n", " $s_\\theta(x,t)$ in the equation\n", " x: A mini-batch of training data.\n", " marginal_prob_std: A function that gives the standard deviation of\n", " the perturbation kernel, takes `t` as input.\n", " $\\sigma_t$ in the equation.\n", " eps: A tolerance value for numerical stability.\n", " \"\"\"\n", " # Sample time uniformly in eps, 1\n", " random_t = torch.rand(x.shape[0], device=device) * (1. - eps) + eps\n", " # Find the noise std at the time `t`\n", " std = marginal_prob_std(random_t).to(device)\n", " #################################################\n", " ## TODO for students: Implement the denoising score matching eq.\n", " raise NotImplementedError(\"Student exercise: Implement the denoising score matching eq. \")\n", " #################################################\n", " # get normally distributed noise N(0, I)\n", " z = ...\n", " # compute the perturbed x = x + z * \\sigma_t\n", " perturbed_x = ...\n", " # predict score with the model at (perturbed x, t)\n", " score = ...\n", " # compute distance between the score and noise \\| score * sigma_t + z \\|_2^2\n", " loss = ...\n", " ##############\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W2D4_GenerativeModels/solutions/W2D4_Tutorial3_Solution_0ded5c2e.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "A correctly implemented loss function shall pass the test below.\n", "\n", "For a dataset with a single `0` datapoint, we have the analytical score $\\mathbf s(\\mathbf x,t)=-\\mathbf x/\\sigma_t^2$. We test that, for this case, the analytical has zero loss." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test loss function\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Test loss function\n", "marginal_prob_std_test = lambda t: marginal_prob_std(t, Lambda=10, device='cpu')\n", "score_analyt_test = lambda x_t, t: - x_t / marginal_prob_std_test(t)[:,None,None,None]**2\n", "x_test = torch.zeros(10, 3, 64, 64)\n", "loss = loss_fn(score_analyt_test, x_test, marginal_prob_std_test, eps=1e-3, device='cpu')\n", "assert torch.allclose(loss,torch.zeros(1)), \"the loss should be zero in this case\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Defining_the_loss_function_Exercise\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Train and Test the Diffusion Model\n", "\n", "**Note:** We have reduced the `n_epochs` to 12, but feel free to increase and use a larger value. The original value was set to 100, but if the training takes too long, `n_epochs=50` with `batch_size=1024` also suffice. An average loss of around ~30 can generate acceptable digits." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training the model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Training the model\n", "Lambda = 25.0 # @param {'type':'number'}\n", "\n", "marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device=DEVICE)\n", "diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=Lambda, device=DEVICE)\n", "score_model = UNet(marginal_prob_std=marginal_prob_std_fn)\n", "score_model = score_model.to(DEVICE)\n", "\n", "n_epochs = 12 # @param {'type':'integer'}\n", "# size of a mini-batch\n", "batch_size = 1024 # @param {'type':'integer'}\n", "# learning rate\n", "lr = 10e-4 # @param {'type':'number'}\n", "\n", "set_seed(SEED)\n", "dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)\n", "g = torch.Generator()\n", "g.manual_seed(SEED)\n", "data_loader = DataLoader(dataset, batch_size=batch_size,\n", " shuffle=True, num_workers=2,\n", " worker_init_fn=seed_worker,\n", " generator=g,)\n", "\n", "optimizer = Adam(score_model.parameters(), lr=lr)\n", "scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 1 - epoch / n_epochs))\n", "tqdm_epoch = trange(n_epochs)\n", "\n", "for epoch in tqdm_epoch:\n", " avg_loss = 0.\n", " num_items = 0\n", " pbar = tqdm(data_loader)\n", " for x, y in pbar:\n", " x = x.to(DEVICE)\n", " loss = loss_fn(score_model, x, marginal_prob_std_fn, eps=0.01, device=DEVICE)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " avg_loss += loss.item() * x.shape[0]\n", " num_items += x.shape[0]\n", " scheduler.step()\n", " print(f\"Average Loss: {(avg_loss / num_items):5f} lr {scheduler.get_last_lr()[0]:.1e}\")\n", " # Print the averaged training loss so far.\n", " tqdm_epoch.set_description(f'Average Loss: {(avg_loss / num_items):.5f}')\n", " # Update the checkpoint after each epoch of training.\n", " torch.save(score_model.state_dict(), 'ckpt.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the Sampler\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Define the Sampler\n", "def Euler_Maruyama_sampler(score_model,\n", " marginal_prob_std,\n", " diffusion_coeff,\n", " batch_size=64,\n", " x_shape=(1, 28, 28),\n", " num_steps=500,\n", " device='cuda',\n", " eps=1e-3, y=None):\n", " \"\"\"Generate samples from score-based models with the Euler-Maruyama solver.\n", "\n", " Args:\n", " score_model: A PyTorch model that represents the time-dependent score-based model.\n", " marginal_prob_std: A function that gives the standard deviation of\n", " the perturbation kernel.\n", " diffusion_coeff: A function that gives the diffusion coefficient of the SDE.\n", " batch_size: The number of samplers to generate by calling this function once.\n", " num_steps: The number of sampling steps.\n", " Equivalent to the number of discretized time steps.\n", " device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.\n", " eps: The smallest time step for numerical stability.\n", "\n", " Returns:\n", " Samples.\n", " \"\"\"\n", " t = torch.ones(batch_size).to(device)\n", " r = torch.randn(batch_size, *x_shape).to(device)\n", " init_x = r * marginal_prob_std(t)[:, None, None, None]\n", " init_x = init_x.to(device)\n", " time_steps = torch.linspace(1., eps, num_steps).to(device)\n", " step_size = time_steps[0] - time_steps[1]\n", " x = init_x\n", " with torch.no_grad():\n", " for time_step in tqdm(time_steps):\n", " batch_time_step = torch.ones(batch_size, device=device) * time_step\n", " g = diffusion_coeff(batch_time_step)\n", " mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size\n", " x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)\n", " # Do not include any noise in the last sampling step.\n", " return mean_x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sampling\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Sampling\n", "def save_samples_uncond(score_model, suffix=\"\", device='cpu'):\n", " score_model.eval()\n", " ## Generate samples using the specified sampler.\n", " sample_batch_size = 64 # @param {'type':'integer'}\n", " num_steps = 250 # @param {'type':'integer'}\n", " # score_model.eval()\n", " ## Generate samples using the specified sampler.\n", " samples = Euler_Maruyama_sampler(score_model,\n", " marginal_prob_std_fn,\n", " diffusion_coeff_fn,\n", " sample_batch_size,\n", " num_steps=num_steps,\n", " device=DEVICE,\n", " eps=0.001)\n", "\n", " # Sample visualization.\n", " samples = samples.clamp(0.0, 1.0)\n", " sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))\n", " sample_np = sample_grid.permute(1, 2, 0).cpu().numpy()\n", " plt.imsave(f\"uncondition_diffusion{suffix}.png\", sample_np, )\n", " plt.figure(figsize=(6,6))\n", " plt.axis('off')\n", " plt.imshow(sample_np, vmin=0., vmax=1.)\n", " plt.show()\n", "\n", "\n", "marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device=DEVICE)\n", "uncond_score_model = UNet(marginal_prob_std=marginal_prob_std_fn)\n", "uncond_score_model.load_state_dict(torch.load(\"ckpt.pth\"))\n", "uncond_score_model.to(DEVICE)\n", "save_samples_uncond(uncond_score_model, suffix=\"\", device=DEVICE)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Nice job! You have just finished the training of a Diffusion model. As you see, the result is not ideal, and many factors affect this. To name a few:\n", "\n", "* **Better network architecture**: residual connections, attention mechanism, better upsampling mechanism\n", "* **Better objective**: better weighting function $\\gamma_t$\n", "* **Better optimization procedure**: using learning rate decay\n", "* **Better sampling algorithm**: Euler integration is known to have larger errors, so it's advisable to use a more advanced SDE or ODE solver" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Conditional Diffusion Model\n", "\n", "Another way to greatly improve the result is adding a conditional signal -- for example, tell the score network which digit you want. This makes the score modeling much more effortless and adds controllability to the user. The popular Stable Diffusion model is one of this kind, which uses natural language text as the conditional signal for images." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 2: Conditional Diffusion Model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 2: Conditional Diffusion Model\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'tJDdVN9Fnrs'), ('Bilibili', 'BV1ek4y1N7bs')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Conditional_Diffusion_Model_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In formulation, the conditional diffusion is highly similar to the unconditional diffusion." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "If you are curious about how to build and train a conditional diffusion model, you are welcome to look at the Bonus exercise at the end." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 3: Advanced Techinque - Stable Diffusion\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 3: Advanced Techinque - Stable Diffusion\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'HBLgRqxgxrY'), ('Bilibili', 'BV1Yh4y1M74g')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Advanced_Techinque_Stable_Diffusion_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Interactive Demo 2: Stable Diffusion\n", "\n", "In this demo, we will play with one of the most potent open-source diffusion models, Stable Diffusion 2.1, and try to connect with what we have learned." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Download the Stable Diffusion models\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "#@title Download the Stable Diffusion models\n", "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, PNDMScheduler\n", "\n", "model_id = \"stabilityai/stable-diffusion-2-1\"\n", "pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)\n", "# Use the PNDM scheduler as default\n", "# pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)\n", "# Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead\n", "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n", "pipe = pipe.to(DEVICE)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now you can let loose your imagination and create artworks from text!\n", "\n", "Example prompts:\n", "\n", "```python\n", "prompt = \"A lovely cat running on the dessert in Van Gogh style, trending art.\"\n", "prompt = \"A ballerina dancing under the starry night in Monet style, trending art.\"\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {} }, "outputs": [], "source": [ "prompt = \"A lovely cat running on the dessert in Van Gogh style, trending art.\" # @param {'type':'string'}\n", "my_seed = 2023 # @param {'type':'integer'}\n", "execute = False # @param {'type':'boolean'}\n", "\n", "if execute:\n", " image = pipe(prompt, num_inference_steps=50,\n", " generator=torch.Generator(\"cuda\").manual_seed(my_seed)).images[0]\n", " image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Stable_Diffusion_Interactive_Demo\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Think! 2: Architecture of Stable Diffusion Model\n", "\n", "Can you see the similarity between the U-Net in Stable Diffusion and the baby UNet we defined up there?\n", "To inspect the architecture, you can use the `recursive_print(pipe.unet,deepest=2)` function with a different `deepest`.\n", "\n", "The text is encoded through the CLIP model, and you can also look at its structure below `recursive_print(pipe.text_encoder,deepest=4)`, which is a large transformer!\n", "\n", "Take 2 minutes to think and play with the code, then discuss as a group (~10 minutes)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Helper function to inspect network\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Helper function to inspect network\n", "def recursive_print(module, prefix=\"\", depth=0, deepest=3):\n", " \"\"\"Simulating print(module) for torch.nn.Modules\n", " but with depth control. Print to the `deepest` level. `deepest=0` means no print\n", " \"\"\"\n", " if depth == 0:\n", " print(f\"[{type(module).__name__}]\")\n", " if depth >= deepest:\n", " return\n", " for name, child in module.named_children():\n", " if len([*child.named_children()]) == 0:\n", " print(f\"{prefix}({name}): {child}\")\n", " else:\n", " if isinstance(child, nn.ModuleList):\n", " print(f\"{prefix}({name}): {type(child).__name__} len={len(child)}\")\n", " else:\n", " print(f\"{prefix}({name}): {type(child).__name__}\")\n", " recursive_print(child, prefix=prefix + \" \", depth=depth + 1, deepest=deepest)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "recursive_print(pipe.unet,deepest=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "recursive_print(pipe.text_encoder,deepest=4)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W2D4_GenerativeModels/solutions/W2D4_Tutorial3_Solution_263bbd8a.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Architecture_of_Stable_Diffusion_Model_Discussion\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 3: Ethical Considerations\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 4: Ethical Consideration\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 4: Ethical Consideration\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'Qy8ODZ7TYZg'), ('Bilibili', 'BV1TV4y1a7Qx')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Ethical_Consideration_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Think! 3: Copyright of imagery generated from diffusion generated models\n", "\n", "Suppose you prompt a pretrained diffusion model with the name of the artist and obtain beautiful imagery similar to that artist's style. Who has the copyright of the generated image? The producing company of the diffusion model, the original artist, you, the prompter, the random seed and the weights, or the GPU that runs the inference?\n", "\n", "Who do you think deserves the credit and why?\n", "\n", "What if you apply enough post-processing steps to the generated images, e.g., finetune the prompt and seed, or edit the image?\n", "\n", "Take 2 minutes to think in silence, then discuss as a group (~10 minutes).\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W2D4_GenerativeModels/solutions/W2D4_Tutorial3_Solution_5064d2b9.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Copyrights_Discussion\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "Today, we learned about\n", "\n", "* One major application for diffusion modeling, i.e., Modeling natural images.\n", "* Inductive biases suitable for image modeling: U-Net architecture and time modulation mechanism.\n", "* Introduction to conditional diffusion models, with a demo on Stable Diffusion.\n", "* Ethical considerations related to diffusion models, including copyright, misinformation, and fairness." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Daily survey\n", "\n", "Don't forget to complete your reflections and content check in the daily survey! Please be patient after logging in as there is a small delay before you will be redirected to the survey.\n", "\n", "\"button" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Bonus: Train Conditional Diffusion for MNIST\n", "\n", "In this part, we'd like to train an MNIST generative model conditioned on the digit.\n", "\n", "Here we will use a basic form of conditional modulation, i.e., digit embedding, to linearly control the relative gain of the features. After learning about the attention mechanism, you could think about better ways conditional modulation, e.g., using cross-attention to modulate the score model." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## UNet score model with conditional modulation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "class UNet_Conditional(nn.Module):\n", " \"\"\"A time-dependent score-based model built upon U-Net architecture.\"\"\"\n", "\n", " def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,\n", " text_dim=256, nClass=10):\n", " \"\"\"Initialize a time-dependent score-based network.\n", "\n", " Args:\n", " marginal_prob_std: A function that takes time t and gives the standard\n", " deviation of the perturbation kernel p_{0t}(x(t) | x(0)).\n", " channels: The number of channels for feature maps of each resolution.\n", " embed_dim: The dimensionality of Gaussian random feature embeddings of time.\n", " text_dim: the embedding dimension of text / digits.\n", " nClass: number of classes you want to model.\n", " \"\"\"\n", " super().__init__()\n", " # random embedding for classes\n", " self.cond_embed = nn.Embedding(nClass, text_dim)\n", " # Gaussian random feature embedding layer for time\n", " self.time_embed = nn.Sequential(\n", " GaussianFourierProjection(embed_dim=embed_dim),\n", " nn.Linear(embed_dim, embed_dim)\n", " )\n", " # Encoding layers where the resolution decreases\n", " self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)\n", " self.t_mod1 = Dense(embed_dim, channels[0])\n", " self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])\n", "\n", " self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)\n", " self.t_mod2 = Dense(embed_dim, channels[1])\n", " self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])\n", " self.y_mod2 = Dense(embed_dim, channels[1])\n", "\n", " self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)\n", " self.t_mod3 = Dense(embed_dim, channels[2])\n", " self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])\n", " self.y_mod3 = Dense(embed_dim, channels[2])\n", "\n", " self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)\n", " self.t_mod4 = Dense(embed_dim, channels[3])\n", " self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])\n", " self.y_mod4 = Dense(embed_dim, channels[3])\n", "\n", " # Decoding layers where the resolution increases\n", " self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)\n", " self.t_mod5 = Dense(embed_dim, channels[2])\n", " self.y_mod5 = Dense(embed_dim, channels[2])\n", " self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])\n", "\n", " self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1) # + channels[2]\n", " self.t_mod6 = Dense(embed_dim, channels[1])\n", " self.y_mod6 = Dense(embed_dim, channels[1])\n", " self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])\n", "\n", " self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1) # + channels[1]\n", " self.t_mod7 = Dense(embed_dim, channels[0])\n", " self.y_mod7 = Dense(embed_dim, channels[0])\n", " self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])\n", " self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)\n", "\n", " # The swish activation function\n", " self.act = nn.SiLU() # lambda x: x * torch.sigmoid(x)\n", " self.marginal_prob_std = marginal_prob_std\n", " for module in [self.y_mod2,self.y_mod3,self.y_mod4,\n", " self.y_mod5,self.y_mod6,self.y_mod7]:\n", " nn.init.normal_(module.dense.weight, mean=0, std=0.0001)\n", " nn.init.constant_(module.dense.bias, 1.0)\n", "\n", " def forward(self, x, t, y=None):\n", " # Obtain the Gaussian random feature embedding for t\n", " embed = self.act(self.time_embed(t))\n", " y_embed = self.cond_embed(y)\n", " # Encoding path\n", " h1 = self.conv1(x) + self.t_mod1(embed)\n", " ## Incorporate information from t\n", " ## Group normalization\n", " h1 = self.act(self.gnorm1(h1))\n", " h2 = self.conv2(h1) + self.t_mod2(embed)\n", " h2 = h2 * self.y_mod2(y_embed)\n", " h2 = self.act(self.gnorm2(h2))\n", " h3 = self.conv3(h2) + self.t_mod3(embed)\n", " h3 = h3 * self.y_mod3(y_embed)\n", " h3 = self.act(self.gnorm3(h3))\n", " h4 = self.conv4(h3) + self.t_mod4(embed)\n", " h4 = h4 * self.y_mod4(y_embed)\n", " h4 = self.act(self.gnorm4(h4))\n", "\n", " # Decoding path\n", " h = self.tconv4(h4) + self.t_mod5(embed)\n", " h = h * self.y_mod5(y_embed)\n", " ## Skip connection from the encoding path\n", " h = self.act(self.tgnorm4(h))\n", " h = self.tconv3(h + h3) + self.t_mod6(embed)\n", " h = h * self.y_mod6(y_embed)\n", " h = self.act(self.tgnorm3(h))\n", " h = self.tconv2(h + h2) + self.t_mod7(embed)\n", " h = h * self.y_mod7(y_embed)\n", " h = self.act(self.tgnorm2(h))\n", " h = self.tconv1(h + h1)\n", "\n", " # Normalize output\n", " h = h / self.marginal_prob_std(t)[:, None, None, None]\n", " return h" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Loss for conditional diffusion" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-3):\n", " \"\"\"The loss function for training score-based generative models.\n", "\n", " Args:\n", " model: A PyTorch model instance that represents a\n", " time-dependent score-based model.\n", " x: A mini-batch of training data.\n", " marginal_prob_std: A function that gives the standard deviation of\n", " the perturbation kernel.\n", " eps: A tolerance value for numerical stability.\n", " \"\"\"\n", " random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps\n", " z = torch.randn_like(x)\n", " std = marginal_prob_std(random_t)\n", " perturbed_x = x + z * std[:, None, None, None]\n", " score = model(perturbed_x, random_t, y=y)\n", " loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2,\n", " dim=(1, 2, 3)))\n", " return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training conditional diffusion model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Training conditional diffusion model\n", "Lambda = 25 #@param {'type':'number'}\n", "marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device=DEVICE)\n", "diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=Lambda, device=DEVICE)\n", "print(\"initilize new score model...\")\n", "score_model_cond = UNet_Conditional(marginal_prob_std=marginal_prob_std_fn)\n", "score_model_cond = score_model_cond.to(DEVICE)\n", "\n", "n_epochs = 10 # @param {'type':'integer'}\n", "## size of a mini-batch\n", "batch_size = 1024 # @param {'type':'integer'}\n", "## learning rate\n", "lr = 10e-4 # @param {'type':'number'}\n", "\n", "dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)\n", "data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)\n", "\n", "optimizer = Adam(score_model_cond.parameters(), lr=lr)\n", "scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.99 ** epoch))\n", "tqdm_epoch = trange(n_epochs)\n", "for epoch in tqdm_epoch:\n", " avg_loss = 0.\n", " num_items = 0\n", " for x, y in tqdm(data_loader):\n", " x = x.to(DEVICE)\n", " loss = loss_fn_cond(score_model_cond, x, y.to(DEVICE), marginal_prob_std_fn)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " avg_loss += loss.item() * x.shape[0]\n", " num_items += x.shape[0]\n", " scheduler.step()\n", " lr_current = scheduler.get_last_lr()[0]\n", " print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))\n", " # Print the averaged training loss so far.\n", " tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))\n", " # Update the checkpoint after each epoch of training.\n", " torch.save(score_model_cond.state_dict(), 'ckpt_cond.pth')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "**Note:** The original value for `n_epochs` was 100." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sample Conditional Diffusion\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Sample Conditional Diffusion\n", "digit = 4 # @param {'type':'integer'}\n", "sample_batch_size = 64 # @param {'type':'integer'}\n", "num_steps = 250 # @param {'type':'integer'}\n", "score_model_cond.eval()\n", "## Generate samples using the specified sampler.\n", "samples = Euler_Maruyama_sampler(\n", " score_model_cond,\n", " marginal_prob_std_fn,\n", " diffusion_coeff_fn,\n", " sample_batch_size,\n", " num_steps=num_steps,\n", " device=DEVICE,\n", " y=digit*torch.ones(sample_batch_size, dtype=torch.long, device=DEVICE))\n", "\n", "## Sample visualization.\n", "samples = samples.clamp(0.0, 1.0)\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))\n", "\n", "plt.figure(figsize=(6, 6))\n", "plt.axis('off')\n", "plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "torch.cuda.empty_cache()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "gpuType": "T4", "include_colab_link": true, "name": "W2D4_Tutorial3", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc-autonumbering": true }, "nbformat": 4, "nbformat_minor": 0 }