Update models to give distribution as well in the output
This commit is contained in:
parent
a351134f08
commit
38cc645253
@ -34,12 +34,22 @@ class ObservationEncoder(nn.Module):
|
|||||||
std = torch.clamp(std, min=0.0, max=1e5)
|
std = torch.clamp(std, min=0.0, max=1e5)
|
||||||
|
|
||||||
# Normal Distribution
|
# Normal Distribution
|
||||||
|
dist = self.get_dist(mean, std)
|
||||||
|
|
||||||
|
# Sampling via reparameterization Trick
|
||||||
x = self.reparameterize(mean, std)
|
x = self.reparameterize(mean, std)
|
||||||
return x
|
|
||||||
|
encoded_output = {"sample": x, "distribution": dist}
|
||||||
|
return encoded_output
|
||||||
|
|
||||||
def reparameterize(self, mu, std):
|
def reparameterize(self, mu, std):
|
||||||
eps = torch.randn_like(std)
|
eps = torch.randn_like(std)
|
||||||
return mu + eps * std
|
return mu + eps * std
|
||||||
|
|
||||||
|
def get_dist(self, mean, std):
|
||||||
|
distribution = torch.distributions.Normal(mean, std)
|
||||||
|
distribution = torch.distributions.independent.Independent(distribution, 1)
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
class ObservationDecoder(nn.Module):
|
class ObservationDecoder(nn.Module):
|
||||||
@ -114,8 +124,12 @@ class TransitionModel(nn.Module):
|
|||||||
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
|
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
|
||||||
state_prior_std = F.softplus(state_prior_std)
|
state_prior_std = F.softplus(state_prior_std)
|
||||||
|
|
||||||
|
# Normal Distribution
|
||||||
|
state_prior_dist = self.get_dist(state_prior_mean, state_prior_std)
|
||||||
|
|
||||||
|
# Sampling via reparameterization Trick
|
||||||
sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std)
|
sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std)
|
||||||
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history}
|
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
|
||||||
return prior
|
return prior
|
||||||
|
|
||||||
def reparemeterize(self, mean, std):
|
def reparemeterize(self, mean, std):
|
||||||
@ -154,15 +168,4 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
|
|
||||||
def forward(self, x_samples, y_samples):
|
def forward(self, x_samples, y_samples):
|
||||||
mu, logvar = self.get_mu_logvar(x_samples)
|
mu, logvar = self.get_mu_logvar(x_samples)
|
||||||
|
|
||||||
sample_size = x_samples.shape[0]
|
|
||||||
#random_index = torch.randint(sample_size, (sample_size,)).long()
|
|
||||||
random_index = torch.randperm(sample_size).long()
|
|
||||||
|
|
||||||
positive = - (mu - y_samples)**2 / logvar.exp()
|
|
||||||
negative = - (mu - y_samples[random_index])**2 / logvar.exp()
|
|
||||||
upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()
|
|
||||||
return upper_bound/2.
|
|
||||||
|
|
||||||
def learning_loss(self, x_samples, y_samples):
|
|
||||||
return - self.loglikeli(x_samples, y_samples)
|
return - self.loglikeli(x_samples, y_samples)
|
Loading…
Reference in New Issue
Block a user