Source code for q_derived.base_q_derived_soft_policy

from abc import abstractmethod
import random
from typing import Any, Mapping, Union

from gridmind.policies.soft.base_soft_policy import BaseSoftPolicy
import numpy as np


[docs]class BaseQDerivedSoftPolicy(BaseSoftPolicy): def __init__( self, Q: Union[Any, Mapping], epsilon: float, num_actions: int ) -> None: super().__init__()
[docs] self.Q = Q
[docs] self.epsilon = epsilon
[docs] self.num_actions = num_actions
[docs] self.action_space = None
[docs] def update(self, state, action): raise NotImplementedError()
[docs] def get_epsilon(self): return self.epsilon
[docs] def set_epsilon(self, value: float): assert value <= 1.0 and value >= 0.0, "epsilon must be in the range [0,1]" self.epsilon = value
@abstractmethod
[docs] def decay_epsilon(self): raise NotImplementedError()
[docs] def _get_random_action(self, action_mask=None): if action_mask is not None: valid_actions = np.where(action_mask)[0] random_action = np.random.choice(valid_actions) return random_action if self.action_space: random_action = self.action_space.sample() return random_action random_action = random.randint(0, self.num_actions - 1) return random_action
[docs] def get_action(self, state, action_mask=None): if random.random() <= self.epsilon: action = self._get_random_action(action_mask=action_mask) else: action = self._get_greedy_action(state, action_mask=action_mask) return action
[docs] def get_action_prob(self, state, action, action_mask=None): greedy_action = self._get_greedy_action(state, action_mask=action_mask) num_valid_actions = ( np.sum(action_mask) if action_mask is not None else self.num_actions ) each_random_action_prob = self.epsilon / num_valid_actions greedy_action_prob = 1.0 - self.epsilon + each_random_action_prob action_probs = ( greedy_action_prob if action == greedy_action else each_random_action_prob ) return action_probs
[docs] def get_all_action_probabilities(self, states, action_mask=None): action_probs = [] for state in states: state_action_probs = [] for action in range(self.num_actions): prob = self.get_action_prob(state, action, action_mask=action_mask) state_action_probs.append(prob) action_probs.append(state_action_probs) action_probs = np.array(action_probs).squeeze() return action_probs
[docs] def get_action_deterministic(self, state, action_mask=None): action = self._get_greedy_action(state=state, action_mask=action_mask) return action
@abstractmethod
[docs] def _get_greedy_action(self, state, action_mask=None): raise NotImplementedError()