# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import utils from encoder import make_encoder LOG_FREQ = 10000 def gaussian_logprob(noise, log_std): """Compute Gaussian log probability.""" residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) def squash(mu, pi, log_pi): """Apply squashing function. See appendix C from https://arxiv.org/pdf/1812.05905.pdf. """ mu = torch.tanh(mu) if pi is not None: pi = torch.tanh(pi) if log_pi is not None: log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) return mu, pi, log_pi def weight_init(m): """Custom weight init for Conv2D and Linear layers.""" if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data) m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf assert m.weight.size(2) == m.weight.size(3) m.weight.data.fill_(0.0) m.bias.data.fill_(0.0) mid = m.weight.size(2) // 2 gain = nn.init.calculate_gain('relu') nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) class Actor(nn.Module): """MLP actor network.""" def __init__( self, obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters, stride ): super().__init__() self.encoder = make_encoder( encoder_type, obs_shape, encoder_feature_dim, num_layers, num_filters, stride ) self.log_std_min = log_std_min self.log_std_max = log_std_max self.trunk = nn.Sequential( nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2 * action_shape[0]) ) self.outputs = dict() self.apply(weight_init) def forward( self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False ): obs = self.encoder(obs, detach=detach_encoder) mu, log_std = self.trunk(obs).chunk(2, dim=-1) # constrain log_std inside [log_std_min, log_std_max] log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * ( self.log_std_max - self.log_std_min ) * (log_std + 1) self.outputs['mu'] = mu self.outputs['std'] = log_std.exp() if compute_pi: std = log_std.exp() noise = torch.randn_like(mu) pi = mu + noise * std else: pi = None entropy = None if compute_log_pi: log_pi = gaussian_logprob(noise, log_std) else: log_pi = None mu, pi, log_pi = squash(mu, pi, log_pi) return mu, pi, log_pi, log_std def log(self, L, step, log_freq=LOG_FREQ): if step % log_freq != 0: return for k, v in self.outputs.items(): L.log_histogram('train_actor/%s_hist' % k, v, step) L.log_param('train_actor/fc1', self.trunk[0], step) L.log_param('train_actor/fc2', self.trunk[2], step) L.log_param('train_actor/fc3', self.trunk[4], step) class QFunction(nn.Module): """MLP for q-function.""" def __init__(self, obs_dim, action_dim, hidden_dim): super().__init__() self.trunk = nn.Sequential( nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) def forward(self, obs, action): assert obs.size(0) == action.size(0) obs_action = torch.cat([obs, action], dim=1) return self.trunk(obs_action) class Critic(nn.Module): """Critic network, employes two q-functions.""" def __init__( self, obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, num_layers, num_filters, stride ): super().__init__() self.encoder = make_encoder( encoder_type, obs_shape, encoder_feature_dim, num_layers, num_filters, stride ) self.Q1 = QFunction( self.encoder.feature_dim, action_shape[0], hidden_dim ) self.Q2 = QFunction( self.encoder.feature_dim, action_shape[0], hidden_dim ) self.outputs = dict() self.apply(weight_init) def forward(self, obs, action, detach_encoder=False): # detach_encoder allows to stop gradient propogation to encoder obs = self.encoder(obs, detach=detach_encoder) q1 = self.Q1(obs, action) q2 = self.Q2(obs, action) self.outputs['q1'] = q1 self.outputs['q2'] = q2 return q1, q2 def log(self, L, step, log_freq=LOG_FREQ): if step % log_freq != 0: return self.encoder.log(L, step, log_freq) for k, v in self.outputs.items(): L.log_histogram('train_critic/%s_hist' % k, v, step) for i in range(3): L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)