Correcting UB loss
This commit is contained in:
parent
5ded7bc8f1
commit
9085abe684
@ -39,6 +39,7 @@ class ObservationEncoder(nn.Module):
|
|||||||
dist = self.get_dist(mean, std)
|
dist = self.get_dist(mean, std)
|
||||||
|
|
||||||
# Sampling via reparameterization Trick
|
# Sampling via reparameterization Trick
|
||||||
|
#x = dist.rsample()
|
||||||
x = self.reparameterize(mean, std)
|
x = self.reparameterize(mean, std)
|
||||||
|
|
||||||
encoded_output = {"sample": x, "distribution": dist}
|
encoded_output = {"sample": x, "distribution": dist}
|
||||||
@ -165,9 +166,9 @@ class RewardModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
reward = self.reward_model(state).squeeze(dim=1)
|
reward = self.reward_model(state)
|
||||||
return reward
|
return torch.distributions.independent.Independent(
|
||||||
|
torch.distributions.Normal(reward, 1), 1)
|
||||||
|
|
||||||
class TransitionModel(nn.Module):
|
class TransitionModel(nn.Module):
|
||||||
def __init__(self, state_size, hidden_size, action_size, history_size):
|
def __init__(self, state_size, hidden_size, action_size, history_size):
|
||||||
@ -310,7 +311,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
|
|
||||||
def get_mu_var_samples(self, state_dict):
|
def get_mu_var_samples(self, state_dict):
|
||||||
dist = state_dict["distribution"]
|
dist = state_dict["distribution"]
|
||||||
sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
|
sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
|
||||||
mu = dist.mean
|
mu = dist.mean
|
||||||
var = dist.variance
|
var = dist.variance
|
||||||
return mu, var, sample
|
return mu, var, sample
|
||||||
@ -326,8 +327,12 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
|
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
|
||||||
mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states)
|
mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states)
|
||||||
|
|
||||||
|
sample_size = pred_sample.shape[0]
|
||||||
|
random_index = torch.randperm(sample_size).long()
|
||||||
|
|
||||||
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
|
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
|
||||||
neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
|
neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0)
|
||||||
|
#neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
|
||||||
upper_bound = pos - neg
|
upper_bound = pos - neg
|
||||||
|
|
||||||
return upper_bound/2
|
return upper_bound/2
|
||||||
|
Loading…
Reference in New Issue
Block a user