Source code for gridmind.policies.base_policy

from abc import ABC, abstractmethod
import logging


[docs]class BasePolicy(ABC): def __init__(self) -> None: self.logger = logging.getLogger(self.__class__.__name__)
[docs] @abstractmethod def get_action(self, state): raise NotImplementedError("This method must be overridden")
[docs] @abstractmethod def get_action_prob(self, state, action): raise NotImplementedError("This method must be overridden")
[docs] @abstractmethod def get_all_action_probabilities(self, states): raise NotImplementedError("This method must be overridden")
[docs] @abstractmethod def update(self, state, action): raise NotImplementedError("This method must be overridden")
def __call__(self, state, *args, **kwds): return self.get_action(state=state)