Correcting UB loss
This commit is contained in:
parent
5ded7bc8f1
commit
9085abe684
@ -39,6 +39,7 @@ class ObservationEncoder(nn.Module):
|
||||
dist = self.get_dist(mean, std)
|
||||
|
||||
# Sampling via reparameterization Trick
|
||||
#x = dist.rsample()
|
||||
x = self.reparameterize(mean, std)
|
||||
|
||||
encoded_output = {"sample": x, "distribution": dist}
|
||||
@ -165,9 +166,9 @@ class RewardModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
reward = self.reward_model(state).squeeze(dim=1)
|
||||
return reward
|
||||
|
||||
reward = self.reward_model(state)
|
||||
return torch.distributions.independent.Independent(
|
||||
torch.distributions.Normal(reward, 1), 1)
|
||||
|
||||
class TransitionModel(nn.Module):
|
||||
def __init__(self, state_size, hidden_size, action_size, history_size):
|
||||
@ -310,7 +311,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||
|
||||
def get_mu_var_samples(self, state_dict):
|
||||
dist = state_dict["distribution"]
|
||||
sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
|
||||
sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
|
||||
mu = dist.mean
|
||||
var = dist.variance
|
||||
return mu, var, sample
|
||||
@ -326,8 +327,12 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
|
||||
mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states)
|
||||
|
||||
sample_size = pred_sample.shape[0]
|
||||
random_index = torch.randperm(sample_size).long()
|
||||
|
||||
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
|
||||
neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
|
||||
neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0)
|
||||
#neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
|
||||
upper_bound = pos - neg
|
||||
|
||||
return upper_bound/2
|
||||
|
Loading…
Reference in New Issue
Block a user