Source code for src.gridmind.wrappers.policy_wrappers.preprocessed_observation_policy_wrapper

from gridmind.wrappers.policy_wrappers.base_policy_wrapper import BasePolicyWrapper


[docs]class PreprocessedObservationPolicyWrapper(BasePolicyWrapper): def __init__(self, policy, preprocess_fn): super().__init__(policy)
[docs] self.preprocess_fn = preprocess_fn
[docs] def get_action(self, state): preprocessed_state = self.preprocess_fn(state) return self.policy.get_action(preprocessed_state)
[docs] def get_action_prob(self, state, action): preprocessed_state = self.preprocess_fn(state) return self.policy.get_action_prob(preprocessed_state, action)
[docs] def update(self, state, action): preprocessed_state = self.preprocess_fn(state) return self.policy.update(preprocessed_state, action)