Adding files

This commit is contained in:
Vedant Dave 2023-05-25 17:51:31 +02:00
parent ca334452a0
commit 82e8a23918
3 changed files with 4 additions and 15 deletions

View File

@ -109,10 +109,7 @@ class PixelEncoder(nn.Module):
out_dim = OUT_DIM[num_layers]
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim * 2)
self.ln = nn.LayerNorm(self.feature_dim * 2)
<<<<<<< HEAD
self.combine = nn.Linear(self.feature_dim + 6, self.feature_dim)
=======
>>>>>>> origin/tester_1
self.outputs = dict()
@ -157,12 +154,8 @@ class PixelEncoder(nn.Module):
out = self.reparameterize(mu, logstd)
self.outputs['tanh'] = out
<<<<<<< HEAD
return out, mu, logstd
=======
return out
>>>>>>> origin/tester_1
def copy_conv_weights_from(self, source):
"""Tie convolutional layers"""
# only tie conv layers

View File

@ -416,15 +416,14 @@ class SacAeAgent(object):
h_dist_enc = torch.distributions.Normal(h_mu, h_logvar.exp())
h_dist_pred = torch.distributions.Normal(mean, std)
enc_loss = torch.distributions.kl.kl_divergence(h_dist_enc, h_dist_pred).mean() * 1e-2
"""
with torch.no_grad():
z_pos, _ , _ = self.critic_target.encoder(next_obs_list[-1])
z_out = self.critic_target.encoder.combine(torch.concat((z_pos, action), dim=-1))
logits = self.lb_loss.compute_logits(h, z_out)
labels = torch.arange(logits.shape[0]).long().to(self.device)
lb_loss = nn.CrossEntropyLoss()(logits, labels) * 1e-2
"""
#with torch.no_grad():
# z_pos, _ , _ = self.critic.encoder(next_obs_list[-1])
#ub_loss = club_loss(state_enc["sample"], mean, state_enc["logvar"], h) * 1e-1
@ -437,7 +436,7 @@ class SacAeAgent(object):
ub_loss = torch.tensor(0.0)
#enc_loss = torch.tensor(0.0)
lb_loss = torch.tensor(0.0)
#lb_loss = torch.tensor(0.0)
#rec_loss = torch.tensor(0.0)
loss = rec_loss + enc_loss + lb_loss + ub_loss
self.encoder_optimizer.zero_grad()

View File

@ -28,10 +28,7 @@ def parse_args():
parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--resource_files', type=str)
<<<<<<< HEAD
parser.add_argument('--resource_files_test', type=str)
=======
>>>>>>> origin/tester_1
parser.add_argument('--total_frames', default=10000, type=int)
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=100000, type=int)