From a83149f61ecc9b33a995031ee94e734575546ecd Mon Sep 17 00:00:00 2001 From: VedantDave Date: Wed, 12 Apr 2023 17:29:50 +0200 Subject: [PATCH] Keeping channels as 3 --- DPI/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/DPI/models.py b/DPI/models.py index ec45b81..bb4c391 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -88,7 +88,6 @@ class ObservationDecoder(nn.Module): out = self.dense(features) out = torch.reshape(out, [-1, self.input_size, 1, 1]) out = self.convtranspose(out) - mean = torch.reshape(out, (*out_batch_shape, *self.output_shape)) out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape)) return out_dist