From d9d350e191ecafde7eea439e3ee4b6ad357cd949 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Sun, 2 Apr 2023 18:52:26 +0200 Subject: [PATCH] Adding Contrastive learning models --- DPI/models.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/DPI/models.py b/DPI/models.py index f2d1d4c..72bb7b4 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -258,4 +258,40 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def forward(self, x_samples, y_samples): mu, logvar = self.get_mu_logvar(x_samples) - return - self.loglikeli(x_samples, y_samples) \ No newline at end of file + return - self.loglikeli(x_samples, y_samples) + + +class ProjectionHead(nn.Module): + def __init__(self, state_size, action_size, hidden_size): + super(ProjectionHead, self).__init__() + self.state_size = state_size + self.action_size = action_size + self.hidden_size = hidden_size + + self.projection_model = nn.Sequential( + nn.Linear(state_size + action_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + ) + + def forward(self, state, action): + x = torch.cat([state, action], dim=-1) + x = self.projection_model(x) + return x + + +class ContrastiveHead(nn.Module): + def __init__(self, hidden_size, temperature=1): + super(ContrastiveHead, self).__init__() + self.hidden_size = hidden_size + self.temperature = temperature + self.W = nn.Parameter(torch.rand(self.hidden_size, self.hidden_size)) + + def forward(self, z_a, z_pos): + Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) + logits = torch.matmul(z_a, Wz) # (B,B) + logits = logits - torch.max(logits, 1)[0][:, None] + logits = logits * self.temperature + return logits \ No newline at end of file