Changing CLUB loss and Tensor stacking

This commit is contained in:
Vedant Dave 2023-04-09 18:23:16 +02:00
parent 6b4762d5fc
commit c4283ced6f

View File

@ -194,6 +194,27 @@ class TransitionModel(nn.Module):
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
return prior
def stack_states(self, states, dim=0):
s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim),
sample = torch.stack([state['sample'] for state in states], dim=dim),
history = torch.stack([state['history'] for state in states], dim=dim),)
dist = dict(distribution = [state['distribution'] for state in states])
s.update(dist)
return s
def imagine_rollout(self, state, action, history, horizon):
imagined_priors = []
for i in range(horizon):
prior = self.imagine_step(state, action, history)
state = prior["sample"]
history = prior["history"]
imagined_priors.append(prior)
imagined_priors = self.stack_states(imagined_priors, dim=0)
return imagined_priors
def reparemeterize(self, mean, std):
eps = torch.randn_like(std)
return mean + eps * std
@ -227,40 +248,6 @@ class TanhBijector(torch.distributions.Transform):
return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x))
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, x_dim, y_dim, hidden_size):
super(CLUBSample, self).__init__()
self.p_mu = nn.Sequential(
nn.Linear(x_dim, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim)
)
self.p_logvar = nn.Sequential(
nn.Linear(x_dim, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim),
nn.Tanh()
)
def get_mu_logvar(self, x_samples):
mu = self.p_mu(x_samples)
logvar = self.p_logvar(x_samples)
return mu, logvar
def loglikeli(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples)
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
def forward(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples)
return - self.loglikeli(x_samples, y_samples)
class ProjectionHead(nn.Module):
def __init__(self, state_size, action_size, hidden_size):
super(ProjectionHead, self).__init__()
@ -295,3 +282,43 @@ class ContrastiveHead(nn.Module):
logits = logits - torch.max(logits, 1)[0][:, None]
logits = logits * self.temperature
return logits
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, last_states, current_states, negative_current_states, predicted_current_states):
super(CLUBSample, self).__init__()
self.last_states = last_states
self.current_states = current_states
self.negative_current_states = negative_current_states
self.predicted_current_states = predicted_current_states
def get_mu_var_samples(self, state_dict):
dist = state_dict["distribution"]
sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
mu = dist.mean
var = dist.variance
return mu, var, sample
def loglikeli(self):
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
logvar_curr = torch.log(var_curr)
return (-(mu_curr - pred_sample)**2 /var_curr-logvar_curr).sum(dim=1).mean(dim=0)
def forward(self):
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states)
mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states)
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
upper_bound = pos - neg
return upper_bound/2
def learning_loss(self):
return - self.loglikeli()
if "__name__ == __main__":
pass