Source code for src.gridmind.wrappers.policy_wrappers.epsilon_randomized_policy_wrapper

import random
from gridmind.policies.base_policy import BasePolicy
from gridmind.wrappers.policy_wrappers.base_policy_wrapper import BasePolicyWrapper


[docs]class EpsilonRandomizedPolicyWrapper(BasePolicyWrapper): def __init__(self, policy: BasePolicy, num_actions: int, epsilon: float = 0.2): super().__init__(policy)
[docs] self.epsilon = epsilon
[docs] self.num_actions = num_actions
[docs] def get_action(self, state): if random.random() < self.epsilon: return random.randint(0, self.num_actions - 1) else: return self.policy.get_action(state)
[docs] def get_action_prob(self, state, action): policy_action_prob = self.policy.get_action_prob(state, action) action_prob = ( 1 - self.epsilon ) * policy_action_prob + self.epsilon / self.num_actions return action_prob