Open In Colab   Open in Kaggle

Tutorial 1: Introduction to Reinforcement Learning#

Week 3, Day 4: Basic Reinforcement Learning

By Neuromatch Academy

Content creators: Pablo Samuel Castro

Content reviewers: Shaonan Wang, Xiaomei Mi, Julia Costacurta, Dora Zhiyu Yang, Adrita Das, Jiaxin Cindy Tu

Content editors: Shaonan Wang, Jiaxin Cindy Tu

Production editors: Spiros Chavlis, Konstantine Tsafatinos


Tutorial Objectives#

Reinforcement Learning (RL) is a powerful framework for defining and solving problems where an agent learns to take actions that maximize its reward. Essentially, an agent observes the current state of the world, selects an action, receives a reward, and uses this feedback to improve its future actions. RL provides a formal, optimal way of describing this learning process, which was initially derived from studies of animal behavior and later validated by observations of the brain in humans and animals.

This tutorial will introduce you to the basic concepts of RL using a simple example. By the end, you’ll have a better understanding of how RL works and how it can be applied to solve a wide range of problems.


Setup#

This is a GPU free notebook!

Install and import feedback gadget#

Hide code cell source
# @title Install and import feedback gadget

!pip3 install vibecheck datatops --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_dl",
            "user_key": "f379rz8y",
        },
    ).render()


feedback_prefix = "W3D4_T1"
# Imports
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple

Figure settings#

Hide code cell source
# @title Figure settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True

%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle")

Section 1: A history of RL#

In this section, we will briefly overview the history of reinforcement learning (RL) in reverse chronological order. This will help motivate why RL is an interesting topic to study!

Video 1: Intro to RL#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Intro_to_RL_Video")

Section 2: What is RL#

Using a very simple problem, we will give a high-level overview of what RL is and what are the main components that define the problem formulation.

Extra: If you’d like to read more, the canonical reference for RL is Sutton & Barto’s Reinforcement Learning book.

Section 2.1: Grid World#

GridWorlds are very simple “navigation” problems that can be very useful for motivating RL problems and solutions. They are commonly used in RL research, so it’s a good idea to get familiar with them!

We will use a simple GridWorld problem throughout this tutorial: an empty room with a reward at one corner.

An example below defines a second GridWorld that is a little more difficult. Feel free to create your own!

Extra: If you’d like to play with RL in GridWorlds on the web, you can check out this GridWorld playground web app.

Video 2: Grid World#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Grid_world_Video")

Coding Exercise 1: Code a shortest-path planner for GridWorld#

Create the GridWorldPlanner object (defaults to simple example)#

Hide code cell source
# @title Create the GridWorldPlanner object (defaults to simple example)

ASCII_TO_EMOJI = {
    ' ': '⬜',
    '*': '⬛',
    'g': '⭐',
    '<': '◀️',
    '>': '▶️',
    'v': '🔽',
    '^': '🔼',
}

ACTIONS = ['<', '>', 'v', '^']
ACTION_EFFECTS = {  # Position effects of each action.
    '<': (0, -1),
    '>': (0, 1),
    'v': (1, 0),
    '^': (-1, 0),
}


def get_emoji(c, policy=None):
  assert c in ASCII_TO_EMOJI
  if policy is not None and c != 'g':
    assert policy in ASCII_TO_EMOJI
    if policy != ' ':  # If there is a policy, use this instead.
      c = policy
  return ASCII_TO_EMOJI[c]


class GridWorldBase(object):
  """Defines a GridWorldPlanner object."""

  def __init__(self, world_spec: Optional[np.ndarray] = None):
    """Creates a GridWorld object with an empty policy.

    Args:
      world_spec: Optional array specification of GridWorld. If None, will
                  use default square room.
    """
    if world_spec is None:
      self.world_spec = np.array(
          [['*', '*', '*', '*', '*', '*'],
           ['*', ' ', ' ', ' ', ' ', '*'],
           ['*', ' ', ' ', ' ', ' ', '*'],
           ['*', ' ', ' ', ' ', ' ', '*'],
           ['*', ' ', ' ', ' ', 'g', '*'],
           ['*', '*', '*', '*', '*', '*']]
      )
    else:
      assert len(world_spec.shape) == 2
      self.world_spec = world_spec

    assert len(np.where(self.world_spec == 'g')[0]) == 1  # Only one goal.
    self.policy = np.full_like(self.world_spec, ' ')
    # **Note**: These may be useful for your planner!
    self.goal_cell = [x[0] for x in np.where(self.world_spec == 'g')]

  def get_neighbours(self, cell: Tuple[int, int]):
    """Get the neighbours of a cell.

    **Note**: You should use this when writing your planner!

    Args:
      cell: cell position.

    Returns:
      Dict containing neighbouring cells for each of the 4 possible directions.
    """
    height, width = self.world_spec.shape
    i, j = cell
    if i < 1 or i >= height or j < 1 or j >= width:
      raise ValueError(f'Invalid cell position: {cell}')
    neighbours = {}
    for a in ACTIONS:
      delta = ACTION_EFFECTS[a]
      neighbour_pos = [i + delta[0], j + delta[1]]
      if (neighbour_pos[0] < 0 or neighbour_pos[1] < 0 or
          neighbour_pos[0] >= height or neighbour_pos[1] >= width or
          self.world_spec[neighbour_pos[0], neighbour_pos[1]] == '*'):
        # Remain in same cell
        neighbours[a] = cell
      else:
        neighbours[a] = neighbour_pos
    return neighbours

  def plan(self):
    """Constructs a random policy.

    **Note**: you will make something better further down!
    """
    for i in range(self.policy.shape[0]):
      for j in range(self.policy.shape[1]):
        if self.world_spec[i, j] == '*':  # Nothing to do for walls.
          continue
        self.policy[i, j] = ACTIONS[np.random.choice(len(ACTIONS))]

  def draw(self, include_policy: bool = False):
    """Draw the grid, and (optionally) include the policy."""
    for i in range(len(self.world_spec)):
      row_range = range(len(self.world_spec[i]))
      if include_policy:
        row_chars = [get_emoji(self.world_spec[i, j], self.policy[i, j]) for j in row_range]
      else:
        row_chars = [get_emoji(self.world_spec[i, j], None) for j in row_range]
      print(''.join(row_chars))


gwb = GridWorldBase()
print('Simple GridWorld:')
gwb.draw()
gwb.plan()
print('Random policy:')
gwb.draw(True)
Simple GridWorld:
⬛⬛⬛⬛⬛⬛
⬛⬜⬜⬜⬜⬛
⬛⬜⬜⬜⬜⬛
⬛⬜⬜⬜⬜⬛
⬛⬜⬜⬜⭐⬛
⬛⬛⬛⬛⬛⬛
Random policy:
⬛⬛⬛⬛⬛⬛
⬛▶️▶️🔽◀️⬛
⬛▶️▶️▶️◀️⬛
⬛▶️🔽▶️◀️⬛
⬛▶️🔽◀️⭐⬛
⬛⬛⬛⬛⬛⬛
class GridWorldPlanner(GridWorldBase):
  """A GridWorld that finds a better policy."""

  def plan(self):
    """Define a better planner!

    This gives you a starting point by setting the proper actions in the cells
    surrounding the goal cell.

    **Assignment:** Do the rest!
    """
    super().plan()
    goal_queue = [self.goal_cell]
    goals_done = set()
    goal = goal_queue.pop(0)  # pop from front of list
    goal_neighbours = self.get_neighbours(goal)
    goals_done.add(tuple(goal))

    for a in goal_neighbours:
      nbr = tuple(goal_neighbours[a])
      if nbr == goal:
        continue
      if nbr not in goals_done:
        if a == '<':
          self.policy[nbr[0], nbr[1]] = '>'
        elif a == '>':
          self.policy[nbr[0], nbr[1]] = '<'
        elif a == '^':
          self.policy[nbr[0], nbr[1]] = 'v'
        else:
          self.policy[nbr[0], nbr[1]] = '^'
        goal_queue.append(nbr)


gwp = GridWorldPlanner()
print('Simple GridWorld:')
gwp.draw()
gwp.plan()
print('Better policy:')
gwp.draw(True)
Simple GridWorld:
⬛⬛⬛⬛⬛⬛
⬛⬜⬜⬜⬜⬛
⬛⬜⬜⬜⬜⬛
⬛⬜⬜⬜⬜⬛
⬛⬜⬜⬜⭐⬛
⬛⬛⬛⬛⬛⬛
Better policy:
⬛⬛⬛⬛⬛⬛
⬛🔽◀️◀️▶️⬛
⬛🔽🔽▶️▶️⬛
⬛🔼◀️🔽🔽⬛
⬛🔼▶️▶️⭐⬛
⬛⬛⬛⬛⬛⬛

Make a better planner!

class GridWorldPlanner(GridWorldBase):
  """A GridWorld that finds a better policy."""

  def plan(self):
    super().plan()
    goal_queue = [self.goal_cell]
    goals_done = set()

    ##############################################################
    ## TODO: replace the '...' with the correct loop condition,
    ## then remove this error.
    raise NotImplementedError("Fill in the while-loop condition!")
    ##############################################################
    while ...:
      goal = goal_queue.pop(0)  # pop from front of list
      goal_neighbours = self.get_neighbours(goal)
      goals_done.add(tuple(goal))
      for a in goal_neighbours:
        nbr = tuple(goal_neighbours[a])
        if nbr == goal:
          continue
        if nbr not in goals_done:
          if a == '<':
            self.policy[nbr[0], nbr[1]] = '>'
          elif a == '>':
            self.policy[nbr[0], nbr[1]] = '<'
          elif a == '^':
            self.policy[nbr[0], nbr[1]] = 'v'
          else:
            self.policy[nbr[0], nbr[1]] = '^'
          goal_queue.append(nbr)

Click for solution

gwp = GridWorldPlanner()
print('Simple GridWorld:')
gwp.draw()
gwp.plan()
print('Better policy:')
gwp.draw(True)

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Make_a_better_planner_Exercise")

Try it out in a harder problem.#

Hide code cell source
# @title Try it out in a harder problem.
harder_grid = np.array(
    [['*', '*', '*', '*', '*', '*', '*', '*', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', 'g', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],
     ['*', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '*'],
     ['*', '*', '*', '*', '*', '*', '*', '*', '*'],
    ]
)
gwb_2 = GridWorldBase(harder_grid)
gwp_2 = GridWorldPlanner(harder_grid)
print('Harder GridWorld:')
gwb_2.draw()
gwb_2.plan()
print('Random policy:')
gwb_2.draw(True)
print('Better policy:')
gwp_2.plan()
gwp_2.draw(True)

Section 2.2: Markov Decision Process (MDP)#

Formulating RL problems traditionally happens via a Markov decision process (MDP). In this section, we will introduce all the necessary notation and write code to define the MDP corresponding to our simple GridWorld.

Extra: Martin Puterman’s book on Markov Decision Processes is an excellent reference if you’d like to read more.

Video 3: Markov Decision Process#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Markov_Decision_Process_Video")

Before you start: understanding states, P, and R#

Before coding, let’s make sure you understand what \(P\) and \(R\) should look like.
Consider this tiny GridWorld (walls *, empty spaces, goal g):

* * * * *
* g · · *
* · · · *
* * * * *

States are numbered left-to-right, top-to-bottom (skipping walls):

 g=0   1   2
   3   4   5

For state 1 (one step to the right of the goal), the four actions give:

Action

Next cell

Next state s’

P[1, a, s']

R[1, a]

< left

goal cell

0 (goal!)

1.0

1.0

> right

col 3

2

1.0

0.0

v down

row 2, col 2

4

1.0

0.0

^ up

wall → stays

1

1.0

0.0

Key helpers already available to you:

  • self.state_to_cell[s](row, col) for state s

  • grid_world.get_neighbours((row, col)) → dict mapping each action to its next (row, col), already handling wall-bouncing

  • self.cell_to_state[row, col] → state index for that cell

Your task: fill in the inner loop so P and R are populated as shown above.

Coding exercise 2: Create an MDP from the GridWorld specification#

Create \(P\) and \(R\) matrices for the MDP.

class MDPBase(object):
  """Creates a proper MDP from a GridWorld object."""

  def __init__(self, grid_world: GridWorldBase):
    """Constructs an MDP from a GridWorldBase object.

    Args:
      grid_world: GridWorld specification.
    """
    # Determine how many valid states there are and create empty matrices.
    self.grid_world = grid_world
    self.num_states = np.sum(grid_world.world_spec != '*')
    self.num_actions = len(ACTIONS)
    self.P = np.zeros((self.num_states, self.num_actions, self.num_states))
    self.R = np.zeros((self.num_states, self.num_actions))
    self.pi = np.zeros(self.num_states, dtype=np.int32)

    # Build mapping between cell positions and state IDs (left→right, top→bottom).
    state_idx = 0
    self.cell_to_state = np.ones(grid_world.world_spec.shape, dtype=np.int32) * -1
    self.state_to_cell = {}
    for i, row in enumerate(grid_world.world_spec):
      for j, cell in enumerate(row):
        if cell == '*':
          continue
        if cell == 'g':
          self.goal_state = state_idx
        self.cell_to_state[i, j] = state_idx
        self.state_to_cell[state_idx] = (i, j)
        state_idx += 1

    # Populate P and R.
    for s in range(self.num_states):
      cell = self.state_to_cell[s]                     # (row, col) of state s
      neighbours = grid_world.get_neighbours(cell)      # dict: action → next (row,col)
      for a, action in enumerate(neighbours):
        ##############################################################
        ## TODO: fill in nbr and s2, then remove this error.
        raise NotImplementedError("Populate P and R!")
        ##############################################################
        nbr = ...
        s2  = self.cell_to_state[..., ...]
        self.P[s, a, s2] = 1.0
        if s2 == self.goal_state:
          self.R[s, a] = 1.0

  def draw(self, include_policy: bool = False):
    for s in range(self.num_states):
      r, c = self.state_to_cell[s]
      self.grid_world.policy[r, c] = ACTIONS[self.pi[s]]
    self.grid_world.draw(include_policy)

  def plan(self):
    goal_queue = [self.goal_state]
    goals_done = set()
    while goal_queue:
      goal = goal_queue.pop(0)
      nbr_states, nbr_actions = np.where(self.P[:, :, goal] > 0.)
      goals_done.add(goal)
      for s, a in zip(nbr_states, nbr_actions):
        if s == goal:
          continue
        if s not in goals_done:
          self.pi[s] = a
          goal_queue.append(s)

Click for solution

mdpb = MDPBase(gwb)

# Verify the transitions were properly created.
for i, row in enumerate(mdpb.grid_world.world_spec):
  for j, cell in enumerate(row):
    if cell == '*':
      continue
    neighbours = mdpb.grid_world.get_neighbours((i, j))
    s = mdpb.cell_to_state[i, j]
    for a, action in enumerate(neighbours):
      nbr = neighbours[action]
      s2 = mdpb.cell_to_state[nbr[0], nbr[1]]
      assert np.sum(mdpb.P[s, a, :]) == 1.0
      assert mdpb.P[s, a, s2] == 1.0
      if s2 == mdpb.goal_state:
        assert mdpb.R[s, a] == 1.0
      else:
        assert mdpb.R[s, a] == 0.0

print('P and R matrices successfully created!')
print('GridWorld:')
mdpb.draw()
print('Shortest path policy:')
mdpb.plan()
mdpb.draw(True)

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Create_an_MDP_Exercise")

Section 2.3: \(Q\)-values#

\(Q\)-values are central to RL algorithms, as they quantify the desirability of performing an action given a particular state. The agent updates these values throughout training and can use its estimates to decide how to act.

Video 4: Q-values#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Q_values_Video")

Coding exercise 3: Create a steps-to-go solver#

Create a new MDP class that holds steps-to-go as Q-values.

class MDPToGo(MDPBase):

  def __init__(self, grid_world: GridWorldBase):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
    """
    super().__init__(grid_world)
    self.Q = np.zeros((self.num_states, self.num_actions))

  def computeQ(self):
    """Store discounted steps-to-go in an SxA matrix called Q.

    This matrix will then be used to extract the optimal policy.
    """
    #################################################
    # Implement a function to compute Q
    raise NotImplementedError("Implement `ComputeQ` function!")
    #################################################
    goal_queue = [(self.goal_state, 0)]  # (state, steps taken so far)
    goals_done = set()
    while goal_queue:
      goal, steps_to_go = goal_queue.pop(0)
      steps_to_go += 1
      nbr_states, nbr_actions = np.where(self.P[:, :, goal] > 0.)
      goals_done.add(goal)

      for s, a in zip(nbr_states, nbr_actions):
        if goal == self.goal_state and s == self.goal_state:
          self.Q[s, a] = ...        # action at goal that stays at goal
        elif s == goal:
          self.Q[s, a] = ...        # If (s, a) leads to itself then we have an infinite loop
        else:
          self.Q[s, a] = ...        # normal: steps_to_go steps to goal
        if s not in goals_done:
          goal_queue.append((s, steps_to_go))

  def plan(self):
    """Now planning is just doing an argmin over the Q-values!

    Note that this is a little different than standard Q-learning (where we do
    an argmax), since our Q-values currently store steps-to-go.
    """
    self.pi = np.argmin(self.Q, axis=-1)

Click for solution

mdpTg = MDPToGo(gwb)
print('GridWorld:')
mdpTg.draw()
# Compute Q, then extract policy from it.
mdpTg.computeQ()
mdpTg.plan()
print('Optimal policy:')
mdpTg.draw(True)

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Create_a_step_to_go_solver_Exercise")

Section 3: Value iteration#

Value iteration is an iterative algorithm that continuously improves estimates of \(Q\) and \(V\) by performing the Bellman backup. This assumes access to \(P\) and \(R\) (not typically accessible in RL) but is the backbone of \(Q\)-learning, which we will discuss later.


Did you know? Richard Bellman developed dynamic programming (a core part of any computer science curriculum) precisely for value iteration.

Video 5: Value iteration#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Value_Iteration_Video")

Coding exercise 4: Implement value iteration#

Create a new MDP class that does value iteration.

class MDPValueIteration(MDPToGo):

  def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
      gamma: Discount factor.
    """
    super().__init__(grid_world)
    self.gamma = gamma

  def computeQ(self, error_tolerance: float = 1e-5):
    """Compute optimal Q and V via value iteration.
       Args:
          error_tolerance: How much error we tolerate between successive Q updates.
    """
    self.Q = np.zeros((self.num_states, self.num_actions))
    num_iterations = 0
    error = np.inf
    #################################################
    # Write this method!
    # First find Q, and then extract V from Q.
    # Hint: Use matrix multiplication instead of for loops!
    raise NotImplementedError("Implement `computeQ` function!")
    #################################################
    while error > error_tolerance:
      new_Q = np.zeros_like(self.Q)
      max_next_Q = ...
      for a in range(self.num_actions):
        new_Q[:, a] = ...
      error = np.max(abs(new_Q - self.Q))
      self.Q = np.copy(new_Q)
      num_iterations += 1

    self.V = np.max(self.Q, axis=-1)
    print(f'Converged in {num_iterations} iterations (tolerance {error_tolerance}).')

  def plan(self):
    """Now planning is just doing an argmax over the Q-values!
    """
    #################################################
    # Note that we're going back to argmax, since the Q-values now represent proper
    # "returns-to-go", so we want to maximize that.
    # Write this method! It should be a one-liner, and very similar to what you
    # used for extracting V from Q.
    raise NotImplementedError("Implement `plan` function!")
    #################################################
    self.pi = ...

  def _draw_v(self):
    """Draw the V values."""
    min_v = np.min(self.V)
    max_v = np.max(self.V)
    wall_v = 2 * min_v - max_v  # Creating a smaller value for walls.
    grid_values = np.ones_like(self.grid_world.world_spec, dtype=np.int32) * wall_v
    # Fill in the V values in grid cells.
    for s in range(self.num_states):
      cell = self.state_to_cell[s]
      grid_values[cell[0], cell[1]] = self.V[s]

    fig, ax = plt.subplots()
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    grid = ax.matshow(grid_values)
    grid.set_clim(wall_v, max_v)
    fig.colorbar(grid)

  def draw(self, draw_mode: str = 'grid'):
    """Draw the GridWorld according to specified mode.

    Args:
      draw_mode: Specification of what mode to draw. Supported options:
                 'grid': Draw the base GridWorld.
                 'policy': Display the policy.
                 'values': Display the values for each state.
    """
    # First make sure we convert our MDP policy into the GridWorld policy.
    if draw_mode == 'values':
      self._draw_v()
    else:
      super().draw(draw_mode == 'policy')

Click for solution

mdpVi = MDPValueIteration(gwb)
print('GridWorld:')
mdpVi.draw()
# Compute Q, then extract policy from it.
mdpVi.computeQ()
mdpVi.plan()
print('Optimal policy:')
mdpVi.draw('policy')
mdpVi.draw('values')

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_Value_Iteration_Exercise")

Section 4: Policy Iteration#

Rather than iterating on value estimates until convergence, policy iteration iterates directly on the policy \(\pi\):

  1. Policy evaluation: given \(\pi\), compute \(Q^\pi\) via the Bellman equation $\(Q^\pi(s,a) = R(s,a) + \gamma \sum_{s'} P(s'|s,a)\,Q^\pi(s',\pi(s'))\)$

  2. Policy improvement: extract the greedy policy \(\pi'(s) = \arg\max_a Q^\pi(s,a)\)

  3. Repeat until \(\pi' = \pi\).

Policy iteration can converge in fewer sweeps than value iteration because each full step evaluates one policy rather than doing a single Bellman backup.

Did you know? Richard Bellman developed dynamic programming (a core part of any computer science curriculum) precisely for algorithms like value and policy iteration.

Video 6: Policy iteration#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Policy_Iteration_Video")

Coding exercise 5: Implement policy iteration#

Create a new MDP class that does policy iteration.

class MDPPolicyIteration(MDPToGo):

  def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
      gamma: Discount factor.
    """
    super().__init__(grid_world)
    self.gamma = gamma

  def findPi(self):
    """Find the optimal policy via policy iteration."""
    self.Q  = np.zeros((self.num_states, self.num_actions))
    self.pi = np.zeros(self.num_states, dtype=np.int32)
    num_iterations = 0
    #################################################
    # Compute π, which involves computing Q.
    # Once you have π and Q, find V.
    # Hint: Your value iteration solution will be useful here.
    raise NotImplementedError("Implement `findPi` function!")
    #################################################
    new_pi = np.ones_like(self.pi)
    while np.any(new_pi != self.pi):
      new_pi  = self.pi          # save the current policy before updating
      new_Q   = ...          # initialize to zeros
      next_V  = np.array([mdpVi.Q[i, x] for i, x in enumerate(mdpVi.pi)])          # V under the current policy (hint: use mdpVi)
      for a in range(self.num_actions):
        new_Q[:, a] = ...    # Bellman evaluation: R + γ P V
      self.Q  = np.copy(new_Q)
      self.pi = ...          # greedy improvement
      num_iterations += 1

    self.V = np.max(self.Q, axis=-1)
    print(f'Policy iteration converged in {num_iterations} iteration(s).')
  def _draw_v(self):
    min_v = np.min(self.V); max_v = np.max(self.V)
    wall_v = 2 * min_v - max_v
    grid_values = np.ones_like(self.grid_world.world_spec, dtype=np.int32) * wall_v
    for s in range(self.num_states):
      cell = self.state_to_cell[s]
      grid_values[cell[0], cell[1]] = self.V[s]
    fig, ax = plt.subplots()
    ax.grid(False); ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
    grid = ax.matshow(grid_values); grid.set_clim(wall_v, max_v); fig.colorbar(grid)

  def draw(self, draw_mode: str = 'grid'):
    """Draw the GridWorld according to specified mode.

    Args:
      draw_mode: Specification of what mode to draw. Supported options:
                 'grid': Draw the base GridWorld.
                 'policy': Display the policy.
                 'values': Display the values for each state.
    """
    # First make sure we convert our MDP policy into the GridWorld policy.
    if draw_mode == 'values':
      self._draw_v()
    else:
      super().draw(draw_mode == 'policy')

Click for solution

mdpPi = MDPPolicyIteration(gwb)
print('GridWorld:')
mdpPi.draw()
mdpPi.findPi()
print('Optimal policy (policy iteration):')
mdpPi.draw('policy')
mdpPi.draw('values')

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_Policy_Iteration_Exercise")

Summary#

Reinforcement learning is important in artificial intelligence because it allows an agent to learn how to make decisions in complex environments based on trial and error. Understanding how an agent learns and the various algorithms used to facilitate this learning is crucial for developing effective RL systems. In this tutorial you built the foundations of Reinforcement Learning from the ground up:

  • GridWorld gives a concrete environment in which an agent navigates by taking actions and receiving rewards. By mastering the Gridworld environment, researchers can apply the same principles to more complex tasks.

  • A Markov Decision Process (MDP) formally describes that environment: states \(S\), actions \(A\), transition probabilities \(P\), and a reward function \(R\).

  • \(Q\)-values (steps-to-go) quantify how desirable each \((s, a)\) pair is — fewer steps = better.

  • Value and policy iteration algorithms are important because they provide a way to compute optimal policies for an agent in a given environment. These algorithms help an agent make decisions leading to the greatest cumulative reward over time.