Source code for simple_rl.planning.ValueIterationClass
# Python imports.
from __future__ import print_function
from collections import defaultdict
import random
# Check python version for queue module.
import sys
if sys.version_info[0] < 3:
import Queue as queue
else:
import queue
# Other imports.
from simple_rl.planning.PlannerClass import Planner
[docs]class ValueIteration(Planner):
def __init__(self, mdp, name="value_iter", delta=0.0001, max_iterations=500, sample_rate=3):
'''
Args:
mdp (MDP)
delta (float): After an iteration if VI, if no change more than @\delta has occurred, terminates.
max_iterations (int): Hard limit for number of iterations.
sample_rate (int): Determines how many samples from @mdp to take to estimate T(s' | s, a).
horizon (int): Number of steps before terminating.
'''
Planner.__init__(self, mdp, name=name)
self.delta = delta
self.max_iterations = max_iterations
self.sample_rate = sample_rate
self.value_func = defaultdict(float)
self.reachability_done = False
self.has_computed_matrix = False
self.bellman_backups = 0
def _compute_matrix_from_trans_func(self):
if self.has_computed_matrix:
self._compute_reachable_state_space()
# We've already run this, just return.
return
self.trans_dict = defaultdict(lambda:defaultdict(lambda:defaultdict(float)))
# K: state
# K: a
# K: s_prime
# V: prob
for s in self.get_states():
for a in self.actions:
for sample in range(self.sample_rate):
s_prime = self.transition_func(s, a)
self.trans_dict[s][a][s_prime] += 1.0 / self.sample_rate
self.has_computed_matrix = True
[docs] def get_gamma(self):
return self.mdp.get_gamma()
[docs] def get_num_states(self):
if not self.reachability_done:
self._compute_reachable_state_space()
return len(self.states)
[docs] def get_states(self):
if self.reachability_done:
return list(self.states)
else:
self._compute_reachable_state_space()
return list(self.states)
[docs] def get_value(self, s):
'''
Args:
s (State)
Returns:
(float)
'''
return self._compute_max_qval_action_pair(s)[0]
[docs] def get_q_value(self, s, a):
'''
Args:
s (State)
a (str): action
Returns:
(float): The Q estimate given the current value function @self.value_func.
'''
# Compute expected value.
expected_future_val = 0
for s_prime in self.trans_dict[s][a].keys():
expected_future_val += self.trans_dict[s][a][s_prime] * self.value_func[s_prime]
return self.reward_func(s,a) + self.gamma*expected_future_val
def _compute_reachable_state_space(self):
'''
Summary:
Starting with @self.start_state, determines all reachable states
and stores them in self.states.
'''
if self.reachability_done:
return
state_queue = queue.Queue()
state_queue.put(self.init_state)
self.states.add(self.init_state)
while not state_queue.empty():
s = state_queue.get()
for a in self.actions:
for samples in range(self.sample_rate): # Take @sample_rate samples to estimate E[V]
next_state = self.transition_func(s,a)
if next_state not in self.states:
self.states.add(next_state)
state_queue.put(next_state)
self.reachability_done = True
[docs] def run_vi(self):
'''
Returns:
(tuple):
1. (int): num iterations taken.
2. (float): value.
Summary:
Runs ValueIteration and fills in the self.value_func.
'''
# Algorithm bookkeeping params.
iterations = 0
max_diff = float("inf")
self._compute_matrix_from_trans_func()
state_space = self.get_states()
self.bellman_backups = 0
# Main loop.
while max_diff > self.delta and iterations < self.max_iterations:
max_diff = 0
for s in state_space:
self.bellman_backups += 1
if s.is_terminal():
continue
max_q = float("-inf")
for a in self.actions:
q_s_a = self.get_q_value(s, a)
max_q = q_s_a if q_s_a > max_q else max_q
# Check terminating condition.
max_diff = max(abs(self.value_func[s] - max_q), max_diff)
# Update value.
self.value_func[s] = max_q
iterations += 1
value_of_init_state = self._compute_max_qval_action_pair(self.init_state)[0]
self.has_planned = True
return iterations, value_of_init_state
[docs] def get_num_backups_in_recent_run(self):
if self.has_planned:
return self.bellman_backups
else:
print("Warning: asking for num Bellman backups, but VI has not been run.")
return 0
[docs] def print_value_func(self):
for key in self.value_func.keys():
print(key, ":", self.value_func[key])
[docs] def plan(self, state=None, horizon=100):
'''
Args:
state (State)
horizon (int)
Returns:
(list): List of actions
'''
state = self.mdp.get_init_state() if state is None else state
if self.has_planned is False:
print("Warning: VI has not been run. Plan will be random.")
action_seq = []
state_seq = [state]
steps = 0
while (not state.is_terminal()) and steps < horizon:
next_action = self._get_max_q_action(state)
action_seq.append(next_action)
state = self.transition_func(state, next_action)
state_seq.append(state)
steps += 1
return action_seq, state_seq
def _get_max_q_action(self, state):
'''
Args:
state (State)
Returns:
(str): The action with the max q value in the given @state.
'''
return self._compute_max_qval_action_pair(state)[1]
[docs] def get_max_q_actions(self, state):
'''
Args:
state (State)
Returns:
(list): List of actions with the max q value in the given @state.
'''
max_q_val = self.get_value(state)
best_action_list = []
# Find best action (action w/ current max predicted Q value)
for action in self.actions:
q_s_a = self.get_q_value(state, action)
if q_s_a == max_q_val:
best_action_list.append(action)
return best_action_list
[docs] def policy(self, state):
'''
Args:
state (State)
Returns:
(str): Action
Summary:
For use in a FixedPolicyAgent.
'''
return self._get_max_q_action(state)
def _compute_max_qval_action_pair(self, state):
'''
Args:
state (State)
Returns:
(tuple) --> (float, str): where the float is the Qval, str is the action.
'''
# Grab random initial action in case all equal
max_q_val = float("-inf")
best_action = self.actions[0]
# Find best action (action w/ current max predicted Q value)
for action in self.actions:
q_s_a = self.get_q_value(state, action)
if q_s_a > max_q_val:
max_q_val = q_s_a
best_action = action
return max_q_val, best_action