This commit is contained in:
Denis Yarats 2019-09-23 18:22:49 -07:00
parent 623dd6241a
commit 38341e3d44
2 changed files with 10 additions and 10 deletions

View File

@ -412,7 +412,7 @@ class SacAeAgent(object):
self.encoder_tau self.encoder_tau
) )
if self.decoder is None and step % self.decoder_update_freq == 0: if self.decoder is not None and step % self.decoder_update_freq == 0:
self.update_decoder(obs, obs, L, step) self.update_decoder(obs, obs, L, step)
def save(self, model_dir, step): def save(self, model_dir, step):

View File

@ -32,15 +32,15 @@ def parse_args():
parser.add_argument('--agent', default='sac_ae', type=str) parser.add_argument('--agent', default='sac_ae', type=str)
parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int)
parser.add_argument('--batch_size', default=512, type=int) parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--hidden_dim', default=256, type=int) parser.add_argument('--hidden_dim', default=1024, type=int)
# eval # eval
parser.add_argument('--eval_freq', default=10000, type=int) parser.add_argument('--eval_freq', default=10000, type=int)
parser.add_argument('--num_eval_episodes', default=10, type=int) parser.add_argument('--num_eval_episodes', default=10, type=int)
# critic # critic
parser.add_argument('--critic_lr', default=1e-3, type=float) parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float) parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float) parser.add_argument('--critic_tau', default=0.01, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int) parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor # actor
parser.add_argument('--actor_lr', default=1e-3, type=float) parser.add_argument('--actor_lr', default=1e-3, type=float)
@ -52,19 +52,19 @@ def parse_args():
parser.add_argument('--encoder_type', default='pixel', type=str) parser.add_argument('--encoder_type', default='pixel', type=str)
parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_tau', default=0.05, type=float)
parser.add_argument('--decoder_type', default='pixel', type=str) parser.add_argument('--decoder_type', default='pixel', type=str)
parser.add_argument('--decoder_lr', default=1e-3, type=float) parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int) parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float) parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float)
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--init_temperature', default=0.1, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float) parser.add_argument('--alpha_lr', default=1e-4, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float) parser.add_argument('--alpha_beta', default=0.5, type=float)
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--work_dir', default='.', type=str)