fix
This commit is contained in:
parent
623dd6241a
commit
38341e3d44
@ -412,7 +412,7 @@ class SacAeAgent(object):
|
|||||||
self.encoder_tau
|
self.encoder_tau
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.decoder is None and step % self.decoder_update_freq == 0:
|
if self.decoder is not None and step % self.decoder_update_freq == 0:
|
||||||
self.update_decoder(obs, obs, L, step)
|
self.update_decoder(obs, obs, L, step)
|
||||||
|
|
||||||
def save(self, model_dir, step):
|
def save(self, model_dir, step):
|
||||||
|
18
train.py
18
train.py
@ -32,15 +32,15 @@ def parse_args():
|
|||||||
parser.add_argument('--agent', default='sac_ae', type=str)
|
parser.add_argument('--agent', default='sac_ae', type=str)
|
||||||
parser.add_argument('--init_steps', default=1000, type=int)
|
parser.add_argument('--init_steps', default=1000, type=int)
|
||||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||||
parser.add_argument('--batch_size', default=512, type=int)
|
parser.add_argument('--batch_size', default=128, type=int)
|
||||||
parser.add_argument('--hidden_dim', default=256, type=int)
|
parser.add_argument('--hidden_dim', default=1024, type=int)
|
||||||
# eval
|
# eval
|
||||||
parser.add_argument('--eval_freq', default=10000, type=int)
|
parser.add_argument('--eval_freq', default=10000, type=int)
|
||||||
parser.add_argument('--num_eval_episodes', default=10, type=int)
|
parser.add_argument('--num_eval_episodes', default=10, type=int)
|
||||||
# critic
|
# critic
|
||||||
parser.add_argument('--critic_lr', default=1e-3, type=float)
|
parser.add_argument('--critic_lr', default=1e-3, type=float)
|
||||||
parser.add_argument('--critic_beta', default=0.9, type=float)
|
parser.add_argument('--critic_beta', default=0.9, type=float)
|
||||||
parser.add_argument('--critic_tau', default=0.005, type=float)
|
parser.add_argument('--critic_tau', default=0.01, type=float)
|
||||||
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
||||||
# actor
|
# actor
|
||||||
parser.add_argument('--actor_lr', default=1e-3, type=float)
|
parser.add_argument('--actor_lr', default=1e-3, type=float)
|
||||||
@ -52,19 +52,19 @@ def parse_args():
|
|||||||
parser.add_argument('--encoder_type', default='pixel', type=str)
|
parser.add_argument('--encoder_type', default='pixel', type=str)
|
||||||
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
||||||
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
||||||
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
parser.add_argument('--encoder_tau', default=0.05, type=float)
|
||||||
parser.add_argument('--decoder_type', default='pixel', type=str)
|
parser.add_argument('--decoder_type', default='pixel', type=str)
|
||||||
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
||||||
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
||||||
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
|
parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float)
|
||||||
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float)
|
||||||
parser.add_argument('--num_layers', default=4, type=int)
|
parser.add_argument('--num_layers', default=4, type=int)
|
||||||
parser.add_argument('--num_filters', default=32, type=int)
|
parser.add_argument('--num_filters', default=32, type=int)
|
||||||
# sac
|
# sac
|
||||||
parser.add_argument('--discount', default=0.99, type=float)
|
parser.add_argument('--discount', default=0.99, type=float)
|
||||||
parser.add_argument('--init_temperature', default=0.01, type=float)
|
parser.add_argument('--init_temperature', default=0.1, type=float)
|
||||||
parser.add_argument('--alpha_lr', default=1e-3, type=float)
|
parser.add_argument('--alpha_lr', default=1e-4, type=float)
|
||||||
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
parser.add_argument('--alpha_beta', default=0.5, type=float)
|
||||||
# misc
|
# misc
|
||||||
parser.add_argument('--seed', default=1, type=int)
|
parser.add_argument('--seed', default=1, type=int)
|
||||||
parser.add_argument('--work_dir', default='.', type=str)
|
parser.add_argument('--work_dir', default='.', type=str)
|
||||||
|
Loading…
Reference in New Issue
Block a user