diff --git a/encoder.py b/encoder.py index fd4f7c9..0c4507f 100644 --- a/encoder.py +++ b/encoder.py @@ -109,10 +109,7 @@ class PixelEncoder(nn.Module): out_dim = OUT_DIM[num_layers] self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim * 2) self.ln = nn.LayerNorm(self.feature_dim * 2) -<<<<<<< HEAD self.combine = nn.Linear(self.feature_dim + 6, self.feature_dim) -======= ->>>>>>> origin/tester_1 self.outputs = dict() @@ -157,12 +154,8 @@ class PixelEncoder(nn.Module): out = self.reparameterize(mu, logstd) self.outputs['tanh'] = out -<<<<<<< HEAD return out, mu, logstd -======= - return out ->>>>>>> origin/tester_1 - + def copy_conv_weights_from(self, source): """Tie convolutional layers""" # only tie conv layers diff --git a/sac_ae.py b/sac_ae.py index eae526f..36b152f 100644 --- a/sac_ae.py +++ b/sac_ae.py @@ -416,15 +416,14 @@ class SacAeAgent(object): h_dist_enc = torch.distributions.Normal(h_mu, h_logvar.exp()) h_dist_pred = torch.distributions.Normal(mean, std) enc_loss = torch.distributions.kl.kl_divergence(h_dist_enc, h_dist_pred).mean() * 1e-2 - - """ + with torch.no_grad(): z_pos, _ , _ = self.critic_target.encoder(next_obs_list[-1]) z_out = self.critic_target.encoder.combine(torch.concat((z_pos, action), dim=-1)) logits = self.lb_loss.compute_logits(h, z_out) labels = torch.arange(logits.shape[0]).long().to(self.device) lb_loss = nn.CrossEntropyLoss()(logits, labels) * 1e-2 - """ + #with torch.no_grad(): # z_pos, _ , _ = self.critic.encoder(next_obs_list[-1]) #ub_loss = club_loss(state_enc["sample"], mean, state_enc["logvar"], h) * 1e-1 @@ -437,7 +436,7 @@ class SacAeAgent(object): ub_loss = torch.tensor(0.0) #enc_loss = torch.tensor(0.0) - lb_loss = torch.tensor(0.0) + #lb_loss = torch.tensor(0.0) #rec_loss = torch.tensor(0.0) loss = rec_loss + enc_loss + lb_loss + ub_loss self.encoder_optimizer.zero_grad() diff --git a/train.py b/train.py index 22d01a7..75314b5 100644 --- a/train.py +++ b/train.py @@ -28,10 +28,7 @@ def parse_args(): parser.add_argument('--frame_stack', default=3, type=int) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--resource_files', type=str) -<<<<<<< HEAD parser.add_argument('--resource_files_test', type=str) -======= ->>>>>>> origin/tester_1 parser.add_argument('--total_frames', default=10000, type=int) # replay buffer parser.add_argument('--replay_buffer_capacity', default=100000, type=int)