Source code for episode_collector

from typing import Callable, Optional
from gridmind.policies.base_policy import BasePolicy
from gridmind.utils.algorithm_util.trajectory import Trajectory
from gymnasium import Env


[docs]def collect_episode( env: Env, policy: BasePolicy, trajectory: Trajectory, obs_preprocessor: Optional[Callable] = None, record_action_prob: bool = False, ): trajectory.clear() obs, info = env.reset() done = False while not done: obs_raw = obs if obs_preprocessor is not None: obs = obs_preprocessor(obs) action = policy.get_action(state=obs) next_obs, reward, terminated, truncated, info = env.step(action) if record_action_prob: action_prob = policy.get_action_prob(state=obs, action=action) trajectory.record_step( state=obs_raw, action=action, reward=reward, action_prob=action_prob ) else: trajectory.record_step(state=obs_raw, action=action, reward=reward) done = terminated or truncated obs = next_obs