Source code for trajectory

from typing import Optional


[docs]class Trajectory: def __init__(self) -> None:
[docs] self.state_actions = list()
[docs] self.rewards = list()
[docs] self.additional_info = list()
[docs] def update_step( self, state, action, reward, timestep: Optional[int] = None, **kwargs ): if (timestep is None) or (timestep == len(self.state_actions)): return self.record_step( state=state, action=action, reward=reward, kwargs=kwargs ) assert timestep < len(self.state_actions) and timestep >= 0 state_action = (state, action) self.state_actions[timestep] = state_action self.rewards[timestep] = reward self.additional_info[timestep] = kwargs
[docs] def record_step(self, state, action, reward, **kwargs): state_action = (state, action) self.state_actions.append(state_action) self.rewards.append(reward) self.additional_info.append(kwargs)
[docs] def get_step_with_info(self, timestep: int): state, action, reward = self.get_step(timestep) additional_info = self.additional_info[timestep] return state, action, reward, additional_info
[docs] def get_step(self, timestep: int): state_action = self.get_state_action(timestep) reward = self.get_reward(timestep + 1) state = state_action[0] action = state_action[1] return state, action, reward
[docs] def get_state_action(self, timestep: int): assert timestep < len(self.state_actions) and timestep >= 0 state_action = self.state_actions[timestep] return state_action
[docs] def get_state(self, timestep: int): assert timestep < len(self.state_actions) and timestep >= 0 state_action = self.get_state_action(timestep=timestep) state = state_action[0] return state
[docs] def get_reward(self, timestep: int): assert timestep <= len(self.rewards) and timestep > 0 reward = self.rewards[timestep - 1] return reward
[docs] def check_state_action_appearance_before_timestep(self, state_action, timestep): has_appeared = state_action in self.state_actions[:timestep] return has_appeared
[docs] def get_trajectory_length(self): trajectory_len = len(self.state_actions) return trajectory_len
[docs] def clear(self): self.state_actions = list() self.rewards = list() self.additional_info = list()