from typing import Dict
from gridmind.policies.base_policy import BasePolicy
import torch
[docs]class DeterministicLookupPolicy(BasePolicy):
def __init__(self, lookup_table: Dict[int, int]):
[docs] self.lookup_table = lookup_table
[docs] def get_action(self, state):
# Convert tensor to int
if isinstance(state, torch.Tensor):
state = int(state.item())
return self.lookup_table.get(state, None)
[docs] def get_action_prob(self, state, action):
return 1.0 if self.get_action(state) == action else 0.0
[docs] def update(self, state, action):
raise NotImplementedError