Compare commits

..

2 Commits

Author SHA1 Message Date
ada3cadf0c Adding momentum encoder 2023-04-02 18:52:46 +02:00
d9d350e191 Adding Contrastive learning models 2023-04-02 18:52:26 +02:00
2 changed files with 42 additions and 6 deletions

View File

@ -258,4 +258,40 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def forward(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples)
return - self.loglikeli(x_samples, y_samples)
return - self.loglikeli(x_samples, y_samples)
class ProjectionHead(nn.Module):
def __init__(self, state_size, action_size, hidden_size):
super(ProjectionHead, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.hidden_size = hidden_size
self.projection_model = nn.Sequential(
nn.Linear(state_size + action_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
x = self.projection_model(x)
return x
class ContrastiveHead(nn.Module):
def __init__(self, hidden_size, temperature=1):
super(ContrastiveHead, self).__init__()
self.hidden_size = hidden_size
self.temperature = temperature
self.W = nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
def forward(self, z_a, z_pos):
Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B)
logits = torch.matmul(z_a, Wz) # (B,B)
logits = logits - torch.max(logits, 1)[0][:, None]
logits = logits * self.temperature
return logits

View File

@ -155,17 +155,17 @@ class DPI:
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
)
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256