Source code for src.gridmind.policies.parameterized.actor_critic_policy

from gridmind.policies.parameterized.base_parameterized_policy import (
    BaseParameterizedPolicy,
)
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical


[docs]def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer
[docs]class ActorCriticPolicy(BaseParameterizedPolicy): def __init__(self, observation_shape, num_actions): super(ActorCriticPolicy, self).__init__( observation_shape=observation_shape, num_actions=num_actions ) self.actor, self.critic = self.construct_actor_critic_networks()
[docs] def construct_actor_critic_networks(self): input_size = int(np.prod(self.observation_shape)) critic = nn.Sequential( nn.Linear(input_size, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 1), ) actor = nn.Sequential( nn.Linear(input_size, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, self.num_actions), ) return actor, critic
[docs] def get_value(self, x): return self.critic(x)
[docs] def get_action_and_value(self, x, action=None): logits = self.actor(x) dist = Categorical(logits=logits) if action is None: action = dist.sample() return action, dist.log_prob(action.squeeze()), dist.entropy(), self.critic(x)
[docs] def get_action(self, state): logits = self.actor(state) dist = Categorical(logits=logits) action = dist.sample() return action.detach().cpu().item()
[docs] def get_action_prob(self, state, action): logits = self.actor(state) dist = Categorical(logits=logits) action_prob = dist.probs[action] return action_prob
[docs] def update(self, state, action): raise NotImplementedError