Correct history with detach
This commit is contained in:
parent
de17cab9f5
commit
ac714e3495
@ -19,7 +19,7 @@ class ObservationEncoder(nn.Module):
|
|||||||
input_channels = obs_shape[0] if i == 0 else output_channels
|
input_channels = obs_shape[0] if i == 0 else output_channels
|
||||||
output_channels = num_filters * (2 ** i)
|
output_channels = num_filters * (2 ** i)
|
||||||
layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2))
|
layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2))
|
||||||
layers.append(nn.ReLU())
|
layers.append(nn.LeakyReLU())
|
||||||
|
|
||||||
self.convs = nn.Sequential(*layers)
|
self.convs = nn.Sequential(*layers)
|
||||||
|
|
||||||
@ -196,7 +196,8 @@ class TransitionModel(nn.Module):
|
|||||||
|
|
||||||
def imagine_step(self, prev_state, prev_action, prev_history):
|
def imagine_step(self, prev_state, prev_action, prev_history):
|
||||||
state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))
|
state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))
|
||||||
history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history)
|
prev_hist = prev_history.detach()
|
||||||
|
history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist)
|
||||||
|
|
||||||
state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1))
|
state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1))
|
||||||
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
|
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user