src.gridmind.algorithms.base_learning_algorithm

Attributes

SAVE_DATA_DIR

Classes

BaseLearningAlgorithm

Helper class that provides a standard way to create an ABC using

Module Contents

src.gridmind.algorithms.base_learning_algorithm.SAVE_DATA_DIR = None[source]
class src.gridmind.algorithms.base_learning_algorithm.BaseLearningAlgorithm(name: str, env: gymnasium.Env | None = None, summary_dir: str | None = None, write_summary: bool = True)[source]

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

name[source]
logger = None[source]
env = None[source]
epoch_eval_interval = None[source]
perform_evaluation = False[source]
monitor_divergence = False[source]
stop_on_divergence = False[source]
write_summary = True[source]
_initialize_summary_writer(summary_dir, env_name, extra_info: str = '', use_async_writer: bool = False)[source]
register_performance_evaluator(evaluator: gridmind.utils.performance_evaluation.base_performance_evaluator.BasePerformanceEvaluator)[source]
register_divergence_detector(detector: gridmind.utils.divergence.base_divergence_detector.BaseDivergenceDetector)[source]
report_policy()[source]
report_state_values()[source]
report_state_action_values()[source]
_preprocess(observation)[source]
speculate_divergence()[source]
abstract _get_state_value_fn(force_functional_interface: bool = True)[source]
abstract _get_state_action_value_fn(force_functional_interface: bool = True)[source]
abstract _get_policy()[source]
get_state_value_fn(force_functional_interface: bool = True, autopreprocess: bool = False)[source]
get_state_action_value_fn(force_functional_interface: bool = True, autopreprocess: bool = False)[source]
get_policy(autopreprocess: bool = False)[source]
abstract set_policy(policy: gridmind.policies.base_policy.BasePolicy, **kwargs)[source]
abstract _train_episodes(num_episodes: int, prediction_only: bool, *args, **kwargs)[source]
get_policy_cloned()[source]
train(num_episodes: int | None = None, num_steps: int | None = None, prediction_only: bool = False, save_policy: bool = True, *args, **kwargs)[source]
train_steps(num_steps: int, prediction_only: bool, save_policy: bool = True, *args, **kwargs)[source]
abstract _train_steps(num_steps: int, prediction_only: bool, *args, **kwargs)[source]
train_episodes(num_episodes: int, prediction_only: bool, save_policy: bool = True, *args, **kwargs)[source]
_training_wrapper(num_iter: int, prediction_only: bool, save_policy: bool, training_fn: Callable, *args, **kwargs)[source]
_report_all_metrics()[source]
evaluate_policy(num_episodes: int)[source]
optimize_policy(num_episodes: int)[source]
save_policy(path: str)[source]
static load_policy(saved_policy_path: str)[source]