Update train model
This commit is contained in:
parent
ab4e7b9a22
commit
bb0265846e
34
DPI/train.py
34
DPI/train.py
@ -50,9 +50,9 @@ def parse_args():
|
|||||||
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
|
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
|
||||||
parser.add_argument('--init_steps', default=5000, type=int)
|
parser.add_argument('--init_steps', default=5000, type=int)
|
||||||
parser.add_argument('--num_train_steps', default=100000, type=int)
|
parser.add_argument('--num_train_steps', default=100000, type=int)
|
||||||
parser.add_argument('--update_steps', default=100, type=int)
|
parser.add_argument('--update_steps', default=10, type=int)
|
||||||
parser.add_argument('--batch_size', default=64, type=int)
|
parser.add_argument('--batch_size', default=64, type=int)
|
||||||
parser.add_argument('--state_size', default=50, type=int)
|
parser.add_argument('--state_size', default=100, type=int)
|
||||||
parser.add_argument('--hidden_size', default=512, type=int)
|
parser.add_argument('--hidden_size', default=512, type=int)
|
||||||
parser.add_argument('--history_size', default=256, type=int)
|
parser.add_argument('--history_size', default=256, type=int)
|
||||||
parser.add_argument('--episode_collection', default=5, type=int)
|
parser.add_argument('--episode_collection', default=5, type=int)
|
||||||
@ -66,21 +66,21 @@ def parse_args():
|
|||||||
parser.add_argument('--num_eval_episodes', default=20, type=int)
|
parser.add_argument('--num_eval_episodes', default=20, type=int)
|
||||||
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
|
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
|
||||||
# value
|
# value
|
||||||
parser.add_argument('--value_lr', default=8e-5, type=float)
|
parser.add_argument('--value_lr', default=8e-6, type=float)
|
||||||
parser.add_argument('--value_target_update_freq', default=100, type=int)
|
parser.add_argument('--value_target_update_freq', default=100, type=int)
|
||||||
parser.add_argument('--td_lambda', default=0.95, type=int)
|
parser.add_argument('--td_lambda', default=0.95, type=int)
|
||||||
# actor
|
# actor
|
||||||
parser.add_argument('--actor_lr', default=8e-5, type=float)
|
parser.add_argument('--actor_lr', default=8e-6, type=float)
|
||||||
parser.add_argument('--actor_beta', default=0.9, type=float)
|
parser.add_argument('--actor_beta', default=0.9, type=float)
|
||||||
parser.add_argument('--actor_log_std_min', default=-10, type=float)
|
parser.add_argument('--actor_log_std_min', default=-10, type=float)
|
||||||
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
||||||
parser.add_argument('--actor_update_freq', default=2, type=int)
|
parser.add_argument('--actor_update_freq', default=2, type=int)
|
||||||
# world/encoder/decoder
|
# world/encoder/decoder
|
||||||
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
|
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
|
||||||
parser.add_argument('--world_model_lr', default=6e-5, type=float)
|
parser.add_argument('--world_model_lr', default=1e-6, type=float)
|
||||||
parser.add_argument('--decoder_lr', default=6e-4, type=float)
|
parser.add_argument('--decoder_lr', default=6e-6, type=float)
|
||||||
parser.add_argument('--reward_lr', default=6e-5, type=float)
|
parser.add_argument('--reward_lr', default=8e-6, type=float)
|
||||||
parser.add_argument('--encoder_tau', default=0.001, type=float)
|
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
||||||
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
|
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
|
||||||
parser.add_argument('--num_layers', default=4, type=int)
|
parser.add_argument('--num_layers', default=4, type=int)
|
||||||
parser.add_argument('--num_filters', default=32, type=int)
|
parser.add_argument('--num_filters', default=32, type=int)
|
||||||
@ -238,10 +238,10 @@ class DPI:
|
|||||||
|
|
||||||
# optimizers
|
# optimizers
|
||||||
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6)
|
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6)
|
||||||
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6, weight_decay=1e-5)
|
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6)
|
||||||
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6)
|
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6)
|
||||||
self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6)
|
self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6)
|
||||||
self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6, weight_decay=1e-5)
|
self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6)
|
||||||
|
|
||||||
# Create Modules
|
# Create Modules
|
||||||
self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head,
|
self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head,
|
||||||
@ -461,19 +461,18 @@ class DPI:
|
|||||||
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
|
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
|
||||||
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms)
|
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms)
|
||||||
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
|
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
|
||||||
#print(self.observed_rollout["mean"][0][0])
|
|
||||||
self.pred_curr_state_enc = self.pred_curr_state_dist.rsample() #self.observed_rollout["sample"]
|
self.pred_curr_state_enc = self.pred_curr_state_dist.rsample() #self.observed_rollout["sample"]
|
||||||
|
|
||||||
# encoder loss
|
# encoder loss
|
||||||
enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
|
enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
|
||||||
|
|
||||||
# reward loss
|
# reward loss
|
||||||
rew_dist = self.reward_model(self.curr_state_enc)
|
rew_dist = self.reward_model(self.curr_state_enc.detach())
|
||||||
#print(torch.cat([rew_dist.mean[0], rewards[0]],dim=-1))
|
#print(torch.cat([rew_dist.mean[0], rewards[0]],dim=-1))
|
||||||
rew_loss = -torch.mean(rew_dist.log_prob(rewards))
|
rew_loss = -torch.mean(rew_dist.log_prob(rewards))
|
||||||
|
|
||||||
# decoder loss
|
# decoder loss
|
||||||
dec_dist = self.obs_decoder(self.nxt_state_enc)
|
dec_dist = self.obs_decoder(self.nxt_state_enc.detach())
|
||||||
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
|
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
|
||||||
|
|
||||||
# upper bound loss
|
# upper bound loss
|
||||||
@ -484,7 +483,7 @@ class DPI:
|
|||||||
ub_loss = ub_loss + past_ub_loss
|
ub_loss = ub_loss + past_ub_loss
|
||||||
past_ub_loss = ub_loss
|
past_ub_loss = ub_loss
|
||||||
ub_loss = ub_loss / self.curr_state_enc.shape[0]
|
ub_loss = ub_loss / self.curr_state_enc.shape[0]
|
||||||
ub_loss = 0.01 * ub_loss
|
ub_loss = 1 * ub_loss
|
||||||
|
|
||||||
# lower bound loss
|
# lower bound loss
|
||||||
# contrastive projection
|
# contrastive projection
|
||||||
@ -520,14 +519,14 @@ class DPI:
|
|||||||
action, curr_state_hist.detach(),
|
action, curr_state_hist.detach(),
|
||||||
imagine_horizon)
|
imagine_horizon)
|
||||||
self.pred_nxt_state_dist = self.transition_model.get_dist(self.imagined_rollout["mean"], self.imagined_rollout["std"])
|
self.pred_nxt_state_dist = self.transition_model.get_dist(self.imagined_rollout["mean"], self.imagined_rollout["std"])
|
||||||
#print(self.imagined_rollout["mean"][0][0])
|
|
||||||
self.pred_nxt_state_enc = self.pred_nxt_state_dist.rsample() #self.transition_model.reparemeterize(self.imagined_rollout["mean"], self.imagined_rollout["std"])
|
self.pred_nxt_state_enc = self.pred_nxt_state_dist.rsample() #self.transition_model.reparemeterize(self.imagined_rollout["mean"], self.imagined_rollout["std"])
|
||||||
|
|
||||||
with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules):
|
with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules):
|
||||||
imag_rewards_dist = self.reward_model(self.pred_nxt_state_enc)
|
imag_rewards_dist = self.reward_model(self.pred_nxt_state_enc)
|
||||||
imag_values_dist = self.value_model(self.pred_nxt_state_enc)
|
imag_values_dist = self.value_model(self.pred_nxt_state_enc)
|
||||||
imag_rewards = imag_rewards_dist.mean
|
imag_rewards = imag_rewards_dist.mean
|
||||||
imag_values = imag_values_dist.mean
|
imag_values = imag_values_dist.mean
|
||||||
|
#print(torch.cat([imag_rewards[0], imag_values[0]],dim=-1))
|
||||||
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
|
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
|
||||||
|
|
||||||
self.returns = self._compute_lambda_return(imag_rewards[:-1],
|
self.returns = self._compute_lambda_return(imag_rewards[:-1],
|
||||||
@ -541,7 +540,6 @@ class DPI:
|
|||||||
return actor_loss
|
return actor_loss
|
||||||
|
|
||||||
def value_model_losses(self):
|
def value_model_losses(self):
|
||||||
# value loss
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
value_feat = self.pred_nxt_state_enc[:-1].detach()
|
value_feat = self.pred_nxt_state_enc[:-1].detach()
|
||||||
value_targ = self.returns.detach()
|
value_targ = self.returns.detach()
|
||||||
@ -771,7 +769,7 @@ if __name__ == '__main__':
|
|||||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
total_steps = 1000000
|
total_steps = 2000000
|
||||||
dpi = DPI(args)
|
dpi = DPI(args)
|
||||||
dpi.train(step,total_steps)
|
dpi.train(step,total_steps)
|
||||||
dpi.evaluate()
|
dpi.evaluate()
|
Loading…
Reference in New Issue
Block a user