From 9085abe6842d0ae75fc0d9b55224dd6d1060f25c Mon Sep 17 00:00:00 2001 From: VedantDave Date: Wed, 12 Apr 2023 09:34:11 +0200 Subject: [PATCH] Correcting UB loss --- DPI/models.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/DPI/models.py b/DPI/models.py index a1c0d28..ec45b81 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -39,6 +39,7 @@ class ObservationEncoder(nn.Module): dist = self.get_dist(mean, std) # Sampling via reparameterization Trick + #x = dist.rsample() x = self.reparameterize(mean, std) encoded_output = {"sample": x, "distribution": dist} @@ -165,9 +166,9 @@ class RewardModel(nn.Module): ) def forward(self, state): - reward = self.reward_model(state).squeeze(dim=1) - return reward - + reward = self.reward_model(state) + return torch.distributions.independent.Independent( + torch.distributions.Normal(reward, 1), 1) class TransitionModel(nn.Module): def __init__(self, state_size, hidden_size, action_size, history_size): @@ -310,7 +311,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def get_mu_var_samples(self, state_dict): dist = state_dict["distribution"] - sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses + sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses mu = dist.mean var = dist.variance return mu, var, sample @@ -318,7 +319,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def loglikeli(self): _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states) - logvar_curr = torch.log(var_curr) + logvar_curr = torch.log(var_curr) return (-(mu_curr - pred_sample)**2 /var_curr-logvar_curr).sum(dim=1).mean(dim=0) def forward(self): @@ -326,8 +327,12 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states) mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states) + sample_size = pred_sample.shape[0] + random_index = torch.randperm(sample_size).long() + pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0) - neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) + neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0) + #neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) upper_bound = pos - neg return upper_bound/2