Changing CLUB loss and Tensor stacking
This commit is contained in:
parent
6b4762d5fc
commit
c4283ced6f
@ -194,6 +194,27 @@ class TransitionModel(nn.Module):
|
||||
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
|
||||
return prior
|
||||
|
||||
def stack_states(self, states, dim=0):
|
||||
|
||||
s = dict(
|
||||
mean = torch.stack([state['mean'] for state in states], dim=dim),
|
||||
std = torch.stack([state['std'] for state in states], dim=dim),
|
||||
sample = torch.stack([state['sample'] for state in states], dim=dim),
|
||||
history = torch.stack([state['history'] for state in states], dim=dim),)
|
||||
dist = dict(distribution = [state['distribution'] for state in states])
|
||||
s.update(dist)
|
||||
return s
|
||||
|
||||
def imagine_rollout(self, state, action, history, horizon):
|
||||
imagined_priors = []
|
||||
for i in range(horizon):
|
||||
prior = self.imagine_step(state, action, history)
|
||||
state = prior["sample"]
|
||||
history = prior["history"]
|
||||
imagined_priors.append(prior)
|
||||
imagined_priors = self.stack_states(imagined_priors, dim=0)
|
||||
return imagined_priors
|
||||
|
||||
def reparemeterize(self, mean, std):
|
||||
eps = torch.randn_like(std)
|
||||
return mean + eps * std
|
||||
@ -227,40 +248,6 @@ class TanhBijector(torch.distributions.Transform):
|
||||
return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x))
|
||||
|
||||
|
||||
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||
def __init__(self, x_dim, y_dim, hidden_size):
|
||||
super(CLUBSample, self).__init__()
|
||||
self.p_mu = nn.Sequential(
|
||||
nn.Linear(x_dim, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, y_dim)
|
||||
)
|
||||
|
||||
self.p_logvar = nn.Sequential(
|
||||
nn.Linear(x_dim, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, y_dim),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def get_mu_logvar(self, x_samples):
|
||||
mu = self.p_mu(x_samples)
|
||||
logvar = self.p_logvar(x_samples)
|
||||
return mu, logvar
|
||||
|
||||
def loglikeli(self, x_samples, y_samples):
|
||||
mu, logvar = self.get_mu_logvar(x_samples)
|
||||
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
|
||||
|
||||
def forward(self, x_samples, y_samples):
|
||||
mu, logvar = self.get_mu_logvar(x_samples)
|
||||
return - self.loglikeli(x_samples, y_samples)
|
||||
|
||||
|
||||
class ProjectionHead(nn.Module):
|
||||
def __init__(self, state_size, action_size, hidden_size):
|
||||
super(ProjectionHead, self).__init__()
|
||||
@ -294,4 +281,44 @@ class ContrastiveHead(nn.Module):
|
||||
logits = torch.matmul(z_a, Wz) # (B,B)
|
||||
logits = logits - torch.max(logits, 1)[0][:, None]
|
||||
logits = logits * self.temperature
|
||||
return logits
|
||||
return logits
|
||||
|
||||
|
||||
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||
def __init__(self, last_states, current_states, negative_current_states, predicted_current_states):
|
||||
super(CLUBSample, self).__init__()
|
||||
self.last_states = last_states
|
||||
self.current_states = current_states
|
||||
self.negative_current_states = negative_current_states
|
||||
self.predicted_current_states = predicted_current_states
|
||||
|
||||
def get_mu_var_samples(self, state_dict):
|
||||
dist = state_dict["distribution"]
|
||||
sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
|
||||
mu = dist.mean
|
||||
var = dist.variance
|
||||
return mu, var, sample
|
||||
|
||||
def loglikeli(self):
|
||||
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
|
||||
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
|
||||
logvar_curr = torch.log(var_curr)
|
||||
return (-(mu_curr - pred_sample)**2 /var_curr-logvar_curr).sum(dim=1).mean(dim=0)
|
||||
|
||||
def forward(self):
|
||||
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
|
||||
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
|
||||
mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states)
|
||||
|
||||
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
|
||||
neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
|
||||
upper_bound = pos - neg
|
||||
|
||||
return upper_bound/2
|
||||
|
||||
def learning_loss(self):
|
||||
return - self.loglikeli()
|
||||
|
||||
|
||||
if "__name__ == __main__":
|
||||
pass
|
Loading…
Reference in New Issue
Block a user