diff --git a/DPI/models.py b/DPI/models.py index 72bb7b4..2158ea7 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -194,6 +194,27 @@ class TransitionModel(nn.Module): prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} return prior + def stack_states(self, states, dim=0): + + s = dict( + mean = torch.stack([state['mean'] for state in states], dim=dim), + std = torch.stack([state['std'] for state in states], dim=dim), + sample = torch.stack([state['sample'] for state in states], dim=dim), + history = torch.stack([state['history'] for state in states], dim=dim),) + dist = dict(distribution = [state['distribution'] for state in states]) + s.update(dist) + return s + + def imagine_rollout(self, state, action, history, horizon): + imagined_priors = [] + for i in range(horizon): + prior = self.imagine_step(state, action, history) + state = prior["sample"] + history = prior["history"] + imagined_priors.append(prior) + imagined_priors = self.stack_states(imagined_priors, dim=0) + return imagined_priors + def reparemeterize(self, mean, std): eps = torch.randn_like(std) return mean + eps * std @@ -227,40 +248,6 @@ class TanhBijector(torch.distributions.Transform): return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x)) -class CLUBSample(nn.Module): # Sampled version of the CLUB estimator - def __init__(self, x_dim, y_dim, hidden_size): - super(CLUBSample, self).__init__() - self.p_mu = nn.Sequential( - nn.Linear(x_dim, hidden_size//2), - nn.ReLU(), - nn.Linear(hidden_size//2, hidden_size//2), - nn.ReLU(), - nn.Linear(hidden_size//2, y_dim) - ) - - self.p_logvar = nn.Sequential( - nn.Linear(x_dim, hidden_size//2), - nn.ReLU(), - nn.Linear(hidden_size//2, hidden_size//2), - nn.ReLU(), - nn.Linear(hidden_size//2, y_dim), - nn.Tanh() - ) - - def get_mu_logvar(self, x_samples): - mu = self.p_mu(x_samples) - logvar = self.p_logvar(x_samples) - return mu, logvar - - def loglikeli(self, x_samples, y_samples): - mu, logvar = self.get_mu_logvar(x_samples) - return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0) - - def forward(self, x_samples, y_samples): - mu, logvar = self.get_mu_logvar(x_samples) - return - self.loglikeli(x_samples, y_samples) - - class ProjectionHead(nn.Module): def __init__(self, state_size, action_size, hidden_size): super(ProjectionHead, self).__init__() @@ -294,4 +281,44 @@ class ContrastiveHead(nn.Module): logits = torch.matmul(z_a, Wz) # (B,B) logits = logits - torch.max(logits, 1)[0][:, None] logits = logits * self.temperature - return logits \ No newline at end of file + return logits + + +class CLUBSample(nn.Module): # Sampled version of the CLUB estimator + def __init__(self, last_states, current_states, negative_current_states, predicted_current_states): + super(CLUBSample, self).__init__() + self.last_states = last_states + self.current_states = current_states + self.negative_current_states = negative_current_states + self.predicted_current_states = predicted_current_states + + def get_mu_var_samples(self, state_dict): + dist = state_dict["distribution"] + sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses + mu = dist.mean + var = dist.variance + return mu, var, sample + + def loglikeli(self): + _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) + mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states) + logvar_curr = torch.log(var_curr) + return (-(mu_curr - pred_sample)**2 /var_curr-logvar_curr).sum(dim=1).mean(dim=0) + + def forward(self): + _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) + mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states) + mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states) + + pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0) + neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) + upper_bound = pos - neg + + return upper_bound/2 + + def learning_loss(self): + return - self.loglikeli() + + +if "__name__ == __main__": + pass \ No newline at end of file