Source code for src.gridmind.policies.lookup.deterministic_lookup_policy

from typing import Dict
from gridmind.policies.base_policy import BasePolicy
import torch


[docs]class DeterministicLookupPolicy(BasePolicy): def __init__(self, lookup_table: Dict[int, int]):
[docs] self.lookup_table = lookup_table
[docs] def get_action(self, state): # Convert tensor to int if isinstance(state, torch.Tensor): state = int(state.item()) return self.lookup_table.get(state, None)
[docs] def get_action_prob(self, state, action): return 1.0 if self.get_action(state) == action else 0.0
[docs] def update(self, state, action): raise NotImplementedError