from abc import ABC, abstractmethod
import logging
[docs]class BasePolicy(ABC):
def __init__(self) -> None:
[docs] self.logger = logging.getLogger(self.__class__.__name__)
@abstractmethod
[docs] def get_action(self, state):
raise NotImplementedError("This method must be overridden")
@abstractmethod
[docs] def get_action_prob(self, state, action):
raise NotImplementedError("This method must be overridden")
@abstractmethod
[docs] def get_all_action_probabilities(self, states):
raise NotImplementedError("This method must be overridden")
@abstractmethod
[docs] def update(self, state, action):
raise NotImplementedError("This method must be overridden")
[docs] def __call__(self, state, *args, **kwds):
return self.get_action(state=state)