diff --git a/DPI/train.py b/DPI/train.py index 8e2a867..29d0cab 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -26,6 +26,7 @@ def parse_args(): parser = argparse.ArgumentParser() # environment parser.add_argument('--domain_name', default='cheetah') + parser.add_argument('--version', default=1, type=int) parser.add_argument('--task_name', default='run') parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--channels', default=3, type=int) @@ -113,9 +114,17 @@ class DPI: self.env = make_env(self.args) self.env.seed(self.args.seed) + # noiseless environment setup + self.args.version = 2 # env_id changes to v2 + self.args.img_source = None # no image noise + self.args.resource_files = None + self.env_clean = make_env(self.args) + self.env_clean.seed(self.args.seed) + # stack several consecutive frames together if self.args.encoder_type.startswith('pixel'): self.env = utils.FrameStack(self.env, k=self.args.frame_stack) + self.env_clean = utils.FrameStack(self.env_clean, k=self.args.frame_stack) # create replay buffer self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, @@ -124,6 +133,12 @@ class DPI: seq_len=self.args.episode_length, batch_size=args.batch_size, args=self.args) + self.data_buffer_clean = ReplayBuffer(size=self.args.replay_buffer_capacity, + obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), + action_size=self.env.action_space.shape[0], + seq_len=self.args.episode_length, + batch_size=args.batch_size, + args=self.args) # create work directory utils.make_dir(self.args.work_dir) @@ -145,6 +160,11 @@ class DPI: output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84) ) + self.obs_encoder_momentum = ObservationEncoder( + obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) + state_size=self.args.state_size # 128 + ) + self.transition_model = TransitionModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 @@ -153,7 +173,8 @@ class DPI: ) # model parameters - self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + list(self.transition_model.parameters()) + self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \ + list(self.obs_decoder.parameters()) + list(self.transition_model.parameters()) # optimizer self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr) @@ -166,33 +187,45 @@ class DPI: self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) - def collect_random_episodes(self, episodes): + def collect_sequences(self, episodes): obs = self.env.reset() + obs_clean = self.env_clean.reset() done = False #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): - #self.env.video.init(enabled=True) + if args.save_video: + self.env.video.init(enabled=True) + self.env_clean.video.init(enabled=True) + for i in range(self.args.episode_length): action = self.env.action_space.sample() + next_obs, _, done, _ = self.env.step(action) + next_obs_clean, _, done, _ = self.env_clean.step(action) self.data_buffer.add(obs, action, next_obs, episode_count+1, done) + self.data_buffer_clean.add(obs_clean, action, next_obs_clean, episode_count+1, done) - #if args.save_video: - # self.env.video.record(self.env) + if args.save_video: + self.env.video.record(self.env_clean) + self.env_clean.video.record(self.env_clean) if done: obs = self.env.reset() + obs_clean = self.env_clean.reset() done=False else: obs = next_obs - #self.env.video.save('%d.mp4' % episode_count) + obs_clean = next_obs_clean + if args.save_video: + self.env.video.save('noisy/%d.mp4' % episode_count) + self.env_clean.video.save('clean/%d.mp4' % episode_count) print("Collected {} random episodes".format(episode_count+1)) def train(self): # collect experience - self.collect_random_episodes(self.args.batch_size) + self.collect_sequences(self.args.batch_size) # Group observations and next_observations by steps observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float() @@ -223,6 +256,9 @@ class DPI: past_encoder_loss = previous_encoder_loss + self._past_encoder_loss(self.states, self.next_states, self.states_dist, self.next_states_dist, self.actions, self.history, i) + + + previous_information_loss = past_latent_loss previous_encoder_loss = past_encoder_loss