Source code for q_learning

from collections import defaultdict
from typing import Optional
from gridmind.algorithms.base_learning_algorithm import BaseLearningAlgorithm

from gridmind.policies.soft.q_derived.base_q_derived_soft_policy import (
    BaseQDerivedSoftPolicy,
)
from gridmind.policies.soft.q_derived.q_table_derived_epsilon_greedy_policy import (
    QTableDerivedEpsilonGreedyPolicy,
)
from gymnasium import Env
import numpy as np
from tqdm import tqdm


[docs]class QLearning(BaseLearningAlgorithm): def __init__( self, env: Env, policy: Optional[BaseQDerivedSoftPolicy] = None, step_size: float = 0.1, discount_factor: float = 0.9, q_initializer: str = "zero", epsilon_decay: bool = False, epsilon: float = 0.1, summary_dir: Optional[str] = None, write_summary: bool = True, ) -> None: super().__init__( "Q-Learning", env=env, summary_dir=summary_dir, write_summary=write_summary )
[docs] self.num_actions = self.env.action_space.n
[docs] self.epsilon_decay = epsilon_decay
[docs] self.epsilon = epsilon
q_initializer = q_initializer.lower() assert q_initializer in [ "zero", "random", ], "q_initializer may only take the value 'zero' or 'random'" if q_initializer == "zero": self.q_values = defaultdict(lambda: np.zeros(self.num_actions)) else: self.q_values = defaultdict(lambda: np.random.rand(self.num_actions))
[docs] self.policy = ( policy if policy is not None else QTableDerivedEpsilonGreedyPolicy( q_table=self.q_values, num_actions=self.num_actions ) )
self.policy.set_epsilon(self.epsilon)
[docs] self.step_size = step_size
[docs] self.discount_factor = discount_factor
[docs] def _get_state_value_fn(self, force_functional_interface: bool = True): raise Exception( f"{self.name} computes only state-action values. Use get_state_action_values() to get state-action values." )
[docs] def _get_state_action_value_fn(self, force_functional_interface: bool = True): if not force_functional_interface: return self.q_values return lambda s, a: self.q_values[s][a]
[docs] def _get_policy(self): return self.policy
[docs] def _train_steps(self, num_steps: int, prediction_only: bool, *args, **kwargs): raise NotImplementedError()
[docs] def _train_episodes(self, num_episodes: int, prediction_only: bool = False): if prediction_only: raise Exception("This is a control-only implementation.") for i in tqdm(range(num_episodes)): obs, info = self.env.reset() done = False while not done: action_mask = info.get("action_mask", None) action = self.policy.get_action(obs, action_mask=action_mask) next_obs, reward, terminated, truncated, info = self.env.step(action) self.q_values[obs][action] = self.q_values[obs][ action ] + self.step_size * ( reward + self.discount_factor * np.max(self.q_values[next_obs]) * (1 - terminated) - self.q_values[obs][action] ) self.policy.update_q( state=obs, action=action, value=self.q_values[obs][action] ) obs = next_obs done = terminated or truncated if self.epsilon_decay: self.policy.decay_epsilon()
[docs] def set_policy(self, policy: BaseQDerivedSoftPolicy): self.policy = policy