diff --git a/train.py b/train.py index e69f81a..81c9de2 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ from video import VideoRecorder from agent.baseline_agent import BaselineAgent from agent.bisim_agent import BisimAgent from agent.deepmdp_agent import DeepMDPAgent -from agents.navigation.carla_env import CarlaEnv +#from agents.navigation.carla_env import CarlaEnv def parse_args(): @@ -34,14 +34,15 @@ def parse_args(): parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) - parser.add_argument('--total_frames', default=1000, type=int) + parser.add_argument('--total_frames', default=100, type=int) + parser.add_argument('--high_noise', action='store_true') # replay buffer - parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) + parser.add_argument('--replay_buffer_capacity', default=50000, type=int) # train parser.add_argument('--agent', default='bisim', type=str, choices=['baseline', 'bisim', 'deepmdp']) parser.add_argument('--init_steps', default=1000, type=int) - parser.add_argument('--num_train_steps', default=1000000, type=int) - parser.add_argument('--batch_size', default=512, type=int) + parser.add_argument('--num_train_steps', default=1000050, type=int) + parser.add_argument('--batch_size', default=128, type=int) # 512 parser.add_argument('--hidden_dim', default=256, type=int) parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model') parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms') @@ -50,12 +51,12 @@ def parse_args(): parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--num_eval_episodes', default=20, type=int) # critic - parser.add_argument('--critic_lr', default=1e-3, type=float) + parser.add_argument('--critic_lr', default=1e-5, type=float) parser.add_argument('--critic_beta', default=0.9, type=float) parser.add_argument('--critic_tau', default=0.005, type=float) parser.add_argument('--critic_target_update_freq', default=2, type=int) # actor - parser.add_argument('--actor_lr', default=1e-3, type=float) + parser.add_argument('--actor_lr', default=1e-5, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float) @@ -63,19 +64,19 @@ def parse_args(): # encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_feature_dim', default=50, type=int) - parser.add_argument('--encoder_lr', default=1e-3, type=float) + parser.add_argument('--encoder_lr', default=1e-5, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_stride', default=1, type=int) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) - parser.add_argument('--decoder_lr', default=1e-3, type=float) + parser.add_argument('--decoder_lr', default=1e-5, type=float) parser.add_argument('--decoder_update_freq', default=1, type=int) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) # sac parser.add_argument('--discount', default=0.99, type=float) - parser.add_argument('--init_temperature', default=0.01, type=float) - parser.add_argument('--alpha_lr', default=1e-3, type=float) + parser.add_argument('--init_temperature', default=0.1, type=float) + parser.add_argument('--alpha_lr', default=1e-4, type=float) parser.add_argument('--alpha_beta', default=0.9, type=float) # misc parser.add_argument('--seed', default=1, type=int) @@ -88,6 +89,9 @@ def parse_args(): parser.add_argument('--render', default=False, action='store_true') parser.add_argument('--port', default=2000, type=int) args = parser.parse_args() + #from dmc2gym.wrappers import set_global_var + #set_global_var(args.high_noise) + return args diff --git a/video.py b/video.py index 58211b6..3df34ee 100644 --- a/video.py +++ b/video.py @@ -22,7 +22,7 @@ class VideoRecorder(object): self.frames = [] if resource_files: files = glob.glob(os.path.expanduser(resource_files)) - self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000) + self._bg_source = RandomVideoSource((height, width), files, grayscale=False, max_videos=50, random_bg=False) else: self._bg_source = None