diff --git a/DPI/models.py b/DPI/models.py index bb4c391..4729787 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -79,7 +79,7 @@ class ObservationDecoder(nn.Module): layers.append(nn.ConvTranspose2d(in_channels=self.in_channels[i], out_channels=self.out_channels[i], kernel_size=self.kernels[i], stride=2, output_padding=self.output_padding[i])) if i!=len(self.kernels)-1: - layers.append(nn.ReLU()) + layers.append(nn.LeakyReLU()) self.convtranspose = nn.Sequential(*layers) @@ -110,7 +110,7 @@ class Actor(nn.Module): input_channels = state_size if i == 0 else self.hidden_size output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size layers.append(nn.Linear(input_channels, output_channels)) - layers.append(nn.ReLU()) + layers.append(nn.LeakyReLU()) self.action_model = nn.Sequential(*layers) def get_dist(self, mean, std): @@ -144,7 +144,7 @@ class ValueModel(nn.Module): input_channels = state_size if i == 0 else self.hidden_size output_channels = self.hidden_size if i!= self.num_layers-1 else 1 layers.append(nn.Linear(input_channels, output_channels)) - layers.append(nn.ReLU()) + layers.append(nn.LeakyReLU()) self.value_model = nn.Sequential(*layers) def forward(self, state): @@ -158,9 +158,9 @@ class RewardModel(nn.Module): super().__init__() self.reward_model = nn.Sequential( nn.Linear(state_size, hidden_size), - nn.ReLU(), + nn.LeakyReLU(), nn.Linear(hidden_size, hidden_size), - nn.ReLU(), + nn.LeakyReLU(), nn.Linear(hidden_size, 1) ) @@ -177,7 +177,7 @@ class TransitionModel(nn.Module): self.hidden_size = hidden_size self.action_size = action_size self.history_size = history_size - self.act_fn = nn.ReLU() + self.act_fn = nn.LeakyReLU() self.fc_state_action = nn.Linear(state_size + action_size, hidden_size) self.history_cell = nn.GRUCell(hidden_size + history_size, history_size) @@ -274,7 +274,7 @@ class ProjectionHead(nn.Module): self.projection_model = nn.Sequential( nn.Linear(state_size + action_size, hidden_size), nn.LayerNorm(hidden_size), - nn.ReLU(), + nn.LeakyReLU(), nn.Linear(hidden_size, hidden_size), nn.LayerNorm(hidden_size), )