Adding action decoder

This commit is contained in:
Vedant Dave 2023-03-31 17:59:42 +02:00
parent 47a0772c9d
commit 13765c2f9e

View File

@ -1,3 +1,5 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -91,6 +93,46 @@ class ObservationDecoder(nn.Module):
return out_dist
class ActionDecoder(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=5):
super().__init__()
self.state_size = state_size
self.hidden_size = hidden_size
self.action_size = action_size
self.num_layers = num_layers
self._min_std=torch.Tensor([1e-4])[0]
self._init_std=torch.Tensor([5])[0]
self._mean_scale=torch.Tensor([5])[0]
layers = []
for i in range(self.num_layers):
input_channels = state_size if i == 0 else self.hidden_size
output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size
layers.append(nn.Linear(input_channels, output_channels))
layers.append(nn.ReLU())
self.action_model = nn.Sequential(*layers)
def get_dist(self, mean, std):
distribution = torch.distributions.Normal(mean, std)
distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector())
distribution = torch.distributions.independent.Independent(distribution, 1)
return distribution
def forward(self, features):
out = self.action_model(features)
mean, std = torch.chunk(out, 2, dim=-1)
raw_init_std = torch.log(torch.exp(self._init_std) - 1)
action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale)
action_std = F.softplus(std + raw_init_std) + self._min_std
dist = self.get_dist(action_mean, action_std)
sample = dist.rsample()
return sample
class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__()
@ -137,6 +179,34 @@ class TransitionModel(nn.Module):
return mean + eps * std
class TanhBijector(torch.distributions.Transform):
def __init__(self):
super().__init__()
self.bijective = True
self.domain = torch.distributions.constraints.real
self.codomain = torch.distributions.constraints.interval(-1.0, 1.0)
@property
def sign(self): return 1.
def _call(self, x): return torch.tanh(x)
def atanh(self, x):
return 0.5 * torch.log((1 + x) / (1 - x))
def _inverse(self, y: torch.Tensor):
y = torch.where(
(torch.abs(y) <= 1.),
torch.clamp(y, -0.99999997, 0.99999997),
y)
y = self.atanh(y)
return y
def log_abs_det_jacobian(self, x, y):
#return 2. * (np.log(2) - x - F.softplus(-2. * x))
return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x))
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, x_dim, y_dim, hidden_size):
super(CLUBSample, self).__init__()