from abc import ABC, abstractmethod
import logging
[docs]class BasePolicy(ABC):
def __init__(self) -> None:
self.logger = logging.getLogger(self.__class__.__name__)
[docs] @abstractmethod
def get_action(self, state):
raise NotImplementedError("This method must be overridden")
[docs] @abstractmethod
def get_action_prob(self, state, action):
raise NotImplementedError("This method must be overridden")
[docs] @abstractmethod
def get_all_action_probabilities(self, states):
raise NotImplementedError("This method must be overridden")
[docs] @abstractmethod
def update(self, state, action):
raise NotImplementedError("This method must be overridden")
def __call__(self, state, *args, **kwds):
return self.get_action(state=state)