Adding action decoder
This commit is contained in:
parent
47a0772c9d
commit
13765c2f9e
@ -1,3 +1,5 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -91,6 +93,46 @@ class ObservationDecoder(nn.Module):
|
|||||||
return out_dist
|
return out_dist
|
||||||
|
|
||||||
|
|
||||||
|
class ActionDecoder(nn.Module):
|
||||||
|
def __init__(self, state_size, hidden_size, action_size, num_layers=5):
|
||||||
|
super().__init__()
|
||||||
|
self.state_size = state_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.action_size = action_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self._min_std=torch.Tensor([1e-4])[0]
|
||||||
|
self._init_std=torch.Tensor([5])[0]
|
||||||
|
self._mean_scale=torch.Tensor([5])[0]
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
input_channels = state_size if i == 0 else self.hidden_size
|
||||||
|
output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size
|
||||||
|
layers.append(nn.Linear(input_channels, output_channels))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
self.action_model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def get_dist(self, mean, std):
|
||||||
|
distribution = torch.distributions.Normal(mean, std)
|
||||||
|
distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector())
|
||||||
|
distribution = torch.distributions.independent.Independent(distribution, 1)
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
out = self.action_model(features)
|
||||||
|
mean, std = torch.chunk(out, 2, dim=-1)
|
||||||
|
|
||||||
|
raw_init_std = torch.log(torch.exp(self._init_std) - 1)
|
||||||
|
action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale)
|
||||||
|
action_std = F.softplus(std + raw_init_std) + self._min_std
|
||||||
|
|
||||||
|
dist = self.get_dist(action_mean, action_std)
|
||||||
|
sample = dist.rsample()
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TransitionModel(nn.Module):
|
class TransitionModel(nn.Module):
|
||||||
def __init__(self, state_size, hidden_size, action_size, history_size):
|
def __init__(self, state_size, hidden_size, action_size, history_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -137,6 +179,34 @@ class TransitionModel(nn.Module):
|
|||||||
return mean + eps * std
|
return mean + eps * std
|
||||||
|
|
||||||
|
|
||||||
|
class TanhBijector(torch.distributions.Transform):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bijective = True
|
||||||
|
self.domain = torch.distributions.constraints.real
|
||||||
|
self.codomain = torch.distributions.constraints.interval(-1.0, 1.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sign(self): return 1.
|
||||||
|
|
||||||
|
def _call(self, x): return torch.tanh(x)
|
||||||
|
|
||||||
|
def atanh(self, x):
|
||||||
|
return 0.5 * torch.log((1 + x) / (1 - x))
|
||||||
|
|
||||||
|
def _inverse(self, y: torch.Tensor):
|
||||||
|
y = torch.where(
|
||||||
|
(torch.abs(y) <= 1.),
|
||||||
|
torch.clamp(y, -0.99999997, 0.99999997),
|
||||||
|
y)
|
||||||
|
y = self.atanh(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def log_abs_det_jacobian(self, x, y):
|
||||||
|
#return 2. * (np.log(2) - x - F.softplus(-2. * x))
|
||||||
|
return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x))
|
||||||
|
|
||||||
|
|
||||||
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||||
def __init__(self, x_dim, y_dim, hidden_size):
|
def __init__(self, x_dim, y_dim, hidden_size):
|
||||||
super(CLUBSample, self).__init__()
|
super(CLUBSample, self).__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user