Tutorial 3: Deep linear neural networks#
Week 1, Day 2: Linear Deep Learning
By Neuromatch Academy
Content creators: Saeed Salehi, Spiros Chavlis, Andrew Saxe
Content reviewers: Polina Turishcheva, Antoine De Comite
Content editors: Anoop Kulkarni
Production editors: Khalid Almubarak, Gagana B, Spiros Chavlis
Tutorial Objectives#
Deep linear neural networks
Learning dynamics and singular value decomposition
Representational Similarity Analysis
Illusory correlations & ethics
Setup#
This a GPU-Free tutorial!
Install and import feedback gadget#
Show code cell source
# @title Install and import feedback gadget
!pip3 install vibecheck datatops --quiet
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_dl",
"user_key": "f379rz8y",
},
).render()
feedback_prefix = "W1D2_T3"
# Imports
import math
import torch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
Figure settings#
Show code cell source
# @title Figure settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True
from matplotlib import gridspec
from ipywidgets import interact, IntSlider, FloatSlider, fixed
from ipywidgets import FloatLogSlider, Layout, VBox
from ipywidgets import interactive_output
from mpl_toolkits.axes_grid1 import make_axes_locatable
import warnings
warnings.filterwarnings("ignore")
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle")
Plotting functions#
Show code cell source
# @title Plotting functions
def plot_x_y_hier_data(im1, im2, subplot_ratio=[1, 2]):
"""
Plot hierarchical data of labels vs features
for all samples
Args:
im1: np.ndarray
Input Dataset
im2: np.ndarray
Targets
subplot_ratio: list
Subplot ratios used to create subplots of varying sizes
Returns:
Nothing
"""
fig = plt.figure(figsize=(12, 5))
gs = gridspec.GridSpec(1, 2, width_ratios=subplot_ratio)
ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1])
ax0.imshow(im1, cmap="cool")
ax1.imshow(im2, cmap="cool")
ax0.set_title("Labels of all samples")
ax1.set_title("Features of all samples")
ax0.set_axis_off()
ax1.set_axis_off()
plt.tight_layout()
plt.show()
def plot_x_y_hier_one(im1, im2, subplot_ratio=[1, 2]):
"""
Plot hierarchical data of labels vs features
for a single sample
Args:
im1: np.ndarray
Input Dataset
im2: np.ndarray
Targets
subplot_ratio: list
Subplot ratios used to create subplots of varying sizes
Returns:
Nothing
"""
fig = plt.figure(figsize=(12, 1))
gs = gridspec.GridSpec(1, 2, width_ratios=subplot_ratio)
ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1])
ax0.imshow(im1, cmap="cool")
ax1.imshow(im2, cmap="cool")
ax0.set_title("Labels of a single sample")
ax1.set_title("Features of a single sample")
ax0.set_axis_off()
ax1.set_axis_off()
plt.tight_layout()
plt.show()
def plot_tree_data(label_list = None, feature_array = None, new_feature = None):
"""
Plot tree data
Args:
label_list: np.ndarray
List of labels [default: None]
feature_array: np.ndarray
List of features [default: None]
new_feature: string
Enables addition of new features
Returns:
Nothing
"""
cmap = matplotlib.colors.ListedColormap(['cyan', 'magenta'])
n_features = 10
n_labels = 8
im1 = np.eye(n_labels)
if feature_array is None:
im2 = np.array([[1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 1]]).T
im2[im2 == 0] = -1
feature_list = ['can_grow',
'is_mammal',
'has_leaves',
'can_move',
'has_trunk',
'can_fly',
'can_swim',
'has_stem',
'is_warmblooded',
'can_flower']
else:
im2 = feature_array
if label_list is None:
label_list = ['Goldfish', 'Tuna', 'Robin', 'Canary',
'Rose', 'Daisy', 'Pine', 'Oak']
fig = plt.figure(figsize=(12, 7))
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1.35])
ax1 = plt.subplot(gs[0])
ax2 = plt.subplot(gs[1])
ax1.imshow(im1, cmap=cmap)
if feature_array is None:
implt = ax2.imshow(im2, cmap=cmap, vmin=-1.0, vmax=1.0)
else:
implt = ax2.imshow(im2[:, -n_features:], cmap=cmap, vmin=-1.0, vmax=1.0)
divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(implt, cax=cax, ticks=[-0.5, 0.5])
cbar.ax.set_yticklabels(['no', 'yes'])
ax1.set_title("Labels")
ax1.set_yticks(ticks=np.arange(n_labels))
ax1.set_yticklabels(labels=label_list)
ax1.set_xticks(ticks=np.arange(n_labels))
ax1.set_xticklabels(labels=label_list, rotation='vertical')
ax2.set_title("{} random Features".format(n_features))
ax2.set_yticks(ticks=np.arange(n_labels))
ax2.set_yticklabels(labels=label_list)
if feature_array is None:
ax2.set_xticks(ticks=np.arange(n_features))
ax2.set_xticklabels(labels=feature_list, rotation='vertical')
else:
ax2.set_xticks(ticks=[n_features-1])
ax2.set_xticklabels(labels=[new_feature], rotation='vertical')
plt.tight_layout()
plt.show()
def plot_loss(loss_array,
title="Training loss (Mean Squared Error)",
c="r"):
"""
Plot loss function
Args:
c: string
Specifies plot color
title: string
Specifies plot title
loss_array: np.ndarray
Log of MSE loss per epoch
Returns:
Nothing
"""
plt.figure(figsize=(10, 5))
plt.plot(loss_array, color=c)
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.title(title)
plt.show()
def plot_loss_sv(loss_array, sv_array):
"""
Plot loss function
Args:
sv_array: np.ndarray
Log of singular values/modes across epochs
loss_array: np.ndarray
Log of MSE loss per epoch
Returns:
Nothing
"""
n_sing_values = sv_array.shape[1]
sv_array = sv_array / np.max(sv_array)
cmap = plt.cm.get_cmap("Set1", n_sing_values)
_, (plot1, plot2) = plt.subplots(2, 1, sharex=True, figsize=(10, 10))
plot1.set_title("Training loss (Mean Squared Error)")
plot1.plot(loss_array, color='r')
plot2.set_title("Evolution of singular values (modes)")
for i in range(n_sing_values):
plot2.plot(sv_array[:, i], c=cmap(i))
plot2.set_xlabel("Epoch")
plt.show()
def plot_loss_sv_twin(loss_array, sv_array):
"""
Plot learning dynamics
Args:
sv_array: np.ndarray
Log of singular values/modes across epochs
loss_array: np.ndarray
Log of MSE loss per epoch
Returns:
Nothing
"""
n_sing_values = sv_array.shape[1]
sv_array = sv_array / np.max(sv_array)
cmap = plt.cm.get_cmap("winter", n_sing_values)
fig = plt.figure(figsize=(10, 5))
ax1 = plt.gca()
ax1.set_title("Learning Dynamics")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Mean Squared Error", c='r')
ax1.tick_params(axis='y', labelcolor='r')
ax1.plot(loss_array, color='r')
ax2 = ax1.twinx()
ax2.set_ylabel("Singular values (modes)", c='b')
ax2.tick_params(axis='y', labelcolor='b')
for i in range(n_sing_values):
ax2.plot(sv_array[:, i], c=cmap(i))
fig.tight_layout()
plt.show()
def plot_ills_sv_twin(ill_array, sv_array, ill_label):
"""
Plot network training evolution
and illusory correlations
Args:
sv_array: np.ndarray
Log of singular values/modes across epochs
ill_array: np.ndarray
Log of illusory correlations per epoch
ill_label: np.ndarray
Log of labels associated with illusory correlations
Returns:
Nothing
"""
n_sing_values = sv_array.shape[1]
sv_array = sv_array / np.max(sv_array)
cmap = plt.cm.get_cmap("winter", n_sing_values)
fig = plt.figure(figsize=(10, 5))
ax1 = plt.gca()
ax1.set_title("Network training and the Illusory Correlations")
ax1.set_xlabel("Epoch")
ax1.set_ylabel(ill_label, c='r')
ax1.tick_params(axis='y', labelcolor='r')
ax1.plot(ill_array, color='r', linewidth=3)
ax1.set_ylim(-1.05, 1.05)
ax2 = ax1.twinx()
ax2.set_ylabel("Singular values (modes)", c='b')
ax2.tick_params(axis='y', labelcolor='b')
for i in range(n_sing_values):
ax2.plot(sv_array[:, i], c=cmap(i))
fig.tight_layout()
plt.show()
def plot_loss_sv_rsm(loss_array, sv_array, rsm_array, i_ep):
"""
Plot learning dynamics
Args:
sv_array: np.ndarray
Log of singular values/modes across epochs
loss_array: np.ndarray
Log of MSE loss per epoch
rsm_array: torch.tensor
Representation similarity matrix
i_ep: int
Which epoch to show
Returns:
Nothing
"""
n_ep = loss_array.shape[0]
rsm_array = rsm_array / np.max(rsm_array)
sv_array = sv_array / np.max(sv_array)
n_sing_values = sv_array.shape[1]
cmap = plt.cm.get_cmap("winter", n_sing_values)
fig = plt.figure(figsize=(14, 5))
gs = gridspec.GridSpec(1, 2, width_ratios=[5, 3])
ax0 = plt.subplot(gs[1])
ax0.yaxis.tick_right()
implot = ax0.imshow(rsm_array[i_ep], cmap="Purples", vmin=0.0, vmax=1.0)
divider = make_axes_locatable(ax0)
cax = divider.append_axes("right", size="5%", pad=0.9)
cbar = plt.colorbar(implot, cax=cax, ticks=[])
cbar.ax.set_ylabel('Similarity', fontsize=12)
ax0.set_title("RSM at epoch {}".format(i_ep), fontsize=16)
ax0.set_yticks(ticks=np.arange(n_sing_values))
ax0.set_yticklabels(labels=item_names)
ax0.set_xticks(ticks=np.arange(n_sing_values))
ax0.set_xticklabels(labels=item_names, rotation='vertical')
ax1 = plt.subplot(gs[0])
ax1.set_title("Learning Dynamics", fontsize=16)
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Mean Squared Error", c='r')
ax1.tick_params(axis='y', labelcolor='r', direction="in")
ax1.plot(np.arange(n_ep), loss_array, color='r')
ax1.axvspan(i_ep-2, i_ep+2, alpha=0.2, color='m')
ax2 = ax1.twinx()
ax2.set_ylabel("Singular values", c='b')
ax2.tick_params(axis='y', labelcolor='b', direction="in")
for i in range(n_sing_values):
ax2.plot(np.arange(n_ep), sv_array[:, i], c=cmap(i))
ax1.set_xlim(-1, n_ep+1)
ax2.set_xlim(-1, n_ep+1)
plt.show()
Helper functions#
Show code cell source
# @title Helper functions
def build_tree(n_levels, n_branches, probability,
to_np_array=True):
"""
Builds tree
Args:
n_levels: int
Number of levels in tree
n_branches: int
Number of branches in tree
probability: float
Flipping probability
to_np_array: boolean
If true, represent tree as np.ndarray
Returns:
tree: dict if to_np_array=False
np.ndarray otherwise
Tree
"""
assert 0.0 <= probability <= 1.0
tree = {}
tree["level"] = [0]
for i in range(1, n_levels+1):
tree["level"].extend([i]*(n_branches**i))
tree["pflip"] = [probability]*len(tree["level"])
tree["parent"] = [None]
k = len(tree["level"])-1
for j in range(k//n_branches):
tree["parent"].extend([j]*n_branches)
if to_np_array:
tree["level"] = np.array(tree["level"])
tree["pflip"] = np.array(tree["pflip"])
tree["parent"] = np.array(tree["parent"])
return tree
def sample_from_tree(tree, n):
"""
Generates n samples from a tree
Args:
tree: np.ndarray/dictionary
Tree
n: int
Number of levels in tree
Returns:
x: np.ndarray
Sample from tree
"""
items = [i for i, v in enumerate(tree["level"]) if v == max(tree["level"])]
n_items = len(items)
x = np.zeros(shape=(n, n_items))
rand_temp = np.random.rand(n, len(tree["pflip"]))
flip_temp = np.repeat(tree["pflip"].reshape(1, -1), n, 0)
samp = (rand_temp > flip_temp) * 2 - 1
for i in range(n_items):
j = items[i]
prop = samp[:, j]
while tree["parent"][j] is not None:
j = tree["parent"][j]
prop = prop * samp[:, j]
x[:, i] = prop.T
return x
def generate_hsd():
"""
Building the tree
Args:
None
Returns:
tree_labels: np.ndarray
Tree Labels
tree_features: np.ndarray
Sample from tree
"""
n_branches = 2 # 2 branches at each node
probability = .15 # flipping probability
n_levels = 3 # number of levels (depth of tree)
tree = build_tree(n_levels, n_branches, probability, to_np_array=True)
tree["pflip"][0] = 0.5
n_samples = 10000 # Sample this many features
tree_labels = np.eye(n_branches**n_levels)
tree_features = sample_from_tree(tree, n_samples).T
return tree_labels, tree_features
def linear_regression(X, Y):
"""
Analytical Linear regression
Args:
X: np.ndarray
Input features
Y: np.ndarray
Targets
Returns:
W: np.ndarray
Analytical solution
W = Y @ X.T @ np.linalg.inv(X @ X.T)
"""
assert isinstance(X, np.ndarray)
assert isinstance(Y, np.ndarray)
M, Dx = X.