import random
from typing import Optional
from gymnasium.spaces.space import Space
from gridmind.policies.base_policy import BasePolicy
[docs]class RandomPolicy(BasePolicy):
def __init__(self, num_actions, action_space: Optional[Space] = None) -> None:
super().__init__()
[docs] self.action_space = action_space
[docs] self.num_actions = num_actions
assert (
num_actions == self.action_space.n
if self.action_space is not None
else True
), "Provided num_actions does not match with number of actions in the provided action_space."
[docs] def get_action(self, state):
action = random.randint(0, self.num_actions - 1)
assert (
action in self.action_space if self.action_space is not None else True
), "Action not in action space!!"
return action
[docs] def get_action_prob(self, state, action):
action_probs = 1 / self.num_actions
return action_probs
[docs] def update(self, state, action):
raise Exception("This policy is for prediction (value estimation only).")