Update train model

This commit is contained in:
Vedant Dave 2023-04-26 09:43:28 +02:00
parent ab4e7b9a22
commit bb0265846e

View File

@ -50,9 +50,9 @@ def parse_args():
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=5000, type=int)
parser.add_argument('--num_train_steps', default=100000, type=int)
parser.add_argument('--update_steps', default=100, type=int)
parser.add_argument('--update_steps', default=10, type=int)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--state_size', default=50, type=int)
parser.add_argument('--state_size', default=100, type=int)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--history_size', default=256, type=int)
parser.add_argument('--episode_collection', default=5, type=int)
@ -66,21 +66,21 @@ def parse_args():
parser.add_argument('--num_eval_episodes', default=20, type=int)
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
# value
parser.add_argument('--value_lr', default=8e-5, type=float)
parser.add_argument('--value_lr', default=8e-6, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int)
# actor
parser.add_argument('--actor_lr', default=8e-5, type=float)
parser.add_argument('--actor_lr', default=8e-6, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--world_model_lr', default=6e-5, type=float)
parser.add_argument('--decoder_lr', default=6e-4, type=float)
parser.add_argument('--reward_lr', default=6e-5, type=float)
parser.add_argument('--encoder_tau', default=0.001, type=float)
parser.add_argument('--world_model_lr', default=1e-6, type=float)
parser.add_argument('--decoder_lr', default=6e-6, type=float)
parser.add_argument('--reward_lr', default=8e-6, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
@ -238,10 +238,10 @@ class DPI:
# optimizers
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6, weight_decay=1e-5)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6)
self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6)
self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6, weight_decay=1e-5)
self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6)
# Create Modules
self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head,
@ -461,19 +461,18 @@ class DPI:
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms)
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
#print(self.observed_rollout["mean"][0][0])
self.pred_curr_state_enc = self.pred_curr_state_dist.rsample() #self.observed_rollout["sample"]
# encoder loss
enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
# reward loss
rew_dist = self.reward_model(self.curr_state_enc)
rew_dist = self.reward_model(self.curr_state_enc.detach())
#print(torch.cat([rew_dist.mean[0], rewards[0]],dim=-1))
rew_loss = -torch.mean(rew_dist.log_prob(rewards))
# decoder loss
dec_dist = self.obs_decoder(self.nxt_state_enc)
dec_dist = self.obs_decoder(self.nxt_state_enc.detach())
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
# upper bound loss
@ -484,7 +483,7 @@ class DPI:
ub_loss = ub_loss + past_ub_loss
past_ub_loss = ub_loss
ub_loss = ub_loss / self.curr_state_enc.shape[0]
ub_loss = 0.01 * ub_loss
ub_loss = 1 * ub_loss
# lower bound loss
# contrastive projection
@ -520,7 +519,6 @@ class DPI:
action, curr_state_hist.detach(),
imagine_horizon)
self.pred_nxt_state_dist = self.transition_model.get_dist(self.imagined_rollout["mean"], self.imagined_rollout["std"])
#print(self.imagined_rollout["mean"][0][0])
self.pred_nxt_state_enc = self.pred_nxt_state_dist.rsample() #self.transition_model.reparemeterize(self.imagined_rollout["mean"], self.imagined_rollout["std"])
with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules):
@ -528,6 +526,7 @@ class DPI:
imag_values_dist = self.value_model(self.pred_nxt_state_enc)
imag_rewards = imag_rewards_dist.mean
imag_values = imag_values_dist.mean
#print(torch.cat([imag_rewards[0], imag_values[0]],dim=-1))
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1],
@ -541,7 +540,6 @@ class DPI:
return actor_loss
def value_model_losses(self):
# value loss
with torch.no_grad():
value_feat = self.pred_nxt_state_enc[:-1].detach()
value_targ = self.returns.detach()
@ -771,7 +769,7 @@ if __name__ == '__main__':
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
step = 0
total_steps = 1000000
total_steps = 2000000
dpi = DPI(args)
dpi.train(step,total_steps)
dpi.evaluate()