{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {},
"id": "view-in-github"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"# Tutorial 1: Basic Reinforcement Learning\n",
"\n",
"**Week 3, Day 4: Basic Reinforcement Learning**\n",
"\n",
"**By Neuromatch Academy**\n",
"\n",
"__Content creators:__ Pablo Samuel Castro\n",
"\n",
"__Content reviewers:__ Shaonan Wang, Xiaomei Mi, Julia Costacurta, Dora Zhiyu Yang, Adrita Das\n",
"\n",
"__Content editors:__ Shaonan Wang\n",
"\n",
"__Production editors:__ Spiros Chavlis"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Tutorial Objectives\n",
"\n",
"Reinforcement Learning (RL) is a powerful framework for defining and solving problems where an agent learns to take actions that maximize its reward. Essentially, an agent observes the current state of the world, selects an action, receives a reward, and uses this feedback to improve its future actions. RL provides a formal, optimal way of describing this learning process, which was initially derived from studies of animal behavior and later validated by observations of the brain in humans and animals.\n",
"\n",
"This tutorial will introduce you to the basic concepts of RL using a simple example. By the end, you'll have a better understanding of how RL works and how it can be applied to solve a wide range of problems."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @markdown\n",
"from IPython.display import IFrame\n",
"from ipywidgets import widgets\n",
"out = widgets.Output()\n",
"with out:\n",
" print(f\"If you want to download the slides: https://osf.io/download/ztgws/\")\n",
" display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/ztgws/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n",
"display(out)"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Setup\n",
"\n",
"This is a GPU free notebook!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install and import feedback gadget\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Install and import feedback gadget\n",
"\n",
"!pip3 install vibecheck datatops --quiet\n",
"\n",
"from vibecheck import DatatopsContentReviewContainer\n",
"def content_review(notebook_section: str):\n",
" return DatatopsContentReviewContainer(\n",
" \"\", # No text prompt\n",
" notebook_section,\n",
" {\n",
" \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n",
" \"name\": \"neuromatch_dl\",\n",
" \"user_key\": \"f379rz8y\",\n",
" },\n",
" ).render()\n",
"\n",
"\n",
"feedback_prefix = \"W3D4_T1\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"execution": {}
},
"outputs": [],
"source": [
"# Imports\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from typing import Optional, Tuple"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Section 1: A history of RL\n",
"\n",
"In this section, we will briefly overview the history of reinforcement learning (RL) in reverse chronological order. This will help motivate why RL is an interesting topic to study!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video 1: Intro to RL\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 1: Intro to RL\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'x-NyX8bRsTQ'), ('Bilibili', 'BV1wk4y1M7ae')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Intro_to_RL_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Section 2: What is RL\n",
"\n",
"Using a very simple problem, we will give a high-level overview of what RL is and what are the main components that define the problem formulation.\n",
"\n",
"**Extra:** If you'd like to read more, the canonical reference for RL is Sutton & Barto's [Reinforcement Learning book](http://incompleteideas.net/book/the-book-2nd.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"## Section 2.1: Grid World\n",
"\n",
"GridWorlds are very simple \"navigation\" problems that can be very useful for motivating RL problems and solutions. They are commonly used in RL research, so it's a good idea to get familiar with them!\n",
"\n",
"We will use a simple GridWorld problem throughout this tutorial: an empty room with a reward at one corner.\n",
"\n",
"An example below defines a second GridWorld that is a little more difficult. Feel free to create your own!\n",
"\n",
"**Extra:** If you'd like to play with RL in GridWorlds on the web, you can check out this [GridWorld playground web app](https://gridworld-playground.glitch.me/)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Video 2: Grid World\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 2: Grid World\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', '4r400A5GNfE'), ('Bilibili', 'BV1GV411M7hk')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Grid_world_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"### Coding Exercise 1: Code a shortest-path planner for GridWorld\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create the GridWorldPlanner object (defaults to simple example)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Create the GridWorldPlanner object (defaults to simple example)\n",
"\n",
"ASCII_TO_EMOJI = {\n",
" ' ': '⬜',\n",
" '*': '⬛',\n",
" 'g': '⭐',\n",
" '<': '◀️',\n",
" '>': '▶️',\n",
" 'v': '🔽',\n",
" '^': '🔼',\n",
"}\n",
"\n",
"ACTIONS = ['<', '>', 'v', '^']\n",
"ACTION_EFFECTS = { # Position effects of each action.\n",
" '<': (0, -1),\n",
" '>': (0, 1),\n",
" 'v': (1, 0),\n",
" '^': (-1, 0),\n",
"}\n",
"\n",
"\n",
"def get_emoji(c, policy=None):\n",
" assert c in ASCII_TO_EMOJI\n",
" if policy is not None and c != 'g':\n",
" assert policy in ASCII_TO_EMOJI\n",
" if policy != ' ': # If there is a policy, use this instead.\n",
" c = policy\n",
" return ASCII_TO_EMOJI[c]\n",
"\n",
"\n",
"class GridWorldBase(object):\n",
" \"\"\"Defines a GridWorldPlanner object.\"\"\"\n",
"\n",
" def __init__(self, world_spec: Optional[np.ndarray] = None):\n",
" \"\"\"Creates a GridWorld object with an empty policy.\n",
"\n",
" Args:\n",
" world_spec: Optional array specification of GridWorld. If None, will\n",
" use default square room.\n",
" \"\"\"\n",
" if world_spec is None:\n",
" self.world_spec = np.array(\n",
" [['*', '*', '*', '*', '*', '*'],\n",
" ['*', ' ', ' ', ' ', ' ', '*'],\n",
" ['*', ' ', ' ', ' ', ' ', '*'],\n",
" ['*', ' ', ' ', ' ', ' ', '*'],\n",
" ['*', ' ', ' ', ' ', 'g', '*'],\n",
" ['*', '*', '*', '*', '*', '*']]\n",
" )\n",
" else:\n",
" assert len(world_spec.shape) == 2\n",
" self.world_spec = world_spec\n",
"\n",
" assert len(np.where(self.world_spec == 'g')[0]) == 1 # Only one goal.\n",
" self.policy = np.full_like(self.world_spec, ' ')\n",
" # **Note**: These may be useful for your planner!\n",
" self.goal_cell = [x[0] for x in np.where(self.world_spec == 'g')]\n",
"\n",
" def get_neighbours(self, cell: Tuple[int, int]):\n",
" \"\"\"Get the neighbours of a cell.\n",
"\n",
" **Note**: You should use this when writing your planner!\n",
"\n",
" Args:\n",
" cell: cell position.\n",
"\n",
" Returns:\n",
" Dict containing neighbouring cells for each of the 4 possible directions.\n",
" \"\"\"\n",
" height, width = self.world_spec.shape\n",
" i, j = cell\n",
" if i < 1 or i >= height or j < 1 or j >= width:\n",
" raise ValueError(f'Invalid cell position: {cell}')\n",
" neighbours = {}\n",
" for a in ACTIONS:\n",
" delta = ACTION_EFFECTS[a]\n",
" neighbour_pos = [i + delta[0], j + delta[1]]\n",
" if (neighbour_pos[0] < 0 or neighbour_pos[1] < 0 or\n",
" neighbour_pos[0] >= height or neighbour_pos[1] >= width or\n",
" self.world_spec[neighbour_pos[0], neighbour_pos[1]] == '*'):\n",
" # Remain in same cell\n",
" neighbours[a] = cell\n",
" else:\n",
" neighbours[a] = neighbour_pos\n",
" return neighbours\n",
"\n",
" def plan(self):\n",
" \"\"\"Constructs a random policy.\n",
"\n",
" **Note**: you will make something better further down!\n",
" \"\"\"\n",
" for i in range(self.policy.shape[0]):\n",
" for j in range(self.policy.shape[1]):\n",
" if self.world_spec[i, j] == '*': # Nothing to do for walls.\n",
" continue\n",
" self.policy[i, j] = ACTIONS[np.random.choice(len(ACTIONS))]\n",
"\n",
" def draw(self, include_policy: bool = False):\n",
" \"\"\"Draw the grid, and (optionally) include the policy.\"\"\"\n",
" for i in range(len(self.world_spec)):\n",
" row_range = range(len(self.world_spec[i]))\n",
" if include_policy:\n",
" row_chars = [get_emoji(self.world_spec[i, j], self.policy[i, j]) for j in row_range]\n",
" else:\n",
" row_chars = [get_emoji(self.world_spec[i, j], None) for j in row_range]\n",
" print(''.join(row_chars))\n",
"\n",
"\n",
"gwb = GridWorldBase()\n",
"print('Simple GridWorld:')\n",
"gwb.draw()\n",
"gwb.plan()\n",
"print('Random policy:')\n",
"gwb.draw(True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class GridWorldPlanner(GridWorldBase):\n",
" \"\"\"A GridWorld that finds a better policy.\"\"\"\n",
"\n",
" def plan(self):\n",
" \"\"\"Define a better planner!\n",
"\n",
" This gives you a starting point by setting the proper actions in the cells\n",
" surrounding the goal cell.\n",
"\n",
" **Assignment:** Do the rest!\n",
" \"\"\"\n",
" super().plan()\n",
" goal_queue = [self.goal_cell]\n",
" goals_done = set()\n",
" goal = goal_queue.pop(0) # pop from front of list\n",
" goal_neighbours = self.get_neighbours(goal)\n",
" goals_done.add(tuple(goal))\n",
"\n",
" for a in goal_neighbours:\n",
" nbr = tuple(goal_neighbours[a])\n",
" if nbr == goal:\n",
" continue\n",
" if nbr not in goals_done:\n",
" if a == '<':\n",
" self.policy[nbr[0], nbr[1]] = '>'\n",
" elif a == '>':\n",
" self.policy[nbr[0], nbr[1]] = '<'\n",
" elif a == '^':\n",
" self.policy[nbr[0], nbr[1]] = 'v'\n",
" else:\n",
" self.policy[nbr[0], nbr[1]] = '^'\n",
" goal_queue.append(nbr)\n",
"\n",
"\n",
"gwp = GridWorldPlanner()\n",
"print('Simple GridWorld:')\n",
"gwp.draw()\n",
"gwp.plan()\n",
"print('Better policy:')\n",
"gwp.draw(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"Make a better planner!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class GridWorldPlanner(GridWorldBase):\n",
" \"\"\"A GridWorld that finds a better policy.\"\"\"\n",
"\n",
" def plan(self):\n",
" \"\"\"Define a better planner!\n",
"\n",
" This gives you a starting point by setting the proper actions in the cells\n",
" surrounding the goal cell.\n",
"\n",
" **Assignment:** Do the rest!\n",
" \"\"\"\n",
" super().plan()\n",
" goal_queue = [self.goal_cell]\n",
" goals_done = set()\n",
" #################################################\n",
" # Implement a better planer\n",
" raise NotImplementedError(\"Define a better planner!`\")\n",
" #################################################\n",
" while ...:\n",
" goal = goal_queue.pop(0) # pop from front of list\n",
" goal_neighbours = self.get_neighbours(goal)\n",
" goals_done.add(tuple(goal))\n",
" for a in goal_neighbours:\n",
" nbr = tuple(goal_neighbours[a])\n",
" if nbr == goal:\n",
" continue\n",
" if nbr not in goals_done:\n",
" if a == '<':\n",
" self.policy[nbr[0], nbr[1]] = '>'\n",
" elif a == '>':\n",
" self.policy[nbr[0], nbr[1]] = '<'\n",
" elif a == '^':\n",
" self.policy[nbr[0], nbr[1]] = 'v'\n",
" else:\n",
" self.policy[nbr[0], nbr[1]] = '^'\n",
" goal_queue.append(nbr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W3D4_BasicReinforcementLearning/solutions/W3D4_Tutorial1_Solution_32802337.py)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"gwp = GridWorldPlanner()\n",
"print('Simple GridWorld:')\n",
"gwp.draw()\n",
"gwp.plan()\n",
"print('Better policy:')\n",
"gwp.draw(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Make_a_better_planner_Exercise\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Try it out in a harder problem.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Try it out in a harder problem.\n",
"harder_grid = np.array(\n",
" [['*', '*', '*', '*', '*', '*', '*', '*', '*'],\n",
" ['*', ' ', ' ', ' ', '*', ' ', ' ', 'g', '*'],\n",
" ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],\n",
" ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],\n",
" ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],\n",
" ['*', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '*'],\n",
" ['*', '*', '*', '*', '*', '*', '*', '*', '*'],\n",
" ]\n",
")\n",
"gwb_2 = GridWorldBase(harder_grid)\n",
"gwp_2 = GridWorldPlanner(harder_grid)\n",
"print('Harder GridWorld:')\n",
"gwb_2.draw()\n",
"gwb_2.plan()\n",
"print('Random policy:')\n",
"gwb_2.draw(True)\n",
"print('Better policy:')\n",
"gwp_2.plan()\n",
"gwp_2.draw(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"## Section 2.2: Markov Decision Process (MDP)\n",
"\n",
"Formulating RL problems traditionally happens via a Markov decision process (MDP). In this section, we will introduce all the necessary notation and write code to define the MDP corresponding to our simple GridWorld.\n",
"\n",
"**Extra:** Martin Puterman's [book on Markov Decision Processes](https://onlinelibrary.wiley.com/doi/book/10.1002/9780470316887) is an excellent reference if you'd like to read more."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Video 3: Markov Decision Process\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 3: Markov Decision Process\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'AfIsG1I4MRE'), ('Bilibili', 'BV1VV411M7pX')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Markov_Decision_Process_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"### Coding exercise 2: Create an MDP from the GridWorld specification\n",
"\n",
"Create $P$ and $R$ matrices for MDP."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class MDPBase(object):\n",
" \"\"\"This object creates a proper MDP from a GridWorld object.\"\"\"\n",
"\n",
" def __init__(self, grid_world: GridWorldBase):\n",
" \"\"\"Constructs an MDP from a GridWorldBase object.\n",
"\n",
" Args:\n",
" grid_world: GridWorld specification.\n",
" \"\"\"\n",
" # Determine how many valid states there are and create empty matrices.\n",
" self.grid_world = grid_world\n",
" self.num_states = np.sum(grid_world.world_spec != '*')\n",
" self.num_actions = len(ACTIONS)\n",
" self.P = np.zeros((self.num_states, self.num_actions, self.num_states))\n",
" self.R = np.zeros((self.num_states, self.num_actions))\n",
" self.pi = np.zeros(self.num_states, dtype=np.int32)\n",
" # Create mapping from cell positions to state ID.\n",
" state_idx = 0\n",
" self.cell_to_state = np.ones(grid_world.world_spec.shape, dtype=np.int32) * -1 # Defaults to -1.\n",
" self.state_to_cell = {}\n",
" for i, row in enumerate(grid_world.world_spec):\n",
" for j, cell in enumerate(row):\n",
" if cell == '*':\n",
" continue\n",
" if cell == 'g':\n",
" self.goal_state = state_idx\n",
" self.cell_to_state[i, j] = state_idx\n",
" self.state_to_cell[state_idx] = (i, j)\n",
" state_idx += 1\n",
" #################################################\n",
" # States should be numbered from left-to-right and from top-to-bottom.\n",
" raise NotImplementedError(\"Calculate P and R\")\n",
" #################################################\n",
" # Assign transition probabilities and rewards accordingly.\n",
" for s in range(...):\n",
" neighbours = ...\n",
" for a, action in enumerate(neighbours):\n",
" nbr = ...\n",
" s2 = self.cell_to_state[..., ...]\n",
" self.P[s, a, s2] = 1.0 # Deterministic transitions\n",
" if s2 == self.goal_state:\n",
" self.R[s, a] = 1.0\n",
"\n",
" def draw(self, include_policy: bool = False):\n",
" # First make sure we convert our MDP policy into the GridWorld policy.\n",
" for s in range(self.num_states):\n",
" r, c = self.state_to_cell[s]\n",
" self.grid_world.policy[r, c] = ACTIONS[self.pi[s]]\n",
" self.grid_world.draw(include_policy)\n",
"\n",
" def plan(self):\n",
" \"\"\"Define a planner\n",
" \"\"\"\n",
" goal_queue = [self.goal_state]\n",
" goals_done = set()\n",
" #################################################\n",
" # Set the proper actions\n",
" raise NotImplementedError(\"Implement `plan` function!\")\n",
" #################################################\n",
" while goal_queue:\n",
" goal = goal_queue.pop(0) # pop from front of list\n",
" nbr_states, nbr_actions = ...\n",
" goals_done.add(goal)\n",
" for s, a in zip(..., ...):\n",
" if s == goal:\n",
" continue\n",
" if s not in goals_done:\n",
" self.pi[s] = ...\n",
" goal_queue.append(s)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W3D4_BasicReinforcementLearning/solutions/W3D4_Tutorial1_Solution_0102ee70.py)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"mdpb = MDPBase(gwb)\n",
"\n",
"# Verify the transitions were properly created.\n",
"for i, row in enumerate(mdpb.grid_world.world_spec):\n",
" for j, cell in enumerate(row):\n",
" if cell == '*':\n",
" continue\n",
" neighbours = mdpb.grid_world.get_neighbours((i, j))\n",
" s = mdpb.cell_to_state[i, j]\n",
" for a, action in enumerate(neighbours):\n",
" nbr = neighbours[action]\n",
" s2 = mdpb.cell_to_state[nbr[0], nbr[1]]\n",
" assert np.sum(mdpb.P[s, a, :]) == 1.0\n",
" assert mdpb.P[s, a, s2] == 1.0\n",
" if s2 == mdpb.goal_state:\n",
" assert mdpb.R[s, a] == 1.0\n",
" else:\n",
" assert mdpb.R[s, a] == 0.0\n",
"\n",
"print('P and R matrices successfully created!')\n",
"print('GridWorld:')\n",
"mdpb.draw()\n",
"print('Shortest path policy:')\n",
"mdpb.plan()\n",
"mdpb.draw(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Create_an_MDP_Exercise\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"## Section 2.3: $Q$-values\n",
"\n",
"$Q$-values are central to RL algorithms, as they quantify the *desirability* of performing an action given a particular state. The agent updates these values throughout training and can use its estimates to decide how to act."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Video 4: Q-values\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 4: Q-values\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'u2-uqRiJHuM'), ('Bilibili', 'BV1gk4y1M7iK')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Q_values_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"### Coding exercise 3: Create a steps-to-go solver\n",
"\n",
"Create a new MDP class that holds steps-to-go as Q-values"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class MDPToGo(MDPBase):\n",
"\n",
" def __init__(self, grid_world: GridWorldBase):\n",
" \"\"\"Constructs an MDP from a GridWorldBase object.\n",
"\n",
" States should be numbered from left-to-right and from top-to-bottom.\n",
"\n",
" Args:\n",
" grid_world: GridWorld specification.\n",
" \"\"\"\n",
" super().__init__(grid_world)\n",
" self.Q = np.zeros((self.num_states, self.num_actions))\n",
"\n",
" def computeQ(self):\n",
" \"\"\"Store discounted steps-to-go in an SxA matrix called Q.\n",
"\n",
" This matrix will then be used to extract the optimal policy.\n",
" \"\"\"\n",
" #################################################\n",
" # Implement a function to compute Q\n",
" raise NotImplementedError(\"Implement `ComputeQ` function!\")\n",
" #################################################\n",
" goal_queue = [(self.goal_state, 0)]\n",
" goals_done = set()\n",
" while goal_queue:\n",
" goal, steps_to_go = ... # pop from front of list\n",
" steps_to_go += ... # Increase the number of steps to goal.\n",
" nbr_states, nbr_actions = ...\n",
" goals_done.add(...)\n",
" for s, a in zip(..., ...):\n",
" if goal == self.goal_state and s == self.goal_state:\n",
" self.Q[s, a] = ...\n",
" elif s == goal:\n",
" # If (s, a) leads to itself then we have an infinite loop (since\n",
" # we're assuming deterministic transitions).\n",
" self.Q[s, a] = ...\n",
" else:\n",
" self.Q[s, a] = ...\n",
" if s not in goals_done:\n",
" goal_queue.append((..., ...))\n",
"\n",
" def plan(self):\n",
" \"\"\"Now planning is just doing an argmin over the Q-values!\n",
"\n",
" Note that this is a little different than standard Q-learning (where we do\n",
" an argmax), since our Q-values currently store steps-to-go.\n",
" \"\"\"\n",
" self.pi = np.argmin(self.Q, axis=-1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W3D4_BasicReinforcementLearning/solutions/W3D4_Tutorial1_Solution_0f69d259.py)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"mdpTg = MDPToGo(gwb)\n",
"print('GridWorld:')\n",
"mdpTg.draw()\n",
"# Compute Q, then extract policy from it.\n",
"mdpTg.computeQ()\n",
"mdpTg.plan()\n",
"print('Optimal policy:')\n",
"mdpTg.draw(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Create_a_step_to_go_solver_Exercise\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Section 3: Value iteration\n",
"\n",
"Value iteration is an iterative algorithm that continuously improves estimates of $Q$ and $V$ by performing the *Bellman backup*. This assumes access to $P$ and $R$ (not typically accessible in RL) but is the backbone of $Q$-learning, which we will discuss later.\n",
"\n",
"
\n",
"\n",
"> *Did you know?* Richard Bellman developed dynamic programming (a core part of any computer science curriculum) precisely for value iteration."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video 5: Value iteration\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 5: Value iteration\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'XIcX37uaRF0'), ('Bilibili', 'BV1SF41197kC')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Value_Iteration_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"### Coding exercise 4: Implement value iteration\n",
"\n",
"Create a new MDP class that does value iteration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class MDPValueIteration(MDPToGo):\n",
"\n",
" def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):\n",
" \"\"\"Constructs an MDP from a GridWorldBase object.\n",
"\n",
" States should be numbered from left-to-right and from top-to-bottom.\n",
"\n",
" Args:\n",
" grid_world: GridWorld specification.\n",
" gamma: Discount factor.\n",
" \"\"\"\n",
" super().__init__(grid_world)\n",
" self.gamma = gamma\n",
"\n",
" def computeQ(self, error_tolerance : float = 1e-5):\n",
" \"\"\"Compute Q and V vectors via value iteration.\n",
"\n",
" Args:\n",
" error_tolerance: How much error we tolerate between successive Q updates.\n",
" \"\"\"\n",
" self.Q = np.zeros((self.num_states, self.num_actions))\n",
" num_iterations = 0\n",
" error = np.inf\n",
" #################################################\n",
" # Write this method!\n",
" # First find Q, and then extract V from Q.\n",
" # Hint: Use matrix multiplication instead of for loops!\n",
" raise NotImplementedError(\"Implement `computeQ` function!\")\n",
" #################################################\n",
" while error > error_tolerance:\n",
" new_Q = ...\n",
" max_next_Q = ...\n",
" for a in range(self.num_actions):\n",
" new_Q[:, a] = ...\n",
" error = ...\n",
" self.Q = np.copy(new_Q)\n",
" num_iterations += 1\n",
" self.V = np.max(self.Q, axis=-1)\n",
" print(f'Q and V found in {num_iterations} iterations with an error tolerance of {error_tolerance}.')\n",
"\n",
" def plan(self):\n",
" \"\"\"Now planning is just doing an argmax over the Q-values!\n",
" \"\"\"\n",
" #################################################\n",
" # Note that we're going back to argmax, since the Q-values now represent proper\n",
" # \"returns-to-go\", so we want to maximize that.\n",
" # Write this method! It should be a one-liner, and very similar to what you\n",
" # used for extracting V from Q.\n",
" raise NotImplementedError(\"Implement `plan` function!\")\n",
" #################################################\n",
" self.pi = ...\n",
"\n",
" def _draw_v(self):\n",
" \"\"\"Draw the V values.\"\"\"\n",
" min_v = np.min(self.V)\n",
" max_v = np.max(self.V)\n",
" wall_v = 2 * min_v - max_v # Creating a smaller value for walls.\n",
" grid_values = np.ones_like(self.grid_world.world_spec, dtype=np.int32) * wall_v\n",
" # Fill in the V values in grid cells.\n",
" for s in range(self.num_states):\n",
" cell = self.state_to_cell[s]\n",
" grid_values[cell[0], cell[1]] = self.V[s]\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.grid(False)\n",
" ax.get_xaxis().set_visible(False)\n",
" ax.get_yaxis().set_visible(False)\n",
" grid = ax.matshow(grid_values)\n",
" grid.set_clim(wall_v, max_v)\n",
" fig.colorbar(grid)\n",
"\n",
" def draw(self, draw_mode: str = 'grid'):\n",
" \"\"\"Draw the GridWorld according to specified mode.\n",
"\n",
" Args:\n",
" draw_mode: Specification of what mode to draw. Supported options:\n",
" 'grid': Draw the base GridWorld.\n",
" 'policy': Display the policy.\n",
" 'values': Display the values for each state.\n",
" \"\"\"\n",
" # First make sure we convert our MDP policy into the GridWorld policy.\n",
" if draw_mode == 'values':\n",
" self._draw_v()\n",
" else:\n",
" super().draw(draw_mode == 'policy')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W3D4_BasicReinforcementLearning/solutions/W3D4_Tutorial1_Solution_fa9d3f4b.py)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"mdpVi = MDPValueIteration(gwb)\n",
"print('GridWorld:')\n",
"mdpVi.draw()\n",
"# Compute Q, then extract policy from it.\n",
"mdpVi.computeQ()\n",
"mdpVi.plan()\n",
"print('Optimal policy:')\n",
"mdpVi.draw('policy')\n",
"mdpVi.draw('values')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Implement_Value_Iteration_Exercise\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Section 4: Policy iteration\n",
"\n",
"Rather than iterating on estimates of $Q$ and $V$ until we've reached some form of convergence, why not iterate directly on the policy $\\pi$ instead?\n",
"\n",
"Policy iteration does just that and can sometimes lead to solutions in fewer steps than value iteration."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video 6: Policy iteration\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 6: Policy iteration\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'Z77kVYiRm_4'), ('Bilibili', 'BV17u411b7R1')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Policy_Iteration_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"### Coding exercise 5: Implement policy iteration\n",
"\n",
"Create a new MDP class that does policy iteration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class MDPPolicyIteration(MDPToGo):\n",
"\n",
" def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):\n",
" \"\"\"Constructs an MDP from a GridWorldBase object.\n",
"\n",
" States should be numbered from left-to-right and from top-to-bottom.\n",
"\n",
" Args:\n",
" grid_world: GridWorld specification.\n",
" gamma: Discount factor.\n",
" \"\"\"\n",
" super().__init__(grid_world)\n",
" self.gamma = gamma\n",
"\n",
" def findPi(self):\n",
" \"\"\"Find π policy.\n",
" \"\"\"\n",
" self.Q = np.zeros((self.num_states, self.num_actions))\n",
" self.pi = np.zeros(self.num_states, dtype=np.int32)\n",
" num_iterations = 0\n",
" #################################################\n",
" # Compute π, which involves computing Q.\n",
" # Once you have π and Q, find V.\n",
" # Hint: Your value iteration solution will be useful here.\n",
" raise NotImplementedError(\"Implement `findPi` function!\")\n",
" #################################################\n",
" new_pi = ... # initialize to ones\n",
" while np.any(new_pi != self.pi):\n",
" new_pi = ...\n",
" new_Q = ... # initialize to zeros\n",
" next_V = ...\n",
" for a in range(self.num_actions):\n",
" new_Q[:, a] = ...\n",
" self.Q = np.copy(new_Q)\n",
" self.pi = ...\n",
" num_iterations += 1\n",
" self.V = ...\n",
" print(f'Q and V found in {num_iterations} iterations.')\n",
"\n",
" def _draw_v(self):\n",
" \"\"\"Draw the V values.\"\"\"\n",
" min_v = np.min(self.V)\n",
" max_v = np.max(self.V)\n",
" wall_v = 2 * min_v - max_v # Creating a smaller value for walls.\n",
" grid_values = np.ones_like(self.grid_world.world_spec, dtype=np.int32) * wall_v\n",
" # Fill in the V values in grid cells.\n",
" for s in range(self.num_states):\n",
" cell = self.state_to_cell[s]\n",
" grid_values[cell[0], cell[1]] = self.V[s]\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.grid(False)\n",
" ax.get_xaxis().set_visible(False)\n",
" ax.get_yaxis().set_visible(False)\n",
" grid = ax.matshow(grid_values)\n",
" grid.set_clim(wall_v, max_v)\n",
" fig.colorbar(grid)\n",
"\n",
" def draw(self, draw_mode: str = 'grid'):\n",
" \"\"\"Draw the GridWorld according to specified mode.\n",
"\n",
" Args:\n",
" draw_mode: Specification of what mode to draw. Supported options:\n",
" 'grid': Draw the base GridWorld.\n",
" 'policy': Display the policy.\n",
" 'values': Display the values for each state.\n",
" \"\"\"\n",
" # First make sure we convert our MDP policy into the GridWorld policy.\n",
" if draw_mode == 'values':\n",
" self._draw_v()\n",
" else:\n",
" super().draw(draw_mode == 'policy')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W3D4_BasicReinforcementLearning/solutions/W3D4_Tutorial1_Solution_63c032bf.py)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"mdpPi = MDPPolicyIteration(gwb)\n",
"print('GridWorld:')\n",
"mdpPi.draw()\n",
"# Compute Q, then extract policy from it.\n",
"mdpPi.findPi()\n",
"print('Optimal policy:')\n",
"mdpPi.draw('policy')\n",
"mdpPi.draw('values')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Implement_Policy_Iteration_Exercise\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Section 5: $Q$-learning algorithm\n",
"\n",
"RL assumes we don't have access to $P$ nor $R$, so we can use neither value nor policy iteration to find an optimal behavior for our agent.\n",
"\n",
"$Q$-learning, however, incorporates the Bellman backup into an online learning algorithm: $Q$-learning, which can be shown to converge to the true $Q$-values (under mild conditions)!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video 7: Q-learning\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 7: Q-learning\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'HqOdNTdpppE'), ('Bilibili', 'BV1az4y1n7pb')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Q_learning_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"### Coding exercise 6: Implement Q-learning\n",
"\n",
"Create a Q-learning class"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class QLearner(MDPValueIteration):\n",
"\n",
" def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):\n",
" \"\"\"Constructs an MDP from a GridWorldBase object.\n",
"\n",
" States should be numbered from left-to-right and from top-to-bottom.\n",
"\n",
" Args:\n",
" grid_world: GridWorld specification.\n",
" gamma: Discount factor.\n",
" \"\"\"\n",
" super().__init__(grid_world, gamma)\n",
" self.Q = np.zeros((self.num_states, self.num_actions))\n",
" # Pick an initial state randomly.\n",
" self.current_state = np.random.choice(self.num_states)\n",
"\n",
" def step(self, action: int) -> Tuple[int, float]:\n",
" \"\"\"Take a step in MDP from self.current_state.\n",
"\n",
" Args:\n",
" action: Action to take.\n",
"\n",
" Returns:\n",
" Next state and reward received.\n",
" \"\"\"\n",
" new_state = np.random.choice(self.num_states,\n",
" p=self.P[self.current_state, action, :])\n",
" return (new_state, self.R[self.current_state, action])\n",
"\n",
" def pickAction(self) -> int:\n",
" \"\"\"Pick the best action from the current state and Q-value estimates.\"\"\"\n",
" return np.argmax(self.Q[self.current_state, :])\n",
"\n",
" def maybeReset(self):\n",
" \"\"\"If current_state is goal, reset to a random state.\"\"\"\n",
" if self.current_state == self.goal_state:\n",
" self.current_state = np.random.choice(self.num_states)\n",
"\n",
" def learnQ(self, alpha: float = 0.1, max_steps: int = 10_000):\n",
" \"\"\"Learn the Q-function by interacting with the environment.\n",
"\n",
" Args:\n",
" alpha: Learning rate.\n",
" max_steps: Maximum number of steps to take.\n",
" \"\"\"\n",
" self.Q = np.zeros((self.num_states, self.num_actions))\n",
" num_steps = 0\n",
" #################################################\n",
" # Hint: Use the step(), pickAction(), and maybeReset() functions above.\n",
" # Note: The way you initialize the Q-values is crucial here. Try first with\n",
" # an all-zeros initialization (as is currently coded below). If it doesn't\n",
" # work, try a different initialization.\n",
" # Hint: The maximum possible value (given the rewards are in [0, 1]) is\n",
" # 1 / (1 - gamma).\n",
" raise NotImplementedError(\"Write `learnQ` function!\")\n",
" #################################################\n",
" while num_steps < max_steps:\n",
" a = ...\n",
" new_state, r = ...\n",
" td = ...\n",
" self.Q[self.current_state, a] += ...\n",
" self.current_state = ...\n",
" self.maybeReset()\n",
" num_steps += 1\n",
" self.V = ...\n",
"\n",
"\n",
" def plan(self):\n",
" \"\"\"Now planning is just doing an argmin over the Q-values!\n",
"\n",
" Note that this is a little different than standard Q-learning (where we do\n",
" an argmax), since our Q-values currently store steps-to-go.\n",
" \"\"\"\n",
" self.pi = np.argmax(self.Q, axis=-1)\n",
" self.V = np.max(self.Q, axis=-1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W3D4_BasicReinforcementLearning/solutions/W3D4_Tutorial1_Solution_54192b12.py)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"base_q_learner = QLearner(gwb)\n",
"print('GridWorld:')\n",
"base_q_learner.draw()\n",
"# Compute Q, then extract policy from it.\n",
"base_q_learner.learnQ()\n",
"base_q_learner.plan()\n",
"print('Optimal policy:')\n",
"base_q_learner.draw('policy')\n",
"base_q_learner.draw('values')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Implement_Q_learning_Exercise\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Section 6: $\\epsilon$-greedy exploitation\n",
"\n",
"Should an agent *exploit* its current estimates and policy, or should it *explore* the environment in case better policies are out there? This *exploration-exploitation dilemma* is a central problem in RL.\n",
"\n",
"In this section, we explore one of the simplest yet most effective methods for this tradeoff: the so-called $\\epsilon$-greedy exploration."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video 8: Epsilon-greedy exploration\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 8: Epsilon-greedy exploration\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'NiF7_RK_4M4'), ('Bilibili', 'BV1Jh4y1u78c')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Epsilon_greedy_exploration_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"### Coding exercise 7: Implement epsilon-greedy exploration\n",
"\n",
"Create a Q-learning class with epsilon-greedy exploration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class QLearnerExplorer(QLearner):\n",
"\n",
" def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99,\n",
" epsilon: float = 0.1):\n",
" \"\"\"Constructs an MDP from a GridWorldBase object.\n",
"\n",
" States should be numbered from left-to-right and from top-to-bottom.\n",
"\n",
" Args:\n",
" grid_world: GridWorld specification.\n",
" gamma: Discount factor.\n",
" epsilon: Exploration rate.\n",
" \"\"\"\n",
" super().__init__(grid_world, gamma)\n",
" self.epsilon = epsilon\n",
"\n",
" def pickAction(self):\n",
" \"\"\"Pick the next action from the current state.\n",
" \"\"\"\n",
" #################################################\n",
" # With probability epsilon will pick the next action randomly, otherwise will\n",
" # pick based on the Q-value estimates.\n",
" # Hint: It should only be a few lines of code!\n",
" raise NotImplementedError(\"Write the `pickAction` function!\")\n",
" #################################################\n",
" if ... < ...:\n",
" return ...\n",
" return ..."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"[*Click for solution*](https://github.com/NeuromatchAcademy/course-content-dl/tree/main/tutorials/W3D4_BasicReinforcementLearning/solutions/W3D4_Tutorial1_Solution_4743dbd3.py)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"explorer = QLearnerExplorer(gwb)\n",
"print('GridWorld:')\n",
"explorer.draw()\n",
"# Compute Q, then extract policy from it.\n",
"explorer.learnQ()\n",
"explorer.plan()\n",
"print('Optimal policy:')\n",
"explorer.draw('policy')\n",
"explorer.draw('values')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Implement_epsilon_greedy_exploration_Exercise\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Summary\n",
"\n",
"Reinforcement learning is important in artificial intelligence because it allows an agent to learn how to make decisions in complex environments based on trial and error. Understanding how an agent learns and the various algorithms used to facilitate this learning is crucial for developing effective RL systems.\n",
"\n",
"* The Gridworld environment is a commonly used testbed for RL algorithms and provides a simple yet challenging environment for agents to learn in. By mastering the Gridworld environment, researchers can apply the same principles to more complex tasks.\n",
"\n",
"* $Q$-values are a fundamental concept in RL, representing the expected reward an agent will receive by taking a certain action in a given state. Computing $Q$-values is essential for many RL algorithms, including $Q$-learning.\n",
"\n",
"* The value and policy iteration algorithms are important because they provide a way to compute optimal policies for an agent in a given environment. These algorithms help an agent make decisions leading to the greatest cumulative reward over time.\n",
"\n",
"* $Q$-learning is a widely used RL algorithm known for its simplicity and effectiveness. It is essential because it allows an agent to learn how to make decisions in an environment without requiring a model of the environment and can be applied to a wide range of tasks.\n",
"\n",
"* Finally, $\\epsilon$-greedy exploration is an important concept in RL because it helps to balance the exploration-exploitation tradeoff. Occasionally choosing a random action allows an agent to explore new areas of the environment and potentially discover better policies.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Daily survey\n",
"\n",
"Don't forget to complete your reflections and content check in the daily survey! Please be patient after logging in as there is a small delay before you will be redirected to the survey.\n",
"\n",
"
"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"include_colab_link": true,
"name": "W3D4_Tutorial1",
"provenance": [],
"toc_visible": true
},
"kernel": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"toc-autonumbering": true
},
"nbformat": 4,
"nbformat_minor": 0
}