Q-Learning vs. SARSA
Two fundamental RL algorithms, both remarkably useful, even today. One of the primary reasons for their popularity is that they are simple, because by default they only work with discrete state and action spaces. Of course it is possible to improve them to work with continuous state/action spaces, but consider discretizing to keep things rediculously simple.
In this workshop I’m going to reproduce the cliffworld example in the book. In the future I will extend and expand on this so you can develop your own algorithms and environments.
A note on usage
Note that this notebook might not work on your machine because simple_rl forces TkAgg on some machines. See https://github.com/david-abel/simple_rl/issues/40
Also, Pygame is notoriously picky and expects loads of compiler/system related libraries.
I managed to get this working on the following notebook:
docker run -it -p 8888:8888 jupyter/scipy-notebook:54462805efcb
This code is untested on any other notebook.
TODO: migrate away from simple rl and pygame. TODO: Create dedicated q-learning and sarsa notebooks.
!pip install pygame==1.9.6 pandas==1.0.5 matplotlib==3.2.1 > /dev/null
!pip install --upgrade git+git://github.com/david-abel/simple_rl.git@77c0d6b910efbe8bdd5f4f87337c5bc4aed0d79c > /dev/null
import matplotlib
matplotlib.use("agg", force=True)
Running command git clone -q git://github.com/david-abel/simple_rl.git /tmp/pip-req-build-zgmzgexc
SARSA Agent
Now that lot is sorted, the next issue is that simple_rl doesn’t have a SARSA agent. So I had to implement one. This is complicated by the abstractions enforced by simple_rl
, but the key section is in the update
function. This is choosing and updating an action at the same time. This is different to Q-learning, in that the action chosen is independent of the action that is updated.
I’ll make this clearer in future updates to these notebooks.
'''
SARSAAgent.py
Implementation of a SARSA agent for simple-rl
'''
# Python imports.
import numpy as np
import math
from collections import defaultdict
# Other imports.
from simple_rl.agents import Agent, QLearningAgent
from simple_rl.tasks import GridWorldMDP
class SARSAAgent(QLearningAgent):
def __init__(self, actions, goal_reward, name="SARSA",
alpha=0.1, gamma=0.99, epsilon=0.1, explore="uniform", anneal=False):
self.goal_reward = goal_reward
QLearningAgent.__init__(
self,
actions=list(actions),
name=name,
alpha=alpha,
gamma=gamma,
epsilon=epsilon,
explore=explore,
anneal=anneal)
def policy(self, state):
return self.get_max_q_action(state)
def act(self, state, reward, learning=True):
'''
This is mostly the same as the base QLearningAgent class. Except that
the update procedure now generates the action.
'''
if learning:
action = self.update(
self.prev_state, self.prev_action, reward, state)
else:
if self.explore == "softmax":
# Softmax exploration
action = self.soft_max_policy(state)
else:
# Uniform exploration
action = self.epsilon_greedy_q_policy(state)
self.prev_state = state
self.prev_action = action
self.step_number += 1
# Anneal params.
if learning and self.anneal:
self._anneal()
return action
def update(self, state, action, reward, next_state):
'''
Args:
state (State)
action (str)
reward (float)
next_state (State)
Summary:
Updates the internal Q Function according to the Bellman Equation
using a SARSA update
'''
if self.explore == "softmax":
# Softmax exploration
next_action = self.soft_max_policy(next_state)
else:
# Uniform exploration
next_action = self.epsilon_greedy_q_policy(next_state)
# Update the Q Function.
prev_q_val = self.get_q_value(state, action)
next_q_val = self.get_q_value(next_state, next_action)
self.q_func[state][action] = prev_q_val + self.alpha * \
(reward + self.gamma * next_q_val - prev_q_val)
return next_action
Warning: Tensorflow not installed.
Warning: OpenAI gym not installed.
Experiment
Now I’m ready to run the experiment with the helpers from simple_rl
. Basically I’m training an agent for a maximum of 100 steps, for 500 episodes, averaging over 100 repeats.
Feel free to tinker with the settimgs.
import pandas as pd
import numpy as np
from simple_rl.agents import QLearningAgent, RandomAgent
from simple_rl.tasks import GridWorldMDP
from simple_rl.run_experiments import run_single_agent_on_mdp
np.random.seed(42)
instances = 100
n_episodes = 500
alpha = 0.1
epsilon = 0.1
# Setup MDP, Agents.
mdp = GridWorldMDP(
width=10, height=4, init_loc=(1, 1), goal_locs=[(10, 1)],
lava_locs=[(x, 1) for x in range(2, 10)], is_lava_terminal=True, gamma=1.0, walls=[], slip_prob=0.0, step_cost=1.0, lava_cost=100.0)
print("Q-Learning")
rewards = np.zeros((n_episodes, instances))
for instance in range(instances):
ql_agent = QLearningAgent(
mdp.get_actions(),
epsilon=epsilon,
alpha=alpha)
# mdp.visualize_learning(ql_agent, delay=0.0001)
print(" Instance " + str(instance) + " of " + str(instances) + ".")
terminal, num_steps, reward = run_single_agent_on_mdp(
ql_agent, mdp, episodes=n_episodes, steps=100)
rewards[:, instance] = reward
df = pd.DataFrame(rewards.mean(axis=1))
df.to_json("q_learning_cliff_rewards.json")
print("SARSA")
rewards = np.zeros((n_episodes, instances))
for instance in range(instances):
sarsa_agent = SARSAAgent(
mdp.get_actions(),
goal_reward=0,
epsilon=epsilon,
alpha=alpha)
# mdp.visualize_learning(sarsa_agent, delay=0.0001)
print(" Instance " + str(instance) + " of " + str(instances) + ".")
terminal, num_steps, reward = run_single_agent_on_mdp(
sarsa_agent, mdp, episodes=n_episodes, steps=100)
rewards[:, instance] = reward
df = pd.DataFrame(rewards.mean(axis=1))
df.to_json("sarsa_cliff_rewards.json")
Q-Learning
Instance 0 of 100.
Instance 1 of 100.
Instance 2 of 100.
Instance 3 of 100.
Instance 4 of 100.
Instance 5 of 100.
Instance 6 of 100.
Instance 7 of 100.
Instance 8 of 100.
Instance 9 of 100.
Instance 10 of 100.
Instance 11 of 100.
Instance 12 of 100.
Instance 13 of 100.
Instance 14 of 100.
Instance 15 of 100.
Instance 16 of 100.
Instance 17 of 100.
Instance 18 of 100.
Instance 19 of 100.
Instance 20 of 100.
Instance 21 of 100.
Instance 22 of 100.
Instance 23 of 100.
Instance 24 of 100.
Instance 25 of 100.
Instance 26 of 100.
Instance 27 of 100.
Instance 28 of 100.
Instance 29 of 100.
Instance 30 of 100.
Instance 31 of 100.
Instance 32 of 100.
Instance 33 of 100.
Instance 34 of 100.
Instance 35 of 100.
Instance 36 of 100.
Instance 37 of 100.
Instance 38 of 100.
Instance 39 of 100.
Instance 40 of 100.
Instance 41 of 100.
Instance 42 of 100.
Instance 43 of 100.
Instance 44 of 100.
Instance 45 of 100.
Instance 46 of 100.
Instance 47 of 100.
Instance 48 of 100.
Instance 49 of 100.
Instance 50 of 100.
Instance 51 of 100.
Instance 52 of 100.
Instance 53 of 100.
Instance 54 of 100.
Instance 55 of 100.
Instance 56 of 100.
Instance 57 of 100.
Instance 58 of 100.
Instance 59 of 100.
Instance 60 of 100.
Instance 61 of 100.
Instance 62 of 100.
Instance 63 of 100.
Instance 64 of 100.
Instance 65 of 100.
Instance 66 of 100.
Instance 67 of 100.
Instance 68 of 100.
Instance 69 of 100.
Instance 70 of 100.
Instance 71 of 100.
Instance 72 of 100.
Instance 73 of 100.
Instance 74 of 100.
Instance 75 of 100.
Instance 76 of 100.
Instance 77 of 100.
Instance 78 of 100.
Instance 79 of 100.
Instance 80 of 100.
Instance 81 of 100.
Instance 82 of 100.
Instance 83 of 100.
Instance 84 of 100.
Instance 85 of 100.
Instance 86 of 100.
Instance 87 of 100.
Instance 88 of 100.
Instance 89 of 100.
Instance 90 of 100.
Instance 91 of 100.
Instance 92 of 100.
Instance 93 of 100.
Instance 94 of 100.
Instance 95 of 100.
Instance 96 of 100.
Instance 97 of 100.
Instance 98 of 100.
Instance 99 of 100.
SARSA
Instance 0 of 100.
Instance 1 of 100.
Instance 2 of 100.
Instance 3 of 100.
Instance 4 of 100.
Instance 5 of 100.
Instance 6 of 100.
Instance 7 of 100.
Instance 8 of 100.
Instance 9 of 100.
Instance 10 of 100.
Instance 11 of 100.
Instance 12 of 100.
Instance 13 of 100.
Instance 14 of 100.
Instance 15 of 100.
Instance 16 of 100.
Instance 17 of 100.
Instance 18 of 100.
Instance 19 of 100.
Instance 20 of 100.
Instance 21 of 100.
Instance 22 of 100.
Instance 23 of 100.
Instance 24 of 100.
Instance 25 of 100.
Instance 26 of 100.
Instance 27 of 100.
Instance 28 of 100.
Instance 29 of 100.
Instance 30 of 100.
Instance 31 of 100.
Instance 32 of 100.
Instance 33 of 100.
Instance 34 of 100.
Instance 35 of 100.
Instance 36 of 100.
Instance 37 of 100.
Instance 38 of 100.
Instance 39 of 100.
Instance 40 of 100.
Instance 41 of 100.
Instance 42 of 100.
Instance 43 of 100.
Instance 44 of 100.
Instance 45 of 100.
Instance 46 of 100.
Instance 47 of 100.
Instance 48 of 100.
Instance 49 of 100.
Instance 50 of 100.
Instance 51 of 100.
Instance 52 of 100.
Instance 53 of 100.
Instance 54 of 100.
Instance 55 of 100.
Instance 56 of 100.
Instance 57 of 100.
Instance 58 of 100.
Instance 59 of 100.
Instance 60 of 100.
Instance 61 of 100.
Instance 62 of 100.
Instance 63 of 100.
Instance 64 of 100.
Instance 65 of 100.
Instance 66 of 100.
Instance 67 of 100.
Instance 68 of 100.
Instance 69 of 100.
Instance 70 of 100.
Instance 71 of 100.
Instance 72 of 100.
Instance 73 of 100.
Instance 74 of 100.
Instance 75 of 100.
Instance 76 of 100.
Instance 77 of 100.
Instance 78 of 100.
Instance 79 of 100.
Instance 80 of 100.
Instance 81 of 100.
Instance 82 of 100.
Instance 83 of 100.
Instance 84 of 100.
Instance 85 of 100.
Instance 86 of 100.
Instance 87 of 100.
Instance 88 of 100.
Instance 89 of 100.
Instance 90 of 100.
Instance 91 of 100.
Instance 92 of 100.
Instance 93 of 100.
Instance 94 of 100.
Instance 95 of 100.
Instance 96 of 100.
Instance 97 of 100.
Instance 98 of 100.
Instance 99 of 100.
Results
Now you can plot the results for each of the agents.
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
data_files = [("Q-Learning", "q_learning_cliff_rewards.json"),
("SARSA", "sarsa_cliff_rewards.json")]
fig, ax = plt.subplots()
for j, (name, data_file) in enumerate(data_files):
df = pd.read_json(data_file)
x = range(len(df))
y = df.sort_index().values
ax.plot(x,
y,
linestyle='solid',
label=name)
ax.set_xlabel('Episode')
ax.set_ylabel('Averaged Sum of Rewards')
ax.legend(loc='lower right')
plt.show()
Policy Results
If you’re not in a notebook, then you can use simple_rl
to visualize your policy below.
# mdp.visualize_agent(sarsa_agent)