Adding files
This commit is contained in:
parent
ca334452a0
commit
82e8a23918
@ -109,10 +109,7 @@ class PixelEncoder(nn.Module):
|
|||||||
out_dim = OUT_DIM[num_layers]
|
out_dim = OUT_DIM[num_layers]
|
||||||
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim * 2)
|
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim * 2)
|
||||||
self.ln = nn.LayerNorm(self.feature_dim * 2)
|
self.ln = nn.LayerNorm(self.feature_dim * 2)
|
||||||
<<<<<<< HEAD
|
|
||||||
self.combine = nn.Linear(self.feature_dim + 6, self.feature_dim)
|
self.combine = nn.Linear(self.feature_dim + 6, self.feature_dim)
|
||||||
=======
|
|
||||||
>>>>>>> origin/tester_1
|
|
||||||
|
|
||||||
self.outputs = dict()
|
self.outputs = dict()
|
||||||
|
|
||||||
@ -157,12 +154,8 @@ class PixelEncoder(nn.Module):
|
|||||||
|
|
||||||
out = self.reparameterize(mu, logstd)
|
out = self.reparameterize(mu, logstd)
|
||||||
self.outputs['tanh'] = out
|
self.outputs['tanh'] = out
|
||||||
<<<<<<< HEAD
|
|
||||||
return out, mu, logstd
|
return out, mu, logstd
|
||||||
=======
|
|
||||||
return out
|
|
||||||
>>>>>>> origin/tester_1
|
|
||||||
|
|
||||||
def copy_conv_weights_from(self, source):
|
def copy_conv_weights_from(self, source):
|
||||||
"""Tie convolutional layers"""
|
"""Tie convolutional layers"""
|
||||||
# only tie conv layers
|
# only tie conv layers
|
||||||
|
@ -416,15 +416,14 @@ class SacAeAgent(object):
|
|||||||
h_dist_enc = torch.distributions.Normal(h_mu, h_logvar.exp())
|
h_dist_enc = torch.distributions.Normal(h_mu, h_logvar.exp())
|
||||||
h_dist_pred = torch.distributions.Normal(mean, std)
|
h_dist_pred = torch.distributions.Normal(mean, std)
|
||||||
enc_loss = torch.distributions.kl.kl_divergence(h_dist_enc, h_dist_pred).mean() * 1e-2
|
enc_loss = torch.distributions.kl.kl_divergence(h_dist_enc, h_dist_pred).mean() * 1e-2
|
||||||
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z_pos, _ , _ = self.critic_target.encoder(next_obs_list[-1])
|
z_pos, _ , _ = self.critic_target.encoder(next_obs_list[-1])
|
||||||
z_out = self.critic_target.encoder.combine(torch.concat((z_pos, action), dim=-1))
|
z_out = self.critic_target.encoder.combine(torch.concat((z_pos, action), dim=-1))
|
||||||
logits = self.lb_loss.compute_logits(h, z_out)
|
logits = self.lb_loss.compute_logits(h, z_out)
|
||||||
labels = torch.arange(logits.shape[0]).long().to(self.device)
|
labels = torch.arange(logits.shape[0]).long().to(self.device)
|
||||||
lb_loss = nn.CrossEntropyLoss()(logits, labels) * 1e-2
|
lb_loss = nn.CrossEntropyLoss()(logits, labels) * 1e-2
|
||||||
"""
|
|
||||||
#with torch.no_grad():
|
#with torch.no_grad():
|
||||||
# z_pos, _ , _ = self.critic.encoder(next_obs_list[-1])
|
# z_pos, _ , _ = self.critic.encoder(next_obs_list[-1])
|
||||||
#ub_loss = club_loss(state_enc["sample"], mean, state_enc["logvar"], h) * 1e-1
|
#ub_loss = club_loss(state_enc["sample"], mean, state_enc["logvar"], h) * 1e-1
|
||||||
@ -437,7 +436,7 @@ class SacAeAgent(object):
|
|||||||
|
|
||||||
ub_loss = torch.tensor(0.0)
|
ub_loss = torch.tensor(0.0)
|
||||||
#enc_loss = torch.tensor(0.0)
|
#enc_loss = torch.tensor(0.0)
|
||||||
lb_loss = torch.tensor(0.0)
|
#lb_loss = torch.tensor(0.0)
|
||||||
#rec_loss = torch.tensor(0.0)
|
#rec_loss = torch.tensor(0.0)
|
||||||
loss = rec_loss + enc_loss + lb_loss + ub_loss
|
loss = rec_loss + enc_loss + lb_loss + ub_loss
|
||||||
self.encoder_optimizer.zero_grad()
|
self.encoder_optimizer.zero_grad()
|
||||||
|
3
train.py
3
train.py
@ -28,10 +28,7 @@ def parse_args():
|
|||||||
parser.add_argument('--frame_stack', default=3, type=int)
|
parser.add_argument('--frame_stack', default=3, type=int)
|
||||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||||
parser.add_argument('--resource_files', type=str)
|
parser.add_argument('--resource_files', type=str)
|
||||||
<<<<<<< HEAD
|
|
||||||
parser.add_argument('--resource_files_test', type=str)
|
parser.add_argument('--resource_files_test', type=str)
|
||||||
=======
|
|
||||||
>>>>>>> origin/tester_1
|
|
||||||
parser.add_argument('--total_frames', default=10000, type=int)
|
parser.add_argument('--total_frames', default=10000, type=int)
|
||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
|
parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
|
||||||
|
Loading…
Reference in New Issue
Block a user