Correct history with detach

This commit is contained in:
Vedant Dave 2023-04-10 20:18:39 +02:00
parent de17cab9f5
commit ac714e3495

View File

@ -19,7 +19,7 @@ class ObservationEncoder(nn.Module):
input_channels = obs_shape[0] if i == 0 else output_channels
output_channels = num_filters * (2 ** i)
layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2))
layers.append(nn.ReLU())
layers.append(nn.LeakyReLU())
self.convs = nn.Sequential(*layers)
@ -196,7 +196,8 @@ class TransitionModel(nn.Module):
def imagine_step(self, prev_state, prev_action, prev_history):
state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))
history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history)
prev_hist = prev_history.detach()
history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist)
state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1))
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)