Changing CLUB loss and Tensor stacking
This commit is contained in:
parent
6b4762d5fc
commit
c4283ced6f
@ -194,6 +194,27 @@ class TransitionModel(nn.Module):
|
|||||||
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
|
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
|
||||||
return prior
|
return prior
|
||||||
|
|
||||||
|
def stack_states(self, states, dim=0):
|
||||||
|
|
||||||
|
s = dict(
|
||||||
|
mean = torch.stack([state['mean'] for state in states], dim=dim),
|
||||||
|
std = torch.stack([state['std'] for state in states], dim=dim),
|
||||||
|
sample = torch.stack([state['sample'] for state in states], dim=dim),
|
||||||
|
history = torch.stack([state['history'] for state in states], dim=dim),)
|
||||||
|
dist = dict(distribution = [state['distribution'] for state in states])
|
||||||
|
s.update(dist)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def imagine_rollout(self, state, action, history, horizon):
|
||||||
|
imagined_priors = []
|
||||||
|
for i in range(horizon):
|
||||||
|
prior = self.imagine_step(state, action, history)
|
||||||
|
state = prior["sample"]
|
||||||
|
history = prior["history"]
|
||||||
|
imagined_priors.append(prior)
|
||||||
|
imagined_priors = self.stack_states(imagined_priors, dim=0)
|
||||||
|
return imagined_priors
|
||||||
|
|
||||||
def reparemeterize(self, mean, std):
|
def reparemeterize(self, mean, std):
|
||||||
eps = torch.randn_like(std)
|
eps = torch.randn_like(std)
|
||||||
return mean + eps * std
|
return mean + eps * std
|
||||||
@ -227,40 +248,6 @@ class TanhBijector(torch.distributions.Transform):
|
|||||||
return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x))
|
return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x))
|
||||||
|
|
||||||
|
|
||||||
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|
||||||
def __init__(self, x_dim, y_dim, hidden_size):
|
|
||||||
super(CLUBSample, self).__init__()
|
|
||||||
self.p_mu = nn.Sequential(
|
|
||||||
nn.Linear(x_dim, hidden_size//2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size//2, hidden_size//2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size//2, y_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.p_logvar = nn.Sequential(
|
|
||||||
nn.Linear(x_dim, hidden_size//2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size//2, hidden_size//2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size//2, y_dim),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_mu_logvar(self, x_samples):
|
|
||||||
mu = self.p_mu(x_samples)
|
|
||||||
logvar = self.p_logvar(x_samples)
|
|
||||||
return mu, logvar
|
|
||||||
|
|
||||||
def loglikeli(self, x_samples, y_samples):
|
|
||||||
mu, logvar = self.get_mu_logvar(x_samples)
|
|
||||||
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
|
|
||||||
|
|
||||||
def forward(self, x_samples, y_samples):
|
|
||||||
mu, logvar = self.get_mu_logvar(x_samples)
|
|
||||||
return - self.loglikeli(x_samples, y_samples)
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectionHead(nn.Module):
|
class ProjectionHead(nn.Module):
|
||||||
def __init__(self, state_size, action_size, hidden_size):
|
def __init__(self, state_size, action_size, hidden_size):
|
||||||
super(ProjectionHead, self).__init__()
|
super(ProjectionHead, self).__init__()
|
||||||
@ -294,4 +281,44 @@ class ContrastiveHead(nn.Module):
|
|||||||
logits = torch.matmul(z_a, Wz) # (B,B)
|
logits = torch.matmul(z_a, Wz) # (B,B)
|
||||||
logits = logits - torch.max(logits, 1)[0][:, None]
|
logits = logits - torch.max(logits, 1)[0][:, None]
|
||||||
logits = logits * self.temperature
|
logits = logits * self.temperature
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||||
|
def __init__(self, last_states, current_states, negative_current_states, predicted_current_states):
|
||||||
|
super(CLUBSample, self).__init__()
|
||||||
|
self.last_states = last_states
|
||||||
|
self.current_states = current_states
|
||||||
|
self.negative_current_states = negative_current_states
|
||||||
|
self.predicted_current_states = predicted_current_states
|
||||||
|
|
||||||
|
def get_mu_var_samples(self, state_dict):
|
||||||
|
dist = state_dict["distribution"]
|
||||||
|
sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
|
||||||
|
mu = dist.mean
|
||||||
|
var = dist.variance
|
||||||
|
return mu, var, sample
|
||||||
|
|
||||||
|
def loglikeli(self):
|
||||||
|
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
|
||||||
|
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
|
||||||
|
logvar_curr = torch.log(var_curr)
|
||||||
|
return (-(mu_curr - pred_sample)**2 /var_curr-logvar_curr).sum(dim=1).mean(dim=0)
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
|
||||||
|
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
|
||||||
|
mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states)
|
||||||
|
|
||||||
|
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
|
||||||
|
neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
|
||||||
|
upper_bound = pos - neg
|
||||||
|
|
||||||
|
return upper_bound/2
|
||||||
|
|
||||||
|
def learning_loss(self):
|
||||||
|
return - self.loglikeli()
|
||||||
|
|
||||||
|
|
||||||
|
if "__name__ == __main__":
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user