from collections import defaultdict
from typing import Callable, Optional
from gridmind.algorithms.base_learning_algorithm import BaseLearningAlgorithm
from gridmind.policies.soft.q_derived.base_q_derived_soft_policy import (
BaseQDerivedSoftPolicy,
)
from gridmind.policies.soft.q_derived.q_table_derived_epsilon_greedy_policy import (
QTableDerivedEpsilonGreedyPolicy,
)
from gymnasium import Env
import numpy as np
from tqdm import tqdm
[docs]class SARSA(BaseLearningAlgorithm):
def __init__(
self,
env: Env,
policy: Optional[BaseQDerivedSoftPolicy] = None,
step_size: float = 0.5,
discount_factor: float = 0.9,
q_initializer: str = "zero",
epsilon_decay: bool = False,
feature_constructor: Callable = None,
summary_dir: Optional[str] = None,
write_summary: bool = True,
) -> None:
super().__init__(
"SARSA", env=env, summary_dir=summary_dir, write_summary=write_summary
)
[docs] self.num_actions = self.env.action_space.n
[docs] self.feature_constructor = feature_constructor
assert q_initializer in [
"zero",
"random",
], "q_initializer may only take the value 'zero' or 'random'"
if q_initializer == "zero":
self.q_values = defaultdict(lambda: np.zeros(self.num_actions))
else:
self.q_values = defaultdict(lambda: np.random.rand(self.num_actions))
[docs] self.policy = (
policy
if policy is not None
else QTableDerivedEpsilonGreedyPolicy(
q_table=self.q_values, num_actions=self.num_actions
)
)
[docs] self.step_size = step_size
[docs] self.discount_factor = discount_factor
[docs] self.epsilon_decay = epsilon_decay
[docs] def _get_state_value_fn(self, force_functional_interface: bool = True):
raise Exception(
f"{self.name} computes only state-action values. Use get_state_action_values() to get state-action values."
)
[docs] def _get_state_action_value_fn(self, force_functional_interface: bool = True):
if not force_functional_interface:
return self.q_values
return lambda s, a: self.q_values[s][a]
[docs] def _get_policy(self):
return self.policy
[docs] def _train_steps(self, num_steps: int, prediction_only: bool, *args, **kwargs):
raise NotImplementedError()
[docs] def _train_episodes(self, num_episodes: int, prediction_only: bool = False):
if prediction_only:
raise Exception("This is a control-only implementation.")
for i in tqdm(range(num_episodes)):
obs, info = self.env.reset()
if self.feature_constructor is not None:
obs = self.feature_constructor(obs)
done = False
action = self.policy.get_action(obs)
while not done:
next_obs, reward, terminated, truncated, _ = self.env.step(action)
if self.feature_constructor is not None:
next_obs = self.feature_constructor(next_obs)
next_action = self.policy.get_action(next_obs)
self.q_values[obs][action] = self.q_values[obs][
action
] + self.step_size * (
reward
+ self.discount_factor * self.q_values[next_obs][next_action]
- self.q_values[obs][action]
)
self.policy.update_q(
state=obs, action=action, value=self.q_values[obs][action]
)
obs = next_obs
action = next_action
done = terminated or truncated
if self.epsilon_decay:
self.policy.decay_epsilon()
[docs] def set_policy(self, policy: BaseQDerivedSoftPolicy):
self.policy = policy