Source code for q_derived.q_network_derived_epsilon_greedy_policy

from gridmind.policies.soft.q_derived.base_q_derived_soft_policy import (
    BaseQDerivedSoftPolicy,
)
import torch


[docs]class QNetworkDerivedEpsilonGreedyPolicy(BaseQDerivedSoftPolicy): def __init__( self, q_network: torch.nn.Module, num_actions: int, action_space=None, epsilon=0.1, allow_decay=True, epsilon_min=0.001, decay_rate=0.01, ): super().__init__(Q=q_network, epsilon=epsilon, num_actions=num_actions)
[docs] self.action_space = action_space
[docs] self.allow_decay = allow_decay
[docs] self.epsilon_min = epsilon_min
[docs] self.decay_rate = decay_rate
[docs] self.device = next(self.Q.parameters()).device
assert 0 <= epsilon <= 1, "epsilon must be in rage 0 to 1." assert ( num_actions == self.action_space.n if self.action_space is not None else True ), "Provided num_actions does not match with number of actions in the provided action_space."
[docs] def get_network(self): return self.Q
[docs] def set_network(self, network): self.Q = network self.device = next(self.Q.parameters()).device
[docs] def update(self, state, action): raise Exception( "This policy is derived from q_network. Instead of directly updating the action to take in a state, please update the state-action value. Use update_q method instead." )
[docs] def update_q(self, state, action, value: float): raise Exception( f"{self.__class__.__name__} does not support updating Q values directly." )
[docs] def _get_greedy_action(self, state): state = state.to(self.device) action = torch.argmax(self.Q(state)).cpu().detach().item() assert ( action in self.action_space if self.action_space is not None else True ), "Action not in action space!!" return action
[docs] def set_epsilon(self, value: float): if value < self.epsilon_min: self.logger.warning( "Trying to set epsilon value less than epsilon_min. Setting epsilon=epsilon_min" ) value = self.epsilon_min super().set_epsilon(value)
[docs] def decay_epsilon(self): if not self.allow_decay: self.logger.warning("Epsilon decay is not allowed.") return decayed_epsilon = self.epsilon - self.decay_rate if decayed_epsilon >= self.epsilon_min: self.set_epsilon(value=decayed_epsilon)