Source code for prediction.td_0_prediction

from collections import defaultdict
from typing import Optional
from gridmind.algorithms.base_learning_algorithm import BaseLearningAlgorithm
from gridmind.policies.base_policy import BasePolicy
import gymnasium as gym


[docs]class TD0Prediction(BaseLearningAlgorithm): """ Tabular TD(0) for estimating V_pi. Input: policy to be evaluated. The policy is supposed to be a function whose input is observation and output is action. """ def __init__( self, env: gym.Env, policy: BasePolicy, step_size: float = 0.1, discount_factor: float = 0.9, summary_dir: Optional[str] = None, write_summary: bool = True, ) -> None: super().__init__( name="TD-0-Prediction", env=env, summary_dir=summary_dir, write_summary=write_summary, )
[docs] self.step_size = step_size
[docs] self.V = defaultdict(int)
[docs] self.policy = policy
[docs] self.discount_factor = discount_factor
[docs] def _get_state_value_fn(self, force_functional_interface: bool = True): if not force_functional_interface: return self.V return lambda s: self.V[s]
[docs] def _get_state_action_value_fn(self, force_functional_interface: bool = True): raise Exception( f"{self.name} computes only the state values. Use get_state_value_fn() method to get state values." )
[docs] def _get_policy(self): return self.policy
[docs] def _train_steps(self, num_steps: int, prediction_only: bool, *args, **kwargs): raise NotImplementedError()
[docs] def _train_episodes(self, num_episodes: int, prediction_only: bool = True): if prediction_only == False: raise Exception("This is a prediction/evaluation only implementation.") for i in range(num_episodes): obs, info = self.env.reset() done = False while not done: action = self.policy.get_action(obs) next_obs, reward, terminated, truncated, _ = self.env.step(action) self.V[obs] = self.V[obs] + self.step_size * ( reward + self.discount_factor * self.V[next_obs] - self.V[obs] ) obs = next_obs done = terminated or truncated return self.V
[docs] def set_policy(self, policy: BasePolicy, **kwargs): raise NotImplementedError