From bb0265846e962d74224008616c22347c65356e2d Mon Sep 17 00:00:00 2001 From: VedantDave Date: Wed, 26 Apr 2023 09:43:28 +0200 Subject: [PATCH] Update train model --- DPI/train.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/DPI/train.py b/DPI/train.py index e804ac9..f70417d 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -50,9 +50,9 @@ def parse_args(): parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--init_steps', default=5000, type=int) parser.add_argument('--num_train_steps', default=100000, type=int) - parser.add_argument('--update_steps', default=100, type=int) + parser.add_argument('--update_steps', default=10, type=int) parser.add_argument('--batch_size', default=64, type=int) - parser.add_argument('--state_size', default=50, type=int) + parser.add_argument('--state_size', default=100, type=int) parser.add_argument('--hidden_size', default=512, type=int) parser.add_argument('--history_size', default=256, type=int) parser.add_argument('--episode_collection', default=5, type=int) @@ -66,21 +66,21 @@ def parse_args(): parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000 # value - parser.add_argument('--value_lr', default=8e-5, type=float) + parser.add_argument('--value_lr', default=8e-6, type=float) parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--td_lambda', default=0.95, type=int) # actor - parser.add_argument('--actor_lr', default=8e-5, type=float) + parser.add_argument('--actor_lr', default=8e-6, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_update_freq', default=2, type=int) # world/encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) - parser.add_argument('--world_model_lr', default=6e-5, type=float) - parser.add_argument('--decoder_lr', default=6e-4, type=float) - parser.add_argument('--reward_lr', default=6e-5, type=float) - parser.add_argument('--encoder_tau', default=0.001, type=float) + parser.add_argument('--world_model_lr', default=1e-6, type=float) + parser.add_argument('--decoder_lr', default=6e-6, type=float) + parser.add_argument('--reward_lr', default=8e-6, type=float) + parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) @@ -238,10 +238,10 @@ class DPI: # optimizers self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6) - self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6, weight_decay=1e-5) + self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6) self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6) self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6) - self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6, weight_decay=1e-5) + self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6) # Create Modules self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head, @@ -461,19 +461,18 @@ class DPI: self.transition_model.init_states(self.args.batch_size, device) # (N,128) self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms) self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"]) - #print(self.observed_rollout["mean"][0][0]) self.pred_curr_state_enc = self.pred_curr_state_dist.rsample() #self.observed_rollout["sample"] # encoder loss enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist) # reward loss - rew_dist = self.reward_model(self.curr_state_enc) + rew_dist = self.reward_model(self.curr_state_enc.detach()) #print(torch.cat([rew_dist.mean[0], rewards[0]],dim=-1)) rew_loss = -torch.mean(rew_dist.log_prob(rewards)) # decoder loss - dec_dist = self.obs_decoder(self.nxt_state_enc) + dec_dist = self.obs_decoder(self.nxt_state_enc.detach()) dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs)) # upper bound loss @@ -484,7 +483,7 @@ class DPI: ub_loss = ub_loss + past_ub_loss past_ub_loss = ub_loss ub_loss = ub_loss / self.curr_state_enc.shape[0] - ub_loss = 0.01 * ub_loss + ub_loss = 1 * ub_loss # lower bound loss # contrastive projection @@ -520,14 +519,14 @@ class DPI: action, curr_state_hist.detach(), imagine_horizon) self.pred_nxt_state_dist = self.transition_model.get_dist(self.imagined_rollout["mean"], self.imagined_rollout["std"]) - #print(self.imagined_rollout["mean"][0][0]) self.pred_nxt_state_enc = self.pred_nxt_state_dist.rsample() #self.transition_model.reparemeterize(self.imagined_rollout["mean"], self.imagined_rollout["std"]) with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules): imag_rewards_dist = self.reward_model(self.pred_nxt_state_enc) imag_values_dist = self.value_model(self.pred_nxt_state_enc) - imag_rewards = imag_rewards_dist.mean + imag_rewards = imag_rewards_dist.mean imag_values = imag_values_dist.mean + #print(torch.cat([imag_rewards[0], imag_values[0]],dim=-1)) discounts = self.args.discount * torch.ones_like(imag_rewards).detach() self.returns = self._compute_lambda_return(imag_rewards[:-1], @@ -541,7 +540,6 @@ class DPI: return actor_loss def value_model_losses(self): - # value loss with torch.no_grad(): value_feat = self.pred_nxt_state_enc[:-1].detach() value_targ = self.returns.detach() @@ -771,7 +769,7 @@ if __name__ == '__main__': device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') step = 0 - total_steps = 1000000 + total_steps = 2000000 dpi = DPI(args) dpi.train(step,total_steps) dpi.evaluate() \ No newline at end of file