diff --git a/PolicyModel/GaussianModelMultiDim.py b/PolicyModel/GaussianModelMultiDim.py index 76c2ff0..5d9030c 100644 --- a/PolicyModel/GaussianModelMultiDim.py +++ b/PolicyModel/GaussianModelMultiDim.py @@ -46,4 +46,4 @@ class GaussianPolicy: self.weights = x.reshape(self.nr_weights, self.nr_dims) def get_x(self): - return self.weights.reshape(self.nr_weights * self.nr_dims, 1) + return self.weights.reshape(self.nr_weights * self.nr_dims, )