Source code for src.gridmind.policies.parameterized.atari.atari_policy

from gridmind.policies.parameterized.base_parameterized_policy import (
    BaseParameterizedPolicy,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical


[docs]class AtariPolicy(BaseParameterizedPolicy): def __init__(self, observation_shape, num_actions, channel_first: bool = True): super(AtariPolicy, self).__init__( observation_shape=observation_shape, num_actions=num_actions )
[docs] self.channel_first = channel_first
if self.channel_first: channels, height, width = observation_shape else: height, width, channels = self.observation_shape
[docs] self.conv1 = nn.Conv2d(channels, 32, kernel_size=8, stride=4)
[docs] self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
[docs] self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
# Calculate the flattened size of the output after the convolutional layers test_input = torch.zeros(1, *observation_shape) # Batch size 1 conv_out_size = self._get_conv_output_size(test_input)
[docs] self.fc1 = nn.Linear(conv_out_size, 512)
[docs] self.policy_logits = nn.Linear(512, self.num_actions)
[docs] def _get_conv_output_size(self, x): """Helper function to compute the size of the flattened output after convolutions.""" if not self.channel_first: x = x.permute(0, 3, 1, 2) # from [1, 210, 160, 3] to [1, 3, 210, 160] x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) return x.reshape(1, -1).size(1)
[docs] def forward(self, x): x = self.add_batch_dim_if_necessary(x) if not self.channel_first: x = x.permute(0, 3, 1, 2) # from [1, 210, 160, 3] to [1, 3, 210, 160] x = x / 255.0 # normalize pixel values to [0, 1] x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.reshape(x.size(0), -1) x = F.relu(self.fc1(x)) logits = self.policy_logits(x) return logits # raw action logits
[docs] def add_batch_dim_if_necessary(self, state): if state.ndim == 3: state = state.unsqueeze(0) elif state.ndim != 4: raise ValueError( f"Expected state to have 3 or 4 dimensions, but got {state.ndim} dimensions." ) return state
[docs] def get_actions(self, states): logits = self.forward(states) dist = Categorical(logits=logits) actions = dist.sample().unsqueeze(-1) return actions
[docs] def get_action(self, state): logits = self.forward(state) dist = Categorical(logits=logits) action = dist.sample().detach().cpu().item() return action
[docs] def get_action_prob(self, state, action): logits = self.forward(state) dist = Categorical(logits=logits) action_prob = dist.probs[action] return action_prob
[docs] def update(self, state, action): pass
if __name__ == "__main__":
[docs] model = AtariPolicy( observation_shape=(4, 84, 84), num_actions=6, channel_first=True ) # 4 stacked frames, 6 possible actions
sample_input = torch.zeros((1, 4, 84, 84)) # batch of 1, 4 channels, 84x84 image output = model(sample_input) print(output.shape) # torch.Size([1, 6])