From ab2b6599c12b6e8a231a8fb562a6ccb138e485ba Mon Sep 17 00:00:00 2001 From: VedantDave Date: Sat, 25 Mar 2023 17:51:34 +0100 Subject: [PATCH] File formatting --- DPI/models.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/DPI/models.py b/DPI/models.py index 012dc4c..d05e7ec 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -69,7 +69,6 @@ class ObservationDecoder(nn.Module): layers.append(nn.ReLU()) self.convtranspose = nn.Sequential(*layers) - def forward(self, features): out_batch_shape = features.shape[:-1] @@ -102,7 +101,6 @@ class TransitionModel(nn.Module): self.prev_action = torch.zeros(batch_size, self.action_size).to(device) self.prev_history = torch.zeros(batch_size, self.history_size).to(device) - def get_dist(self, mean, std): distribution = torch.distributions.Normal(mean, std) distribution = torch.distributions.independent.Independent(distribution, 1) @@ -167,23 +165,4 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator return upper_bound/2. def learning_loss(self, x_samples, y_samples): - return - self.loglikeli(x_samples, y_samples) - -if __name__ == "__main__": - encoder = ObservationEncoder((12,84,84), 256) - x = torch.randn(5000, 12, 84, 84) - print(encoder(x).shape) - exit() - - club = CLUBSample(256, 256 , 512) - x = torch.randn(100, 256) - y = torch.randn(100, 256) - print(club.learning_loss(x, y)) - - x = torch.randn(100, 12, 84, 84) - y = torch.randn(100, 12, 84, 84) - x_enc = encoder(x) - y_enc = encoder(y) - print(x_enc.shape) - print(y_enc.shape) - print(club.learning_loss(x_enc, y_enc)) \ No newline at end of file + return - self.loglikeli(x_samples, y_samples) \ No newline at end of file