from collections import defaultdict
from typing import Optional
from gridmind.algorithms.base_learning_algorithm import BaseLearningAlgorithm
from gridmind.policies.soft.q_derived.base_q_derived_soft_policy import (
BaseQDerivedSoftPolicy,
)
from gridmind.policies.soft.q_derived.q_table_derived_epsilon_greedy_policy import (
QTableDerivedEpsilonGreedyPolicy,
)
from gymnasium import Env
import numpy as np
from tqdm import tqdm
[docs]class QLearning(BaseLearningAlgorithm):
def __init__(
self,
env: Env,
policy: Optional[BaseQDerivedSoftPolicy] = None,
step_size: float = 0.1,
discount_factor: float = 0.9,
q_initializer: str = "zero",
epsilon_decay: bool = False,
epsilon: float = 0.1,
summary_dir: Optional[str] = None,
write_summary: bool = True,
) -> None:
super().__init__(
"Q-Learning", env=env, summary_dir=summary_dir, write_summary=write_summary
)
[docs] self.num_actions = self.env.action_space.n
[docs] self.epsilon_decay = epsilon_decay
q_initializer = q_initializer.lower()
assert q_initializer in [
"zero",
"random",
], "q_initializer may only take the value 'zero' or 'random'"
if q_initializer == "zero":
self.q_values = defaultdict(lambda: np.zeros(self.num_actions))
else:
self.q_values = defaultdict(lambda: np.random.rand(self.num_actions))
[docs] self.policy = (
policy
if policy is not None
else QTableDerivedEpsilonGreedyPolicy(
q_table=self.q_values, num_actions=self.num_actions
)
)
self.policy.set_epsilon(self.epsilon)
[docs] self.step_size = step_size
[docs] self.discount_factor = discount_factor
[docs] def _get_state_value_fn(self, force_functional_interface: bool = True):
raise Exception(
f"{self.name} computes only state-action values. Use get_state_action_values() to get state-action values."
)
[docs] def _get_state_action_value_fn(self, force_functional_interface: bool = True):
if not force_functional_interface:
return self.q_values
return lambda s, a: self.q_values[s][a]
[docs] def _get_policy(self):
return self.policy
[docs] def _train_steps(self, num_steps: int, prediction_only: bool, *args, **kwargs):
raise NotImplementedError()
[docs] def _train_episodes(self, num_episodes: int, prediction_only: bool = False):
if prediction_only:
raise Exception("This is a control-only implementation.")
for i in tqdm(range(num_episodes)):
obs, info = self.env.reset()
done = False
while not done:
action_mask = info.get("action_mask", None)
action = self.policy.get_action(obs, action_mask=action_mask)
next_obs, reward, terminated, truncated, info = self.env.step(action)
self.q_values[obs][action] = self.q_values[obs][
action
] + self.step_size * (
reward
+ self.discount_factor
* np.max(self.q_values[next_obs])
* (1 - terminated)
- self.q_values[obs][action]
)
self.policy.update_q(
state=obs, action=action, value=self.q_values[obs][action]
)
obs = next_obs
done = terminated or truncated
if self.epsilon_decay:
self.policy.decay_epsilon()
[docs] def set_policy(self, policy: BaseQDerivedSoftPolicy):
self.policy = policy