Correcting UB loss

This commit is contained in:
Vedant Dave 2023-04-12 09:34:11 +02:00
parent 5ded7bc8f1
commit 9085abe684

View File

@ -39,6 +39,7 @@ class ObservationEncoder(nn.Module):
dist = self.get_dist(mean, std)
# Sampling via reparameterization Trick
#x = dist.rsample()
x = self.reparameterize(mean, std)
encoded_output = {"sample": x, "distribution": dist}
@ -165,9 +166,9 @@ class RewardModel(nn.Module):
)
def forward(self, state):
reward = self.reward_model(state).squeeze(dim=1)
return reward
reward = self.reward_model(state)
return torch.distributions.independent.Independent(
torch.distributions.Normal(reward, 1), 1)
class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size):
@ -310,7 +311,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def get_mu_var_samples(self, state_dict):
dist = state_dict["distribution"]
sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
mu = dist.mean
var = dist.variance
return mu, var, sample
@ -318,7 +319,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def loglikeli(self):
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
logvar_curr = torch.log(var_curr)
logvar_curr = torch.log(var_curr)
return (-(mu_curr - pred_sample)**2 /var_curr-logvar_curr).sum(dim=1).mean(dim=0)
def forward(self):
@ -326,8 +327,12 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states)
sample_size = pred_sample.shape[0]
random_index = torch.randperm(sample_size).long()
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0)
#neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
upper_bound = pos - neg
return upper_bound/2