Source code for simple_rl.mdp.MDPDistributionClass

''' MDPDistributionClass.py: Contains the MDP Distribution Class. '''

# Python imports.
from __future__ import print_function
import numpy as np
from collections import defaultdict

[docs]class MDPDistribution(object): ''' Class for distributions over MDPs. ''' def __init__(self, mdp_prob_dict, horizon=0): ''' Args: mdp_prob_dict (dict): Key (MDP) Val (float): Represents the probability with which the MDP is sampled. Notes: @mdp_prob_dict can also be a list, in which case the uniform distribution is used. ''' if type(mdp_prob_dict) == list or len(mdp_prob_dict.values()) == 0: # Assume uniform if no probabilities are provided. mdp_prob = 1.0 / len(mdp_prob_dict.keys()) new_dict = defaultdict(float) for mdp in mdp_prob_dict: new_dict[mdp] = mdp_prob mdp_prob_dict = new_dict self.horizon = horizon self.mdp_prob_dict = mdp_prob_dict
[docs] def get_parameters(self): ''' Returns: (dict) key=param_name (str) --> val=param_val (object). ''' param_dict = {} param_dict["mdp_prob_dict"] = self.mdp_prob_dict param_dict["horizon"] = self.horizon return param_dict
[docs] def remove_mdps(self, mdp_list): ''' Args: (list): Contains MDP instances. Summary: Removes each mdp in @mdp_list from self.mdp_prob_dict and recomputes the distribution. ''' for mdp in mdp_list: try: self.mdp_prob_dict.pop(mdp) except KeyError: raise ValueError("(simple-rl Error): Trying to remove MDP (" + str(mdp) + ") from MDP Distribution that doesn't contain it.") self._normalize()
[docs] def remove_mdp(self, mdp): ''' Args: (MDP) Summary: Removes @mdp from self.mdp_prob_dict and recomputes the distribution. ''' try: self.mdp_prob_dict.pop(mdp) except KeyError: raise ValueError("(simple-rl Error): Trying to remove MDP (" + str(mdp) + ") from MDP Distribution that doesn't contain it.") self._normalize()
def _normalize(self): total = sum(self.mdp_prob_dict.values()) for mdp in self.mdp_prob_dict.keys(): self.mdp_prob_dict[mdp] = self.mdp_prob_dict[mdp] / total
[docs] def get_all_mdps(self, prob_threshold=0): ''' Args: prob_threshold (float) Returns: (list): Contains all mdps in the distribution with Pr. > @prob_threshold. ''' return [mdp for mdp in self.mdp_prob_dict.keys() if self.mdp_prob_dict[mdp] > prob_threshold]
[docs] def get_horizon(self): return self.horizon
[docs] def get_actions(self): return list(self.mdp_prob_dict.keys())[0].get_actions()
[docs] def get_gamma(self): ''' Notes: Not all MDPs in the distribution are guaranteed to share gamma. ''' return list(self.mdp_prob_dict.keys())[0].get_gamma()
[docs] def get_reward_func(self, avg=True): if avg: self.get_average_reward_func() else: self.get_all_mdps()[0].get_reward_func()
[docs] def get_average_reward_func(self): def _avg_r_func(s, a): r = 0.0 for m in self.mdp_prob_dict.keys(): r += m.reward_func(s, a) * self.mdp_prob_dict[m] return r return _avg_r_func
[docs] def get_init_state(self): ''' Notes: Not all MDPs in the distribution are guaranteed to share init states. ''' return list(self.mdp_prob_dict.keys())[0].get_init_state()
[docs] def get_num_mdps(self): return len(self.mdp_prob_dict.keys())
[docs] def get_mdps(self): return self.mdp_prob_dict.keys()
[docs] def get_prob_of_mdp(self, mdp): if mdp in self.mdp_prob_dict.keys(): return self.mdp_prob_dict[mdp] else: return 0.0
[docs] def set_gamma(self, new_gamma): for mdp in self.mdp_prob_dict.keys(): mdp.set_gamma(new_gamma)
[docs] def sample(self, k=1): ''' Args: k (int) Returns: (List of MDP): Samples @k mdps without replacement. ''' sampled_mdp_id_list = np.random.multinomial(k, list(self.mdp_prob_dict.values())).tolist() indices = [i for i, x in enumerate(sampled_mdp_id_list) if x > 0] if k == 1: return list(self.mdp_prob_dict.keys())[indices[0]] mdps_to_return = [] for i in indices: for copies in range(sampled_mdp_id_list[i]): mdps_to_return.append(list(self.mdp_prob_dict.keys())[i]) return mdps_to_return
def __str__(self): ''' Notes: Not all MDPs are guaranteed to share a name (for instance, might include dimensions). ''' return "lifelong-" + str(list(self.mdp_prob_dict.keys())[0])
[docs]def main(): # Simple test code. from simple_rl.tasks import GridWorldMDP mdp_distr = {} height, width = 8, 8 prob_list = [0.0, 0.1, 0.2, 0.3, 0.4] for i in range(len(prob_list)): next_mdp = GridWorldMDP(width=width, height=width, init_loc=(1, 1), goal_locs=r.sample(zip(range(1, width + 1), [height] * width), 2), is_goal_terminal=True) mdp_distr[next_mdp] = prob_list[i] m = MDPDistribution(mdp_distr) m.sample()
if __name__ == "__main__": main()