Compare commits
No commits in common. "ada3cadf0c1e1242114befb688db8e0d5dc0d4b9" and "7c9e75030b54020ff7fe79b6533580b9983aa04f" have entirely different histories.
ada3cadf0c
...
7c9e75030b
@ -259,39 +259,3 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
def forward(self, x_samples, y_samples):
|
def forward(self, x_samples, y_samples):
|
||||||
mu, logvar = self.get_mu_logvar(x_samples)
|
mu, logvar = self.get_mu_logvar(x_samples)
|
||||||
return - self.loglikeli(x_samples, y_samples)
|
return - self.loglikeli(x_samples, y_samples)
|
||||||
|
|
||||||
|
|
||||||
class ProjectionHead(nn.Module):
|
|
||||||
def __init__(self, state_size, action_size, hidden_size):
|
|
||||||
super(ProjectionHead, self).__init__()
|
|
||||||
self.state_size = state_size
|
|
||||||
self.action_size = action_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
|
|
||||||
self.projection_model = nn.Sequential(
|
|
||||||
nn.Linear(state_size + action_size, hidden_size),
|
|
||||||
nn.LayerNorm(hidden_size),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size),
|
|
||||||
nn.LayerNorm(hidden_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, state, action):
|
|
||||||
x = torch.cat([state, action], dim=-1)
|
|
||||||
x = self.projection_model(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ContrastiveHead(nn.Module):
|
|
||||||
def __init__(self, hidden_size, temperature=1):
|
|
||||||
super(ContrastiveHead, self).__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.temperature = temperature
|
|
||||||
self.W = nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
|
|
||||||
|
|
||||||
def forward(self, z_a, z_pos):
|
|
||||||
Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B)
|
|
||||||
logits = torch.matmul(z_a, Wz) # (B,B)
|
|
||||||
logits = logits - torch.max(logits, 1)[0][:, None]
|
|
||||||
logits = logits * self.temperature
|
|
||||||
return logits
|
|
10
DPI/train.py
10
DPI/train.py
@ -156,16 +156,16 @@ class DPI:
|
|||||||
state_size=self.args.state_size # 128
|
state_size=self.args.state_size # 128
|
||||||
)
|
)
|
||||||
|
|
||||||
self.obs_encoder_momentum = ObservationEncoder(
|
|
||||||
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
|
|
||||||
state_size=self.args.state_size # 128
|
|
||||||
)
|
|
||||||
|
|
||||||
self.obs_decoder = ObservationDecoder(
|
self.obs_decoder = ObservationDecoder(
|
||||||
state_size=self.args.state_size, # 128
|
state_size=self.args.state_size, # 128
|
||||||
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
|
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.obs_encoder_momentum = ObservationEncoder(
|
||||||
|
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
|
||||||
|
state_size=self.args.state_size # 128
|
||||||
|
)
|
||||||
|
|
||||||
self.transition_model = TransitionModel(
|
self.transition_model = TransitionModel(
|
||||||
state_size=self.args.state_size, # 128
|
state_size=self.args.state_size, # 128
|
||||||
hidden_size=self.args.hidden_size, # 256
|
hidden_size=self.args.hidden_size, # 256
|
||||||
|
Loading…
Reference in New Issue
Block a user