189 lines
5.5 KiB
Python
189 lines
5.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import utils
|
|
from encoder import make_encoder
|
|
|
|
LOG_FREQ = 10000
|
|
|
|
|
|
def gaussian_logprob(noise, log_std):
|
|
"""Compute Gaussian log probability."""
|
|
residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
|
|
return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
|
|
|
|
|
|
def squash(mu, pi, log_pi):
|
|
"""Apply squashing function.
|
|
See appendix C from https://arxiv.org/pdf/1812.05905.pdf.
|
|
"""
|
|
mu = torch.tanh(mu)
|
|
if pi is not None:
|
|
pi = torch.tanh(pi)
|
|
if log_pi is not None:
|
|
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
|
return mu, pi, log_pi
|
|
|
|
|
|
def weight_init(m):
|
|
"""Custom weight init for Conv2D and Linear layers."""
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.orthogonal_(m.weight.data)
|
|
m.bias.data.fill_(0.0)
|
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
|
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
|
|
assert m.weight.size(2) == m.weight.size(3)
|
|
m.weight.data.fill_(0.0)
|
|
m.bias.data.fill_(0.0)
|
|
mid = m.weight.size(2) // 2
|
|
gain = nn.init.calculate_gain('relu')
|
|
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
|
|
|
|
|
class Actor(nn.Module):
|
|
"""MLP actor network."""
|
|
def __init__(
|
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
|
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters, stride
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder = make_encoder(
|
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
|
num_filters, stride
|
|
)
|
|
|
|
self.log_std_min = log_std_min
|
|
self.log_std_max = log_std_max
|
|
|
|
self.trunk = nn.Sequential(
|
|
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
|
nn.Linear(hidden_dim, 2 * action_shape[0])
|
|
)
|
|
|
|
self.outputs = dict()
|
|
self.apply(weight_init)
|
|
|
|
def forward(
|
|
self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
|
|
):
|
|
obs = self.encoder(obs, detach=detach_encoder)
|
|
|
|
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
|
|
|
|
# constrain log_std inside [log_std_min, log_std_max]
|
|
log_std = torch.tanh(log_std)
|
|
log_std = self.log_std_min + 0.5 * (
|
|
self.log_std_max - self.log_std_min
|
|
) * (log_std + 1)
|
|
|
|
self.outputs['mu'] = mu
|
|
self.outputs['std'] = log_std.exp()
|
|
|
|
if compute_pi:
|
|
std = log_std.exp()
|
|
noise = torch.randn_like(mu)
|
|
pi = mu + noise * std
|
|
else:
|
|
pi = None
|
|
entropy = None
|
|
|
|
if compute_log_pi:
|
|
log_pi = gaussian_logprob(noise, log_std)
|
|
else:
|
|
log_pi = None
|
|
|
|
mu, pi, log_pi = squash(mu, pi, log_pi)
|
|
|
|
return mu, pi, log_pi, log_std
|
|
|
|
def log(self, L, step, log_freq=LOG_FREQ):
|
|
if step % log_freq != 0:
|
|
return
|
|
|
|
for k, v in self.outputs.items():
|
|
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
|
|
|
L.log_param('train_actor/fc1', self.trunk[0], step)
|
|
L.log_param('train_actor/fc2', self.trunk[2], step)
|
|
L.log_param('train_actor/fc3', self.trunk[4], step)
|
|
|
|
|
|
class QFunction(nn.Module):
|
|
"""MLP for q-function."""
|
|
def __init__(self, obs_dim, action_dim, hidden_dim):
|
|
super().__init__()
|
|
|
|
self.trunk = nn.Sequential(
|
|
nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
|
nn.Linear(hidden_dim, 1)
|
|
)
|
|
|
|
def forward(self, obs, action):
|
|
assert obs.size(0) == action.size(0)
|
|
|
|
obs_action = torch.cat([obs, action], dim=1)
|
|
return self.trunk(obs_action)
|
|
|
|
|
|
class Critic(nn.Module):
|
|
"""Critic network, employes two q-functions."""
|
|
def __init__(
|
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
|
encoder_feature_dim, num_layers, num_filters, stride
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder = make_encoder(
|
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
|
num_filters, stride
|
|
)
|
|
|
|
self.Q1 = QFunction(
|
|
self.encoder.feature_dim, action_shape[0], hidden_dim
|
|
)
|
|
self.Q2 = QFunction(
|
|
self.encoder.feature_dim, action_shape[0], hidden_dim
|
|
)
|
|
|
|
self.outputs = dict()
|
|
self.apply(weight_init)
|
|
|
|
def forward(self, obs, action, detach_encoder=False):
|
|
# detach_encoder allows to stop gradient propogation to encoder
|
|
obs = self.encoder(obs, detach=detach_encoder)
|
|
|
|
q1 = self.Q1(obs, action)
|
|
q2 = self.Q2(obs, action)
|
|
|
|
self.outputs['q1'] = q1
|
|
self.outputs['q2'] = q2
|
|
|
|
return q1, q2
|
|
|
|
def log(self, L, step, log_freq=LOG_FREQ):
|
|
if step % log_freq != 0:
|
|
return
|
|
|
|
self.encoder.log(L, step, log_freq)
|
|
|
|
for k, v in self.outputs.items():
|
|
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
|
|
|
for i in range(3):
|
|
L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
|
|
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
|
|
|
|
|
|
|