File formatting
This commit is contained in:
parent
4515c6a6b7
commit
ab2b6599c1
@ -69,7 +69,6 @@ class ObservationDecoder(nn.Module):
|
|||||||
layers.append(nn.ReLU())
|
layers.append(nn.ReLU())
|
||||||
|
|
||||||
self.convtranspose = nn.Sequential(*layers)
|
self.convtranspose = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, features):
|
def forward(self, features):
|
||||||
out_batch_shape = features.shape[:-1]
|
out_batch_shape = features.shape[:-1]
|
||||||
@ -102,7 +101,6 @@ class TransitionModel(nn.Module):
|
|||||||
self.prev_action = torch.zeros(batch_size, self.action_size).to(device)
|
self.prev_action = torch.zeros(batch_size, self.action_size).to(device)
|
||||||
self.prev_history = torch.zeros(batch_size, self.history_size).to(device)
|
self.prev_history = torch.zeros(batch_size, self.history_size).to(device)
|
||||||
|
|
||||||
|
|
||||||
def get_dist(self, mean, std):
|
def get_dist(self, mean, std):
|
||||||
distribution = torch.distributions.Normal(mean, std)
|
distribution = torch.distributions.Normal(mean, std)
|
||||||
distribution = torch.distributions.independent.Independent(distribution, 1)
|
distribution = torch.distributions.independent.Independent(distribution, 1)
|
||||||
@ -167,23 +165,4 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
return upper_bound/2.
|
return upper_bound/2.
|
||||||
|
|
||||||
def learning_loss(self, x_samples, y_samples):
|
def learning_loss(self, x_samples, y_samples):
|
||||||
return - self.loglikeli(x_samples, y_samples)
|
return - self.loglikeli(x_samples, y_samples)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
encoder = ObservationEncoder((12,84,84), 256)
|
|
||||||
x = torch.randn(5000, 12, 84, 84)
|
|
||||||
print(encoder(x).shape)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
club = CLUBSample(256, 256 , 512)
|
|
||||||
x = torch.randn(100, 256)
|
|
||||||
y = torch.randn(100, 256)
|
|
||||||
print(club.learning_loss(x, y))
|
|
||||||
|
|
||||||
x = torch.randn(100, 12, 84, 84)
|
|
||||||
y = torch.randn(100, 12, 84, 84)
|
|
||||||
x_enc = encoder(x)
|
|
||||||
y_enc = encoder(y)
|
|
||||||
print(x_enc.shape)
|
|
||||||
print(y_enc.shape)
|
|
||||||
print(club.learning_loss(x_enc, y_enc))
|
|
Loading…
Reference in New Issue
Block a user