'''
RMaxAgentClass.py: Class for an RMaxAgent from [Brafman and Tennenholtz 2003].
Notes:
- Assumes WLOG reward function codomain is [0,1] (so RMAX is 1.0)
'''
# Python imports.
import random
from collections import defaultdict
# Local classes.
from simple_rl.agents.AgentClass import Agent
[docs]class RMaxAgent(Agent):
'''
Implementation for an R-Max Agent [Brafman and Tennenholtz 2003]
'''
def __init__(self, actions, gamma=0.95, horizon=4, s_a_threshold=1, name="RMax-h"):
name = name + str(horizon) if name[-2:] == "-h" else name
Agent.__init__(self, name=name, actions=actions, gamma=gamma)
self.rmax = 1.0
self.horizon = horizon
self.s_a_threshold = s_a_threshold
self.reset()
[docs] def reset(self):
'''
Summary:
Resets the agent back to its tabula rasa config.
'''
self.rewards = defaultdict(lambda : defaultdict(list)) # S --> A --> [r_1, ...]
self.transitions = defaultdict(lambda : defaultdict(lambda : defaultdict(int))) # S --> A --> S' --> counts
self.r_s_a_counts = defaultdict(lambda : defaultdict(int)) # S --> A --> #rs
self.t_s_a_counts = defaultdict(lambda : defaultdict(int)) # S --> A --> #ts
self.prev_state = None
self.prev_action = None
[docs] def get_num_known_sa(self):
return sum([self.is_known(s,a) for s,a in self.r_s_a_counts.keys()])
[docs] def is_known(self, s, a):
return self.r_s_a_counts[s][a] >= self.s_a_threshold and self.t_s_a_counts[s][a] >= self.s_a_threshold
[docs] def act(self, state, reward):
# Update given s, a, r, s' : self.prev_state, self.prev_action, reward, state
self.update(self.prev_state, self.prev_action, reward, state)
# Compute best action.
action = self.get_max_q_action(state)
# Update pointers.
self.prev_action = action
self.prev_state = state
return action
[docs] def update(self, state, action, reward, next_state):
'''
Args:
state (State)
action (str)
reward (float)
next_state (State)
Summary:
Updates T and R.
'''
if state != None and action != None:
if self.r_s_a_counts[state][action] <= self.s_a_threshold:
# Add new data points if we haven't seen this s-a enough.
self.rewards[state][action] += [reward]
self.r_s_a_counts[state][action] += 1
if self.t_s_a_counts[state][action] <= self.s_a_threshold:
self.transitions[state][action][next_state] += 1
self.t_s_a_counts[state][action] += 1
def _compute_max_qval_action_pair(self, state, horizon=None):
'''
Args:
state (State)
horizon (int): Indicates the level of recursion depth for computing Q.
Returns:
(tuple) --> (float, str): where the float is the Qval, str is the action.
'''
# If this is the first call, use the default horizon.
if horizon is None:
horizon = self.horizon
# Grab random initial action in case all equal
best_action = random.choice(self.actions)
max_q_val = self.get_q_value(state, best_action, horizon)
# Find best action (action w/ current max predicted Q value)
for action in self.actions:
q_s_a = self.get_q_value(state, action, horizon)
if q_s_a > max_q_val:
max_q_val = q_s_a
best_action = action
return max_q_val, best_action
[docs] def get_max_q_action(self, state, horizon=None):
'''
Args:
state (State)
horizon (int): Indicates the level of recursion depth for computing Q.
Returns:
(str): The string associated with the action with highest Q value.
'''
# If this is the first call, use the default horizon.
if horizon is None:
horizon = self.horizon
return self._compute_max_qval_action_pair(state, horizon)[1]
[docs] def get_max_q_value(self, state, horizon=None):
'''
Args:
state (State)
horizon (int): Indicates the level of recursion depth for computing Q.
Returns:
(float): The Q value of the best action in this state.
'''
# If this is the first call, use the default horizon.
if horizon is None:
horizon = self.horizon
return self._compute_max_qval_action_pair(state, horizon)[0]
[docs] def get_q_value(self, state, action, horizon=None):
'''
Args:
state (State)
action (str)
horizon (int): Indicates the level of recursion depth for computing Q.
Returns:
(float)
'''
# If this is the first call, use the default horizon.
if horizon is None:
horizon = self.horizon
if horizon <= 0 or state.is_terminal():
# If we're not looking any further.
return self._get_reward(state, action)
# Compute future return.
expected_future_return = self.gamma*self._compute_exp_future_return(state, action, horizon)
q_val = self._get_reward(state, action) + expected_future_return# self.q_func[(state, action)] = self._get_reward(state, action) + expected_future_return
return q_val
def _compute_exp_future_return(self, state, action, horizon=None):
'''
Args:
state (State)
action (str)
horizon (int): Recursion depth to compute Q
Return:
(float): Discounted expected future return from applying @action in @state.
'''
# If this is the first call, use the default horizon.
if horizon is None:
horizon = self.horizon
next_state_dict = self.transitions[state][action]
denominator = float(sum(next_state_dict.values()))
state_weights = defaultdict(float)
for next_state in next_state_dict.keys():
count = next_state_dict[next_state]
state_weights[next_state] = (count / denominator)
weighted_future_returns = [self.get_max_q_value(next_state, horizon-1) * state_weights[next_state] for next_state in next_state_dict.keys()]
return sum(weighted_future_returns)
def _get_reward(self, state, action):
'''
Args:
state (State)
action (str)
Returns:
Believed reward of executing @action in @state. If R(s,a) is unknown
for this s,a pair, return self.rmax. Otherwise, return the MLE.
'''
if self.r_s_a_counts[state][action] >= self.s_a_threshold:
# Compute MLE if we've seen this s,a pair enough.
rewards_s_a = self.rewards[state][action]
return float(sum(rewards_s_a)) / len(rewards_s_a)
else:
# Otherwise return rmax.
return self.rmax