Keeping channels as 3
This commit is contained in:
parent
1f4667a08d
commit
a83149f61e
@ -88,7 +88,6 @@ class ObservationDecoder(nn.Module):
|
||||
out = self.dense(features)
|
||||
out = torch.reshape(out, [-1, self.input_size, 1, 1])
|
||||
out = self.convtranspose(out)
|
||||
|
||||
mean = torch.reshape(out, (*out_batch_shape, *self.output_shape))
|
||||
out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape))
|
||||
return out_dist
|
||||
|
Loading…
Reference in New Issue
Block a user