from gridmind.policies.parameterized.base_parameterized_policy import (
BaseParameterizedPolicy,
)
from torch import nn
import math
import torch
import torch.nn.functional as F
[docs]class DiscreteActionMLPPolicy(BaseParameterizedPolicy):
def __init__(
self,
observation_shape: tuple,
num_actions: int,
num_hidden_layers: int = 0,
in_features: int = 16,
out_features: int = 16,
use_bias: bool = True,
):
super().__init__(observation_shape=observation_shape, num_actions=num_actions)
num_input_features = math.prod(observation_shape)
[docs] self.num_hidden_layers = num_hidden_layers
[docs] self.in_features = in_features
[docs] self.out_features = out_features
[docs] self.hidden_layers = nn.ModuleList()
if self.num_hidden_layers <= 0:
self.linear_out = nn.Linear(
in_features=num_input_features, out_features=num_actions, bias=use_bias
)
else:
self.hidden_layers.append(
nn.Sequential(
nn.Linear(
in_features=num_input_features,
out_features=self.out_features,
bias=use_bias,
),
nn.ReLU(),
)
)
for _ in range(self.num_hidden_layers - 1):
self.hidden_layers.append(self._create_hidden_layer(use_bias=use_bias))
self.linear_out = nn.Linear(
in_features=self.in_features, out_features=num_actions, bias=use_bias
)
[docs] def _create_hidden_layer(self, use_bias: bool):
return nn.Sequential(
nn.Linear(self.in_features, self.out_features, bias=use_bias), nn.ReLU()
)
[docs] def forward(self, x):
# x = x.view(-1)
for hidden_layer in self.hidden_layers:
x = hidden_layer(x)
out = self.linear_out(x)
return out
[docs] def get_action(self, state):
action_probs = self.forward(state)
action_probs = F.softmax(action_probs, dim=-1)
action = torch.multinomial(action_probs, num_samples=1).detach().cpu().item()
return action
[docs] def get_actions(self, states):
action_probs = self.forward(states)
action_probs = F.softmax(action_probs, dim=-1)
actions = torch.multinomial(action_probs, num_samples=1)
return actions
[docs] def get_action_prob(self, state, action):
action_probs = self.forward(state)
action_probs = F.softmax(action_probs, dim=-1)
return action_probs[action]
[docs] def get_all_action_probabilities(self, states):
action_probs = self.forward(states)
action_probs = F.softmax(action_probs, dim=-1)
return action_probs
[docs] def update(self, state, action, value):
pass