Source code for src.gridmind.policies.base_policy

from abc import ABC, abstractmethod
import logging


[docs]class BasePolicy(ABC): def __init__(self) -> None:
[docs] self.logger = logging.getLogger(self.__class__.__name__)
@abstractmethod
[docs] def get_action(self, state): raise NotImplementedError("This method must be overridden")
@abstractmethod
[docs] def get_action_prob(self, state, action): raise NotImplementedError("This method must be overridden")
@abstractmethod
[docs] def get_all_action_probabilities(self, states): raise NotImplementedError("This method must be overridden")
@abstractmethod
[docs] def update(self, state, action): raise NotImplementedError("This method must be overridden")
[docs] def __call__(self, state, *args, **kwds): return self.get_action(state=state)