From ac714e3495daadeb60ebc61a00bafcacd6c2a453 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Mon, 10 Apr 2023 20:18:39 +0200 Subject: [PATCH] Correct history with detach --- DPI/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/DPI/models.py b/DPI/models.py index 0efdaeb..a1c0d28 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -19,7 +19,7 @@ class ObservationEncoder(nn.Module): input_channels = obs_shape[0] if i == 0 else output_channels output_channels = num_filters * (2 ** i) layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2)) - layers.append(nn.ReLU()) + layers.append(nn.LeakyReLU()) self.convs = nn.Sequential(*layers) @@ -196,7 +196,8 @@ class TransitionModel(nn.Module): def imagine_step(self, prev_state, prev_action, prev_history): state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))) - history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history) + prev_hist = prev_history.detach() + history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist) state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1)) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)