Adding action decoder
This commit is contained in:
parent
47a0772c9d
commit
13765c2f9e
@ -1,3 +1,5 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -91,6 +93,46 @@ class ObservationDecoder(nn.Module):
|
||||
return out_dist
|
||||
|
||||
|
||||
class ActionDecoder(nn.Module):
|
||||
def __init__(self, state_size, hidden_size, action_size, num_layers=5):
|
||||
super().__init__()
|
||||
self.state_size = state_size
|
||||
self.hidden_size = hidden_size
|
||||
self.action_size = action_size
|
||||
self.num_layers = num_layers
|
||||
|
||||
self._min_std=torch.Tensor([1e-4])[0]
|
||||
self._init_std=torch.Tensor([5])[0]
|
||||
self._mean_scale=torch.Tensor([5])[0]
|
||||
|
||||
layers = []
|
||||
for i in range(self.num_layers):
|
||||
input_channels = state_size if i == 0 else self.hidden_size
|
||||
output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size
|
||||
layers.append(nn.Linear(input_channels, output_channels))
|
||||
layers.append(nn.ReLU())
|
||||
self.action_model = nn.Sequential(*layers)
|
||||
|
||||
def get_dist(self, mean, std):
|
||||
distribution = torch.distributions.Normal(mean, std)
|
||||
distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector())
|
||||
distribution = torch.distributions.independent.Independent(distribution, 1)
|
||||
return distribution
|
||||
|
||||
def forward(self, features):
|
||||
out = self.action_model(features)
|
||||
mean, std = torch.chunk(out, 2, dim=-1)
|
||||
|
||||
raw_init_std = torch.log(torch.exp(self._init_std) - 1)
|
||||
action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale)
|
||||
action_std = F.softplus(std + raw_init_std) + self._min_std
|
||||
|
||||
dist = self.get_dist(action_mean, action_std)
|
||||
sample = dist.rsample()
|
||||
return sample
|
||||
|
||||
|
||||
|
||||
class TransitionModel(nn.Module):
|
||||
def __init__(self, state_size, hidden_size, action_size, history_size):
|
||||
super().__init__()
|
||||
@ -135,7 +177,35 @@ class TransitionModel(nn.Module):
|
||||
def reparemeterize(self, mean, std):
|
||||
eps = torch.randn_like(std)
|
||||
return mean + eps * std
|
||||
|
||||
|
||||
|
||||
class TanhBijector(torch.distributions.Transform):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bijective = True
|
||||
self.domain = torch.distributions.constraints.real
|
||||
self.codomain = torch.distributions.constraints.interval(-1.0, 1.0)
|
||||
|
||||
@property
|
||||
def sign(self): return 1.
|
||||
|
||||
def _call(self, x): return torch.tanh(x)
|
||||
|
||||
def atanh(self, x):
|
||||
return 0.5 * torch.log((1 + x) / (1 - x))
|
||||
|
||||
def _inverse(self, y: torch.Tensor):
|
||||
y = torch.where(
|
||||
(torch.abs(y) <= 1.),
|
||||
torch.clamp(y, -0.99999997, 0.99999997),
|
||||
y)
|
||||
y = self.atanh(y)
|
||||
return y
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
#return 2. * (np.log(2) - x - F.softplus(-2. * x))
|
||||
return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x))
|
||||
|
||||
|
||||
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||
def __init__(self, x_dim, y_dim, hidden_size):
|
||||
|
Loading…
Reference in New Issue
Block a user