From 5ded7bc8f1ecfedf71a8f6699d85d0a9f6716076 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Wed, 12 Apr 2023 09:33:42 +0200 Subject: [PATCH] Adding actor and value learners --- DPI/train.py | 127 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 98 insertions(+), 29 deletions(-) diff --git a/DPI/train.py b/DPI/train.py index a23b5cb..5988fc8 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -65,6 +65,7 @@ def parse_args(): parser.add_argument('--value_beta', default=0.9, type=float) parser.add_argument('--value_tau', default=0.005, type=float) parser.add_argument('--value_target_update_freq', default=2, type=int) + parser.add_argument('--td_lambda', default=0.95, type=int) # reward parser.add_argument('--reward_lr', default=1e-4, type=float) # actor @@ -77,6 +78,7 @@ def parse_args(): parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--world_model_lr', default=1e-3, type=float) + parser.add_argument('--past_transition_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_stride', default=1, type=int) @@ -94,6 +96,7 @@ def parse_args(): parser.add_argument('--alpha_beta', default=0.9, type=float) # misc parser.add_argument('--seed', default=1, type=int) + parser.add_argument('--logging_freq', default=100, type=int) parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true') @@ -102,6 +105,7 @@ def parse_args(): parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble']) parser.add_argument('--render', default=False, action='store_true') parser.add_argument('--port', default=2000, type=int) + parser.add_argument('--num_likelihood_updates', default=5, type=int) args = parser.parse_args() return args @@ -145,13 +149,7 @@ class DPI: seq_len=self.args.episode_length, batch_size=args.batch_size, args=self.args) - self.data_buffer_clean = ReplayBuffer(size=self.args.replay_buffer_capacity, - obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), - action_size=self.env.action_space.shape[0], - seq_len=self.args.episode_length, - batch_size=args.batch_size, - args=self.args) - + # create work directory utils.make_dir(self.args.work_dir) self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video')) @@ -230,11 +228,13 @@ class DPI: self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + \ list(self.value_model.parameters()) + list(self.transition_model.parameters()) + \ list(self.prjoection_head.parameters()) + self.past_transition_parameters = self.transition_model.parameters() # optimizers self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr) self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr) self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr) + self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr) # Create Modules self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head] @@ -269,7 +269,7 @@ class DPI: next_obs, rew, done, _ = self.env.step(action) #next_obs_clean, _, done, _ = self.env_clean.step(action) - self.data_buffer.add(obs, action, next_obs, episode_count+1, done) + self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done) #self.data_buffer_clean.add(obs_clean, action, next_obs_clean, episode_count+1, done) if args.save_video: @@ -293,11 +293,12 @@ class DPI: self.collect_sequences(self.args.batch_size) # Group observations and next_observations by steps from past to present - last_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1] + last_observations = torch.tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1] current_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[:self.args.episode_length-1] next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[1:] actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1] next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:] + rewards = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"rewards",obs=False)).float()[1:] # Initialize transition model states self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128) @@ -306,6 +307,7 @@ class DPI: # Train encoder step = 0 total_steps = 10000 + metrics = {} while step < total_steps: for i in range(self.args.episode_length-1): if i > 0: @@ -314,6 +316,7 @@ class DPI: self.current_states_dict = self.get_features(current_observations[i]) self.next_states_dict = self.get_features(next_observations[i], momentum=True) self.action = actions[i] # (N,6) + self.next_action = next_actions[i] # (N,6) history = self.transition_model.prev_history # Encode negative observations @@ -327,14 +330,14 @@ class DPI: predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) self.history = predicted_current_state_dict["history"] - - # Calculate upper bound loss - ub_loss = self._upper_bound_minimization(self.last_states_dict, - self.current_states_dict, - self.negative_current_states_dict, - predicted_current_state_dict - ) + likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict, + self.current_states_dict, + self.negative_current_states_dict, + predicted_current_state_dict + ) + #likeli_loss = torch.tensor(likeli_loss.numpy(),dtype=torch.float32, requires_grad=True) + #ikeli_loss = likeli_loss.mean() # Calculate encoder loss encoder_loss = self._past_encoder_loss(self.current_states_dict, @@ -356,31 +359,83 @@ class DPI: # update models - world_model_loss = encoder_loss + 1e-1 * ub_loss + lb_loss #1e-1 * ub_loss + 1e-5 * encoder_loss + 1e-1 * lb_loss - print("ub_loss: {:.4f}, encoder_loss: {:.4f}, lb_loss: {:.4f}".format(ub_loss, encoder_loss, lb_loss)) - print("world_model_loss: {:.4f}".format(world_model_loss)) + """ + print(likeli_loss) + for i in range(self.args.num_likelihood_updates): + self.past_transition_opt.zero_grad() + print(likeli_loss) + likeli_loss.backward() + nn.utils.clip_grad_norm_(self.past_transition_parameters, self.args.grad_clip_norm) + self.past_transition_opt.step() + print(encoder_loss, ub_loss, lb_loss, step) + """ + + world_model_loss = encoder_loss + ub_loss + lb_loss self.world_model_opt.zero_grad() world_model_loss.backward() nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) self.world_model_opt.step() - + """ + if step % self.args.logging_freq: + metrics['Upper Bound Loss'] = ub_loss.item() + metrics['Encoder Loss'] = encoder_loss.item() + metrics["Lower Bound Loss"] = lb_loss.item() + metrics["World Model Loss"] = world_model_loss.item() + wandb.log(metrics) + """ + # behaviour learning with FreezeParameters(self.world_model_modules): - imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) + imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(), - self.action, self.history.detach(), + self.next_action, self.history.detach(), imagine_horizon) - print(imagined_rollout["sample"].shape, imagined_rollout["distribution"][0].sample().shape) - #exit() + #print(imagined_rollout["sample"].shape, imagined_rollout["distribution"][0].sample().shape) + # actor loss + with FreezeParameters(self.world_model_modules + self.value_modules): + imag_rew_dist = self.reward_model(imagined_rollout["sample"]) + target_imag_val_dist = self.target_value_model(imagined_rollout["sample"]) + + imag_rews = imag_rew_dist.mean + target_imag_vals = target_imag_val_dist.mean + + discounts = self.args.discount * torch.ones_like(imag_rews).detach() + + self.target_returns = self._compute_lambda_return(imag_rews[:-1], + target_imag_vals[:-1], + discounts[:-1] , + self.args.td_lambda, + target_imag_vals[-1]) + + discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) + self.discounts = torch.cumprod(discounts, 0).detach() + actor_loss = -torch.mean(self.discounts * self.target_returns) + + self.actor_opt.zero_grad() + actor_loss.backward() + nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) + self.actor_opt.step() + + # value loss + with torch.no_grad(): + value_feat = imagined_rollout["sample"][:-1].detach() + value_targ = self.target_returns.detach() + + value_dist = self.value_model(value_feat) + value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) + + self.value_opt.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) + self.value_opt.step() + step += 1 if step>total_steps: print("Training finished") break - #exit() - #print(total_ub_loss, total_encoder_loss) @@ -390,8 +445,9 @@ class DPI: current_states, negative_current_states, predicted_current_states) - club_loss = club_sample.loglikeli() - return club_loss + likelihood_loss = club_sample.learning_loss() + club_loss = club_sample() + return likelihood_loss, club_loss def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): # current state distribution @@ -410,7 +466,7 @@ class DPI: def get_features(self, x, momentum=False): import torchvision.transforms.functional as fn x = x/255.0 - 0.5 # Preprocessing - + if self.args.aug: x = T.RandomCrop((80, 80))(x) # (None,80,80,4) x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) @@ -422,6 +478,19 @@ class DPI: else: x = self.obs_encoder(x) return x + + def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value): + next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0) + targets = rewards + discounts * next_values * (1-td_lam) + rets =[] + last_rew = last_value + + for t in range(rewards.shape[0]-1, -1, -1): + last_rew = targets[t] + discounts[t] * td_lam *(last_rew) + rets.append(last_rew) + + returns = torch.flip(torch.stack(rets), [0]) + return returns if __name__ == '__main__': args = parse_args()