diff --git a/DPI/models.py b/DPI/models.py index ed337a2..f2d1d4c 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -149,7 +149,8 @@ class ValueModel(nn.Module): def forward(self, state): value = self.value_model(state) - return value + value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1) + return value_dist class TransitionModel(nn.Module):