fix
This commit is contained in:
parent
623dd6241a
commit
38341e3d44
@ -412,7 +412,7 @@ class SacAeAgent(object):
|
||||
self.encoder_tau
|
||||
)
|
||||
|
||||
if self.decoder is None and step % self.decoder_update_freq == 0:
|
||||
if self.decoder is not None and step % self.decoder_update_freq == 0:
|
||||
self.update_decoder(obs, obs, L, step)
|
||||
|
||||
def save(self, model_dir, step):
|
||||
|
18
train.py
18
train.py
@ -32,15 +32,15 @@ def parse_args():
|
||||
parser.add_argument('--agent', default='sac_ae', type=str)
|
||||
parser.add_argument('--init_steps', default=1000, type=int)
|
||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||
parser.add_argument('--batch_size', default=512, type=int)
|
||||
parser.add_argument('--hidden_dim', default=256, type=int)
|
||||
parser.add_argument('--batch_size', default=128, type=int)
|
||||
parser.add_argument('--hidden_dim', default=1024, type=int)
|
||||
# eval
|
||||
parser.add_argument('--eval_freq', default=10000, type=int)
|
||||
parser.add_argument('--num_eval_episodes', default=10, type=int)
|
||||
# critic
|
||||
parser.add_argument('--critic_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--critic_beta', default=0.9, type=float)
|
||||
parser.add_argument('--critic_tau', default=0.005, type=float)
|
||||
parser.add_argument('--critic_tau', default=0.01, type=float)
|
||||
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
||||
# actor
|
||||
parser.add_argument('--actor_lr', default=1e-3, type=float)
|
||||
@ -52,19 +52,19 @@ def parse_args():
|
||||
parser.add_argument('--encoder_type', default='pixel', type=str)
|
||||
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
||||
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
||||
parser.add_argument('--encoder_tau', default=0.05, type=float)
|
||||
parser.add_argument('--decoder_type', default='pixel', type=str)
|
||||
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
||||
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
|
||||
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
||||
parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float)
|
||||
parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float)
|
||||
parser.add_argument('--num_layers', default=4, type=int)
|
||||
parser.add_argument('--num_filters', default=32, type=int)
|
||||
# sac
|
||||
parser.add_argument('--discount', default=0.99, type=float)
|
||||
parser.add_argument('--init_temperature', default=0.01, type=float)
|
||||
parser.add_argument('--alpha_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
||||
parser.add_argument('--init_temperature', default=0.1, type=float)
|
||||
parser.add_argument('--alpha_lr', default=1e-4, type=float)
|
||||
parser.add_argument('--alpha_beta', default=0.5, type=float)
|
||||
# misc
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--work_dir', default='.', type=str)
|
||||
|
Loading…
Reference in New Issue
Block a user