from collections import defaultdict
import random
from typing import Optional
from gymnasium.spaces.space import Space
from gridmind.policies.base_policy import BasePolicy
[docs]class StochasticStartGreedyPolicy(BasePolicy):
def __init__(self, num_actions: int, action_space: Optional[Space] = None) -> None:
super().__init__()
[docs] self.action_space = action_space
[docs] self.num_actions = num_actions
assert (
num_actions == self.action_space.n
if self.action_space is not None
else True
), "Provided num_actions does not match with number of actions in the provided action_space."
[docs] self.policy_dict = defaultdict(lambda: random.randint(0, self.num_actions - 1))
[docs] def get_action(self, state):
action = self.policy_dict[state]
assert (
action in self.action_space if self.action_space is not None else True
), "Action not in action space!!"
return action
[docs] def get_action_prob(self, state, action):
policy_act = self.get_action(state)
action_probs = 1.0 if action == policy_act else 0
return action_probs
[docs] def update(self, state, action):
assert (
action in self.action_space if self.action_space is not None else True
), "Action not in action space!!"
self.policy_dict[state] = action