from typing import Optional
[docs]class Trajectory:
def __init__(self) -> None:
[docs] self.state_actions = list()
[docs] self.additional_info = list()
[docs] def update_step(
self, state, action, reward, timestep: Optional[int] = None, **kwargs
):
if (timestep is None) or (timestep == len(self.state_actions)):
return self.record_step(
state=state, action=action, reward=reward, kwargs=kwargs
)
assert timestep < len(self.state_actions) and timestep >= 0
state_action = (state, action)
self.state_actions[timestep] = state_action
self.rewards[timestep] = reward
self.additional_info[timestep] = kwargs
[docs] def record_step(self, state, action, reward, **kwargs):
state_action = (state, action)
self.state_actions.append(state_action)
self.rewards.append(reward)
self.additional_info.append(kwargs)
[docs] def get_step_with_info(self, timestep: int):
state, action, reward = self.get_step(timestep)
additional_info = self.additional_info[timestep]
return state, action, reward, additional_info
[docs] def get_step(self, timestep: int):
state_action = self.get_state_action(timestep)
reward = self.get_reward(timestep + 1)
state = state_action[0]
action = state_action[1]
return state, action, reward
[docs] def get_state_action(self, timestep: int):
assert timestep < len(self.state_actions) and timestep >= 0
state_action = self.state_actions[timestep]
return state_action
[docs] def get_state(self, timestep: int):
assert timestep < len(self.state_actions) and timestep >= 0
state_action = self.get_state_action(timestep=timestep)
state = state_action[0]
return state
[docs] def get_reward(self, timestep: int):
assert timestep <= len(self.rewards) and timestep > 0
reward = self.rewards[timestep - 1]
return reward
[docs] def check_state_action_appearance_before_timestep(self, state_action, timestep):
has_appeared = state_action in self.state_actions[:timestep]
return has_appeared
[docs] def get_trajectory_length(self):
trajectory_len = len(self.state_actions)
return trajectory_len
[docs] def clear(self):
self.state_actions = list()
self.rewards = list()
self.additional_info = list()