Source code for src.gridmind.policies.parameterized.atari.atari_actor_critic_policy

from gridmind.policies.parameterized.actor_critic_policy import ActorCriticPolicy
from gridmind.policies.parameterized.atari.atari_policy import AtariPolicy


[docs]class AtaricActorCriticPolicy(ActorCriticPolicy): def __init__(self, observation_shape, num_actions, channel_first: bool = True):
[docs] self.channel_first = channel_first
if self.channel_first: self.channels, self.height, self.width = observation_shape else: self.height, self.width, self.channels = observation_shape super(AtaricActorCriticPolicy, self).__init__( observation_shape=observation_shape, num_actions=num_actions )
[docs] def construct_actor_critic_networks(self): # type: ignore actor = AtariPolicy( observation_shape=self.observation_shape, num_actions=self.num_actions, channel_first=self.channel_first, ) critic = AtariPolicy( observation_shape=self.observation_shape, num_actions=1, channel_first=self.channel_first, ) return actor, critic
[docs] def get_value(self, x): value_logits = self.critic(x) value = value_logits.view(-1) return value