File formatting

This commit is contained in:
Vedant Dave 2023-03-25 17:51:34 +01:00
parent 4515c6a6b7
commit ab2b6599c1

View File

@ -70,7 +70,6 @@ class ObservationDecoder(nn.Module):
self.convtranspose = nn.Sequential(*layers)
def forward(self, features):
out_batch_shape = features.shape[:-1]
out = self.dense(features)
@ -102,7 +101,6 @@ class TransitionModel(nn.Module):
self.prev_action = torch.zeros(batch_size, self.action_size).to(device)
self.prev_history = torch.zeros(batch_size, self.history_size).to(device)
def get_dist(self, mean, std):
distribution = torch.distributions.Normal(mean, std)
distribution = torch.distributions.independent.Independent(distribution, 1)
@ -168,22 +166,3 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def learning_loss(self, x_samples, y_samples):
return - self.loglikeli(x_samples, y_samples)
if __name__ == "__main__":
encoder = ObservationEncoder((12,84,84), 256)
x = torch.randn(5000, 12, 84, 84)
print(encoder(x).shape)
exit()
club = CLUBSample(256, 256 , 512)
x = torch.randn(100, 256)
y = torch.randn(100, 256)
print(club.learning_loss(x, y))
x = torch.randn(100, 12, 84, 84)
y = torch.randn(100, 12, 84, 84)
x_enc = encoder(x)
y_enc = encoder(y)
print(x_enc.shape)
print(y_enc.shape)
print(club.learning_loss(x_enc, y_enc))