Open In Colab   Open in Kaggle

Using RL to Model Cognitive Tasks#

By Neurmatch Academy

Content creators: Morteza Ansarinia, Yamil Vidal

Production editor: Spiros Chavlis


  • This project aims to use behavioral data to train an agent and then use the agent to investigate data produced by human subjects. Having a computational agent that mimics humans in such tests, we will be able to compare its mechanics with human data.

  • In another conception, we could fit an agent that learns many cognitive tasks that require abstract-level constructs such as executive functions. This is a multi-task control problem.


Install dependencies#

Hide code cell source
# @title Install dependencies
!pip install jedi --quiet
!pip install --upgrade pip setuptools wheel --quiet
!pip install dm-acme[jax] --quiet
!pip install dm-sonnet --quiet
!pip install trfl --quiet
!pip install numpy==1.23.3 --quiet --ignore-installed
!pip uninstall seaborn -y --quiet
!pip install seaborn --quiet
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
numba 0.56.4 requires numpy<1.24,>=1.18, but you have numpy 1.25.1 which is incompatible.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.81 requires numpy>=1.25.0, but you have numpy 1.23.3 which is incompatible.

# Imports
import time
import numpy as np
import pandas as pd
import sonnet as snt
import seaborn as sns
import matplotlib.pyplot as plt

import dm_env

import acme
from acme import specs
from acme import wrappers
from acme import EnvironmentLoop
from import dqn
from acme.utils import loggers

Figure settings#

Hide code cell source
# @title Figure settings
from IPython.display import clear_output, display, HTML
%matplotlib inline


  • Cognitive scientists use standard lab tests to tap into specific processes in the brain and behavior. Some examples of those tests are Stroop, N-back, Digit Span, TMT (Trail making tests), and WCST (Wisconsin Card Sorting Tests).

  • Despite an extensive body of research that explains human performance using descriptive what-models, we still need a more sophisticated approach to gain a better understanding of the underlying processes (i.e., a how-model).

  • Interestingly, many of such tests can be thought of as a continuous stream of stimuli and corresponding actions, that is in consonant with the RL formulation. In fact, RL itself is in part motivated by how the brain enables goal-directed behaviors using reward systems, making it a good choice to explain human performance.

  • One behavioral test example would be the N-back task.

    • In the N-back, participants view a sequence of stimuli, one by one, and are asked to categorize each stimulus as being either match or non-match. Stimuli are usually numbers, and feedback is given at both timestep and trajectory levels.

    • The agent is rewarded when its response matches the stimulus that was shown N steps back in the episode. A simpler version of the N-back uses two-choice action schema, that is match vs non-match. Once the present stimulus matches the one presented N step back, then the agent is expected to respond to it as being a match.

  • Given a trained RL agent, we then find correlates of its fitted parameters with the brain mechanisms. The most straightforward composition could be the correlation of model parameters with the brain activities.


Any dataset that used cognitive tests would work. Question: limit to behavioral data vs fMRI? Question: Which stimuli and actions to use? classic tests can be modeled using 1) bounded symbolic stimuli/actions (e.g., A, B, C), but more sophisticated one would require texts or images (e.g., face vs neutral images in social stroop dataset) The HCP dataset from NMA-CN contains behavioral and imaging data for 7 cognitive tests including various versions of N-back.

N-back task#

In the N-back task, participants view a sequence of stimuli, one per time, and are asked to categorize each stimulus as being either match or non-match. Stimuli are usually numbers, and feedbacks are given at both timestep and trajectory levels.

In a typical neuro setup, both accuracy and response time are measured, but here, for the sake of brevity, we focus only on accuracy of responses.

Cognitive Tests Environment#

First we develop an environment in that agents perform a cognitive test, here the N-back.

Human dataset#

We need a dataset of human perfoming a N-back test, with the following features:

  • participant_id: following the BIDS format, it contains a unique identifier for each participant.

  • trial_index: same as time_step.

  • stimulus: same as observation.

  • response: same as action, recorded response by the human subject.

  • expected_response: correct response.

  • is_correct: same as reward, whether the human subject responded correctly.

  • response_time: won’t be used here.

Here we generate a mock dataset with those features, but remember to replace this with real human data.

def generate_mock_nback_dataset(N=2,
                                response_choices=['match', 'non-match']):
  """Generate a mock dataset for the N-back task."""

  n_rows = n_participants * n_trials

  participant_ids = sorted([f'sub-{pid}' for pid in range(1, n_participants + 1)] * n_trials)
  trial_indices = list(range(1, n_trials + 1)) * n_participants
  stimulus_sequence = np.random.choice(stimulus_choices, n_rows)

  responses = np.random.choice(response_choices, n_rows)
  response_times = np.random.exponential(size=n_rows)

  df = pd.DataFrame({
      'participant_id': participant_ids,
      'trial_index': trial_indices,
      'stimulus': stimulus_sequence,
      'response': responses,
      'response_time': response_times

  # mark matchig stimuli
  _nback_stim = df['stimulus'].shift(N)
  df['expected_response'] = (df['stimulus'] == _nback_stim).map({True: 'match', False: 'non-match'})

  df['is_correct'] = (df['response'] == df['expected_response'])

  # we don't care about burn-in trials (trial < N)
  df.loc[df['trial_index'] <= N, 'is_correct'] = True
  df.loc[df['trial_index'] <= N, ['response', 'response_time', 'expected_response']] = None

  return df

# ========
# now generate the actual data with the provided function and plot some of its features
mock_nback_data = generate_mock_nback_dataset()
mock_nback_data['is_correct'] = mock_nback_data['is_correct'].astype(int)

sns.displot(data=mock_nback_data, x='response_time')
plt.suptitle('response time distribution of the mock N-back dataset', y=1.01)

sns.displot(data=mock_nback_data, x='is_correct')
plt.suptitle('Accuracy distribution of the mock N-back dataset', y=1.06)

sns.barplot(data=mock_nback_data, y='is_correct', x='participant_id')
plt.suptitle('Accuracy distribution of the mock N-back dataset', y=1.06)

../../_images/646f734f90bee2767a025c5810d65bd168c94a4989907d176b81332c066524a8.png ../../_images/ac495b6db01de23490d5bcf7c1b4851e2a80dd35cdafa147410931a99cde1882.png ../../_images/f123828c82b1b12a73062916a983eb7a95dc30fe94153e53c025fbb23b00d761.png
participant_id trial_index stimulus response response_time expected_response is_correct
0 sub-1 1 A None NaN None 1
1 sub-1 2 D None NaN None 1
2 sub-1 3 A non-match 1.697356 match 0
3 sub-1 4 F non-match 0.149110 non-match 1
4 sub-1 5 D non-match 0.277760 non-match 1

Implementation scheme#


The following cell implments N-back envinronment, that we later use to train a RL agent on human data. It is capable of performing two kinds of simulation:

  • rewards the agent once the action was correct (i.e., a normative model of the environment).

  • receives human data (or mock data if you prefer), and returns what participants performed as the observation. This is more useful for preference-based RL.

class NBack(dm_env.Environment):

  ACTIONS = ['match', 'non-match']

  def __init__(self,
      N: Number of steps to look back for the matched stimuli. Defaults to 2 (as in 2-back).

    self.N = N
    self.episode_steps = episode_steps
    self.stimuli_choices = stimuli_choices
    self.stimuli = np.empty(shape=episode_steps)  # will be filled in the `reset()`

    self._reset_next_step = True

    # whether mimic humans or reward the agent once it responds optimally.
    if human_data is None:
      self._imitate_human = False
      self.human_data = None
      self.human_subject_data = None
      self._imitate_human = True
      self.human_data = human_data
      self.human_subject_data = None

    self._action_history = []

  def reset(self):
    self._reset_next_step = False
    self._current_step = 0

    # generate a random sequence instead of relying on human data
    if self.human_data is None:
      # self.stimuli = np.random.choice(self.stimuli_choices, self.episode_steps)
      # FIXME This is a fix for acme & reverb issue with string observation. Agent should be able to handle strings
      self.stimuli = np.random.choice(len(self.stimuli_choices), self.episode_steps).astype(np.float32)
      # randomly choose a subject from the human data and follow her trials and responses.
      # FIXME should we always use one specific human subject or randomly select one in each episode?
      self.human_subject_data = self.human_data.query('participant_id == participant_id.sample().iloc[0]',
      self.stimuli = self.human_subject_data['stimulus'].to_list()
      self.stimuli = np.array([ord(s) - ord('A') + 1 for s in self.stimuli]).astype(np.float32)

    return dm_env.restart(self._observation())

  def _episode_return(self):
    if self._imitate_human:
      return np.mean(self.human_subject_data['response'] == self._action_history)
      return 0.0

  def step(self, action: int):
    if self._reset_next_step:
      return self.reset()

    agent_action = NBack.ACTIONS[action]

    if self._imitate_human:
      # if it was the same action as the human subject, then reward the agent
      human_action = self.human_subject_data['response'].iloc[self._current_step]
      step_reward = 0. if (agent_action == human_action) else -1.
      # assume the agent is rationale and doesn't want to reproduce human, reward once the response it correct
      expected_action = 'match' if (self.stimuli[self._current_step] == self.stimuli[self._current_step - self.N]) else 'non-match'
      step_reward = 0. if (agent_action == expected_action) else -1.


    self._current_step += 1

    # Check for termination.
    if self._current_step == self.stimuli.shape[0]:
      self._reset_next_step = True
      # we are using the mean of total time step rewards as the episode return
      return dm_env.termination(reward=self._episode_return(),
      return dm_env.transition(reward=step_reward,

  def observation_spec(self):
    return dm_env.specs.BoundedArray(
        name='nback_stimuli', minimum=0, maximum=len(self.stimuli_choices) + 1)

  def action_spec(self):
    return dm_env.specs.DiscreteArray(

  def _observation(self):

    # agent observes only the current trial
    # obs = self.stimuli[self._current_step - 1]

    # agents observe stimuli up to the current trial
    obs = self.stimuli[:self._current_step+1].copy()
    obs = np.pad(obs,(0, len(self.stimuli) - len(obs)))

    return obs

  def plot_state(self):
    """Display current state of the environment.

     Note: `M` mean `match`, and `.` is a `non-match`.
    stimuli = self.stimuli[:self._current_step - 1]
    actions = ['M' if a=='match' else '.' for a in self._action_history[:self._current_step - 1]]
    return HTML(
        f'<b>Environment ({self.N}-back):</b><br />'
        f'<pre><b>Stimuli:</b> {"".join(map(str,map(int,stimuli)))}</pre>'
        f'<pre><b>Actions:</b> {"".join(actions)}</pre>'

  def create_environment():
    """Utility function to create a N-back environment and its spec."""

    # Make sure the environment outputs single-precision floats.
    environment = wrappers.SinglePrecisionWrapper(NBack())

    # Grab the spec of the environment.
    environment_spec = specs.make_environment_spec(environment)

    return environment, environment_spec

Define a random agent#

For more information you can refer to NMA-DL W3D2 Basic Reinforcement learning.

class RandomAgent(acme.Actor):

  def __init__(self, environment_spec):
    """Gets the number of available actions from the environment spec."""
    self._num_actions = environment_spec.actions.num_values

  def select_action(self, observation):
    """Selects an action uniformly at random."""
    action = np.random.randint(self._num_actions)
    return action

  def observe_first(self, timestep):
    """Does not record as the RandomAgent has no use for data."""

  def observe(self, action, next_timestep):
    """Does not record as the RandomAgent has no use for data."""

  def update(self):
    """Does not update as the RandomAgent does not learn from data."""

Initialize the environment and the agent#

env, env_spec = NBack.create_environment()
agent = RandomAgent(env_spec)

print('actions:\n', env_spec.actions)
print('observations:\n', env_spec.observations)
print('rewards:\n', env_spec.rewards)
 DiscreteArray(shape=(), dtype=int32, name=action, minimum=0, maximum=1, num_values=2)
 BoundedArray(shape=(32,), dtype=dtype('float32'), name='nback_stimuli', minimum=0.0, maximum=7.0)
 Array(shape=(), dtype=dtype('float32'), name='reward')

Run the loop#

# fitting parameters
n_episodes = 1_000
n_total_steps = 0
log_loss = False
n_steps = n_episodes * 32
all_returns = []

# main loop
for episode in range(n_episodes):
  episode_steps = 0
  episode_return = 0
  episode_loss = 0

  start_time = time.time()

  timestep = env.reset()

  # Make the first observation.

  # Run an episode
  while not timestep.last():

    # DEBUG
    # print(timestep)

    # Generate an action from the agent's policy and step the environment.
    action = agent.select_action(timestep.observation)
    timestep = env.step(action)

    # Have the agent observe the timestep and let the agent update itself.
    agent.observe(action, next_timestep=timestep)

    # Book-keeping.
    episode_steps += 1
    n_total_steps += 1
    episode_return += timestep.reward

    if log_loss:
      episode_loss += agent.last_loss

    if n_steps is not None and n_total_steps >= n_steps:

  # Collect the results and combine with counts.
  steps_per_second = episode_steps / (time.time() - start_time)
  result = {
      'episode': episode,
      'episode_length': episode_steps,
      'episode_return': episode_return,
  if log_loss:
    result['loss_avg'] = episode_loss/episode_steps


  # Log the given results.

  if n_steps is not None and n_total_steps >= n_steps:


# Histogram of all returns
sns.histplot(all_returns, stat="density", kde=True, bins=12)
plt.xlabel('Return [a.u.]')

Note: You can simplify the environment loop using DeepMind Acme.

# init a new N-back environment
env, env_spec = NBack.create_environment()

# DEBUG fake testing environment.
# Uncomment this to debug your agent without using the N-back environment.
# env = fakes.DiscreteEnvironment(
#     num_actions=2,
#     num_observations=1000,
#     obs_dtype=np.float32,
#     episode_length=32)
# env_spec = specs.make_environment_spec(env)
def dqn_make_network(action_spec: specs.DiscreteArray) -> snt.Module:
  return snt.Sequential([
      snt.nets.MLP([50, 50, action_spec.num_values]),

# construct a DQN agent
agent = dqn.DQN(

Now, we run the environment loop with the DQN agent and print the training log.

# training loop
loop = EnvironmentLoop(env, agent, logger=loggers.InMemoryLogger())

# print logs
logs = pd.DataFrame(loop._logger._data)
episode_length episode_return steps_per_second episodes steps
995 32 -10.0 329.379165 996 31872
996 32 -11.0 326.324034 997 31904
997 32 -9.0 373.017676 998 31936
998 32 -11.0 309.737031 999 31968
999 32 -9.0 405.329983 1000 32000