From 13765c2f9ea1ae3faa7345f5bb01c029c135bb67 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Fri, 31 Mar 2023 17:59:42 +0200 Subject: [PATCH] Adding action decoder --- DPI/models.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/DPI/models.py b/DPI/models.py index 6d6b6c0..e6adcf4 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -1,3 +1,5 @@ +import numpy as np + import torch import torch.nn as nn import torch.nn.functional as F @@ -91,6 +93,46 @@ class ObservationDecoder(nn.Module): return out_dist +class ActionDecoder(nn.Module): + def __init__(self, state_size, hidden_size, action_size, num_layers=5): + super().__init__() + self.state_size = state_size + self.hidden_size = hidden_size + self.action_size = action_size + self.num_layers = num_layers + + self._min_std=torch.Tensor([1e-4])[0] + self._init_std=torch.Tensor([5])[0] + self._mean_scale=torch.Tensor([5])[0] + + layers = [] + for i in range(self.num_layers): + input_channels = state_size if i == 0 else self.hidden_size + output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size + layers.append(nn.Linear(input_channels, output_channels)) + layers.append(nn.ReLU()) + self.action_model = nn.Sequential(*layers) + + def get_dist(self, mean, std): + distribution = torch.distributions.Normal(mean, std) + distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector()) + distribution = torch.distributions.independent.Independent(distribution, 1) + return distribution + + def forward(self, features): + out = self.action_model(features) + mean, std = torch.chunk(out, 2, dim=-1) + + raw_init_std = torch.log(torch.exp(self._init_std) - 1) + action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale) + action_std = F.softplus(std + raw_init_std) + self._min_std + + dist = self.get_dist(action_mean, action_std) + sample = dist.rsample() + return sample + + + class TransitionModel(nn.Module): def __init__(self, state_size, hidden_size, action_size, history_size): super().__init__() @@ -135,7 +177,35 @@ class TransitionModel(nn.Module): def reparemeterize(self, mean, std): eps = torch.randn_like(std) return mean + eps * std - + + +class TanhBijector(torch.distributions.Transform): + def __init__(self): + super().__init__() + self.bijective = True + self.domain = torch.distributions.constraints.real + self.codomain = torch.distributions.constraints.interval(-1.0, 1.0) + + @property + def sign(self): return 1. + + def _call(self, x): return torch.tanh(x) + + def atanh(self, x): + return 0.5 * torch.log((1 + x) / (1 - x)) + + def _inverse(self, y: torch.Tensor): + y = torch.where( + (torch.abs(y) <= 1.), + torch.clamp(y, -0.99999997, 0.99999997), + y) + y = self.atanh(y) + return y + + def log_abs_det_jacobian(self, x, y): + #return 2. * (np.log(2) - x - F.softplus(-2. * x)) + return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x)) + class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def __init__(self, x_dim, y_dim, hidden_size):