Keeping channels as 3

This commit is contained in:
Vedant Dave 2023-04-12 17:29:50 +02:00
parent 1f4667a08d
commit a83149f61e

View File

@ -88,7 +88,6 @@ class ObservationDecoder(nn.Module):
out = self.dense(features) out = self.dense(features)
out = torch.reshape(out, [-1, self.input_size, 1, 1]) out = torch.reshape(out, [-1, self.input_size, 1, 1])
out = self.convtranspose(out) out = self.convtranspose(out)
mean = torch.reshape(out, (*out_batch_shape, *self.output_shape)) mean = torch.reshape(out, (*out_batch_shape, *self.output_shape))
out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape)) out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape))
return out_dist return out_dist