diff --git a/DPI/models.py b/DPI/models.py index d05e7ec..6d6b6c0 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -34,12 +34,22 @@ class ObservationEncoder(nn.Module): std = torch.clamp(std, min=0.0, max=1e5) # Normal Distribution + dist = self.get_dist(mean, std) + + # Sampling via reparameterization Trick x = self.reparameterize(mean, std) - return x + + encoded_output = {"sample": x, "distribution": dist} + return encoded_output def reparameterize(self, mu, std): eps = torch.randn_like(std) return mu + eps * std + + def get_dist(self, mean, std): + distribution = torch.distributions.Normal(mean, std) + distribution = torch.distributions.independent.Independent(distribution, 1) + return distribution class ObservationDecoder(nn.Module): @@ -114,8 +124,12 @@ class TransitionModel(nn.Module): state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_std = F.softplus(state_prior_std) + # Normal Distribution + state_prior_dist = self.get_dist(state_prior_mean, state_prior_std) + + # Sampling via reparameterization Trick sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std) - prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history} + prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} return prior def reparemeterize(self, mean, std): @@ -154,15 +168,4 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def forward(self, x_samples, y_samples): mu, logvar = self.get_mu_logvar(x_samples) - - sample_size = x_samples.shape[0] - #random_index = torch.randint(sample_size, (sample_size,)).long() - random_index = torch.randperm(sample_size).long() - - positive = - (mu - y_samples)**2 / logvar.exp() - negative = - (mu - y_samples[random_index])**2 / logvar.exp() - upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean() - return upper_bound/2. - - def learning_loss(self, x_samples, y_samples): return - self.loglikeli(x_samples, y_samples) \ No newline at end of file