Source code for src.gridmind.policies.parameterized.discrete_action_cnn_policy

from gridmind.policies.parameterized.base_parameterized_policy import (
    BaseParameterizedPolicy,
)
from torch import nn
import torch
import torch.nn.functional as F


[docs]class DiscreteActionCNNPolicy(BaseParameterizedPolicy): def __init__( self, observation_shape: tuple, num_actions: int, ): super().__init__(observation_shape=observation_shape, num_actions=num_actions) H, W, C = observation_shape # Convolutional layers
[docs] self.conv1 = nn.Conv2d(in_channels=C, out_channels=32, kernel_size=8, stride=4)
[docs] self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
[docs] self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
# Calculate the flattened size of the output after the convolutional layers test_input = torch.zeros(1, *observation_shape) # Batch size 1 conv_out_size = self._get_conv_output_size(test_input) # Fully connected layers
[docs] self.fc1 = nn.Linear(conv_out_size, 512)
[docs] self.fc2 = nn.Linear(512, num_actions)
[docs] def _get_conv_output_size(self, x): """Helper function to compute the size of the flattened output after convolutions.""" x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) return x.view(1, -1).size(1)
[docs] def forward(self, x): x = x.permute(0, 3, 1, 2) x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) # Flatten x = F.relu(self.fc1(x)) return self.fc2(x)
[docs] def get_action(self, state): action_probs = self.forward(state) action_probs = F.softmax(action_probs, dim=-1) action = torch.multinomial(action_probs, num_samples=1).detach().cpu().item() return action
[docs] def get_action_prob(self, state, action): action_probs = self.forward(state) action_probs = F.softmax(action_probs, dim=-1) return action_probs[action]
[docs] def update(self, state, action, value): pass