diff --git a/DPI/models.py b/DPI/models.py index ec45b81..bb4c391 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -88,7 +88,6 @@ class ObservationDecoder(nn.Module): out = self.dense(features) out = torch.reshape(out, [-1, self.input_size, 1, 1]) out = self.convtranspose(out) - mean = torch.reshape(out, (*out_batch_shape, *self.output_shape)) out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape)) return out_dist