diff --git a/DPI/models.py b/DPI/models.py index f2d1d4c..72bb7b4 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -258,4 +258,40 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def forward(self, x_samples, y_samples): mu, logvar = self.get_mu_logvar(x_samples) - return - self.loglikeli(x_samples, y_samples) \ No newline at end of file + return - self.loglikeli(x_samples, y_samples) + + +class ProjectionHead(nn.Module): + def __init__(self, state_size, action_size, hidden_size): + super(ProjectionHead, self).__init__() + self.state_size = state_size + self.action_size = action_size + self.hidden_size = hidden_size + + self.projection_model = nn.Sequential( + nn.Linear(state_size + action_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + ) + + def forward(self, state, action): + x = torch.cat([state, action], dim=-1) + x = self.projection_model(x) + return x + + +class ContrastiveHead(nn.Module): + def __init__(self, hidden_size, temperature=1): + super(ContrastiveHead, self).__init__() + self.hidden_size = hidden_size + self.temperature = temperature + self.W = nn.Parameter(torch.rand(self.hidden_size, self.hidden_size)) + + def forward(self, z_a, z_pos): + Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) + logits = torch.matmul(z_a, Wz) # (B,B) + logits = logits - torch.max(logits, 1)[0][:, None] + logits = logits * self.temperature + return logits \ No newline at end of file