diff --git a/DPI/train.py b/DPI/train.py index cf26d3e..4138b9c 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -6,11 +6,12 @@ import wandb import random import argparse import numpy as np +from collections import OrderedDict import utils from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image +from replay_buffer import ReplayBuffer from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample -from logger import Logger from video import VideoRecorder from dmc2gym.wrappers import set_global_var @@ -40,24 +41,25 @@ def parse_args(): parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) - parser.add_argument('--total_frames', default=1000, type=int) # 10000 + parser.add_argument('--total_frames', default=5000, type=int) # 10000 parser.add_argument('--high_noise', action='store_true') # replay buffer parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000 - parser.add_argument('--episode_length', default=21, type=int) + parser.add_argument('--episode_length', default=51, type=int) # train parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--num_train_steps', default=100000, type=int) - parser.add_argument('--batch_size', default=128, type=int) #512 - parser.add_argument('--state_size', default=30, type=int) - parser.add_argument('--hidden_size', default=256, type=int) + parser.add_argument('--update_steps', default=1, type=int) + parser.add_argument('--batch_size', default=64, type=int) #512 + parser.add_argument('--state_size', default=50, type=int) + parser.add_argument('--hidden_size', default=512, type=int) parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--episode_collection', default=5, type=int) - parser.add_argument('--episodes_buffer', default=20, type=int) + parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer') parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models') parser.add_argument('--load_encoder', default=None, type=str) - parser.add_argument('--imagine_horizon', default=10, type=str) + parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm') # eval parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 @@ -65,8 +67,6 @@ def parse_args(): parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000 # value parser.add_argument('--value_lr', default=8e-6, type=float) - parser.add_argument('--value_beta', default=0.9, type=float) - parser.add_argument('--value_tau', default=0.005, type=float) parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--td_lambda', default=0.95, type=int) # actor @@ -78,13 +78,13 @@ def parse_args(): # world/encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--world_model_lr', default=1e-5, type=float) - parser.add_argument('--encoder_tau', default=0.005, type=float) + parser.add_argument('--encoder_tau', default=0.001 , type=float) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--aug', action='store_true') # sac - parser.add_argument('--discount', default=0.99, type=float) + parser.add_argument('--discount', default=0.95, type=float) # misc parser.add_argument('--seed', default=1, type=int) parser.add_argument('--logging_freq', default=100, type=int) @@ -131,6 +131,7 @@ class DPI: self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.ActionRepeat(self.env, self.args.action_repeat) self.env = utils.NormalizeActions(self.env) + self.env = utils.TimeLimit(self.env, 1000 / args.action_repeat) # create replay buffer self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, @@ -250,7 +251,7 @@ class DPI: self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) - def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None): + def collect_random_sequences(self, episodes): obs = self.env.reset() done = False @@ -259,239 +260,278 @@ class DPI: self.global_episodes += 1 epi_reward = 0 while not done: - if random: - action = self.env.action_space.sample() - else: - with torch.no_grad(): - obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0) - obs_processed = preprocess_obs(obs).to(device) - state = self.obs_encoder(obs_processed)["distribution"].sample() - action = self.actor_model(state).cpu().numpy().squeeze() - #action = self.env.action_space.sample() - + action = self.env.action_space.sample() + next_obs, rew, done, _ = self.env.step(action) + self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes) + obs = next_obs + epi_reward += rew + obs = self.env.reset() + done=False + all_rews.append(epi_reward) + return all_rews + + def collect_sequences(self, episodes, actor_model): + obs = self.env.reset() + done = False + all_rews = [] + for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): + self.global_episodes += 1 + epi_reward = 0 + while not done: + with torch.no_grad(): + obs = torch.tensor(obs.copy(), dtype=torch.float32).to(device).unsqueeze(0) + state = self.get_features(obs)["distribution"].rsample() + action = self.actor_model(state) + action = actor_model.add_exploration(action).cpu().numpy()[0] + print(action) + obs = obs.cpu().numpy()[0] next_obs, rew, done, _ = self.env.step(action) self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes) obs = next_obs epi_reward += rew - obs = self.env.reset() done=False all_rews.append(epi_reward) return all_rews def train(self, step, total_steps): - counter = 0 - import matplotlib.pyplot as plt - fig, ax = plt.subplots() - while step < total_steps: - - # collect experience - if step !=0: - encoder = self.obs_encoder - actor = self.actor_model - all_rews = self.collect_sequences(self.args.episode_collection, random=False, actor_model=actor, encoder_model=encoder) - else: - all_rews = self.collect_sequences(self.args.episodes_buffer, random=True) + episodic_rews = self.collect_random_sequences(self.args.episodes_buffer) + global_step = self.data_buffer.steps + + # logger + logs = OrderedDict() - # collect sequences - non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0] - current_obs = self.data_buffer.observations[non_zero_indices] - next_obs = self.data_buffer.next_observations[non_zero_indices] - actions_raw = self.data_buffer.actions[non_zero_indices] - rewards = self.data_buffer.rewards[non_zero_indices] - self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=0)[0] - - # Group by episodes - current_obs = self.grouped_arrays(current_obs) - next_obs = self.grouped_arrays(next_obs) - actions_raw = self.grouped_arrays(actions_raw) - rewards_ = self.grouped_arrays(rewards) + while global_step < total_steps: + step += 1 + for update_steps in range(self.args.update_steps): + model_loss, actor_loss, value_loss = self.update((step-1)*args.update_steps + update_steps) + episodic_rews = self.collect_sequences(self.args.episode_collection, actor_model=self.actor_model, encoder_model=self.obs_encoder) - # Train encoder - if step == 0: - step += 1 - - update_steps = 1 if step > 1 else 1 - #for _ in range(self.args.collection_interval // self.args.episode_length+1): - for _ in range(update_steps): - counter += 1 + logs.update({ + 'model_loss' : model_loss, + 'actor_loss': actor_loss, + 'value_loss': value_loss, + 'train_avg_reward':np.mean(episodic_rews), + 'train_max_reward': np.max(episodic_rews), + 'train_min_reward': np.min(episodic_rews), + 'train_std_reward':np.std(episodic_rews), + }) - # Select random chunks of episodes - if current_obs.shape[0] < self.args.batch_size: - random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size) - else: - random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size) - - if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size: - init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size) - else: - init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size)) - random.shuffle(random_episode_number) - random.shuffle(init_index) - last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1] - current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:] - next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1] - actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device) - next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device) - rewards = self.select_first_k(rewards_, init_index, random_episode_number)[1:].to(device) - - # Preprocessing - last_observations = preprocess_obs(last_observations).to(device) - current_observations = preprocess_obs(current_observations).to(device) - next_observations = preprocess_obs(next_observations).to(device) + print("########## Global Step: ", global_step, " ##########") + for key, value in logs.items(): + print(key, " : ", value) - # Initialize transition model states - self.transition_model.init_states(self.args.batch_size, device) # (N,128) - self.history = self.transition_model.prev_history # (N,128) - - past_world_model_loss = 0 - past_action_loss = 0 - past_value_loss = 0 - for i in range(self.args.episode_length-1): - if i > 0: - # Encode observations and next_observations - self.last_states_dict = self.get_features(last_observations[i]) - self.current_states_dict = self.get_features(current_observations[i]) - self.next_states_dict = self.get_features(next_observations[i], momentum=True) - self.action = actions[i] # (N,6) - self.next_action = next_actions[i] # (N,6) - history = self.transition_model.prev_history - - # Encode negative observations fro upper bound loss - idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch - random_time_index = torch.randint(0, current_observations.shape[0]-2, (1,)).item() # random time index - negative_current_observations = current_observations[random_time_index][idx] - self.negative_current_states_dict = self.get_features(negative_current_observations) - - # Predict current state from past state with transition model - last_states_sample = self.last_states_dict["sample"] - predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) - self.history = predicted_current_state_dict["history"] - - # Calculate upper bound loss - likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict["sample"].detach(), - self.current_states_dict["sample"].detach(), - self.negative_current_states_dict["sample"].detach(), - predicted_current_state_dict["sample"].detach(), - ) - - # Calculate encoder loss - encoder_loss = self._past_encoder_loss(self.current_states_dict, predicted_current_state_dict) - - # decoder loss - horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) - nxt_obs = next_observations[i:i+horizon].reshape(-1,9,84,84) - next_states_encodings = self.get_features(nxt_obs)["sample"].view(horizon,self.args.batch_size, -1) - obs_dist = self.obs_decoder(next_states_encodings) - decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon])) - - # contrastive projection - vec_anchor = predicted_current_state_dict["sample"].detach() - vec_positive = self.next_states_dict["sample"].detach() - z_anchor = self.prjoection_head(vec_anchor, self.action) - z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach() - - # contrastive loss - logits = self.contrastive_head(z_anchor, z_positive) - labels = torch.arange(logits.shape[0]).long().to(device) - lb_loss = F.cross_entropy(logits, labels) - - # reward loss - reward_dist = self.reward_model(self.current_states_dict["sample"]) - reward_loss = -torch.mean(reward_dist.log_prob(rewards[i])) - - # world model loss - world_model_loss = (10*encoder_loss + 10*ub_loss + 1e-1*lb_loss + reward_loss + 1e-3*decoder_loss + past_world_model_loss) * 1e-3 - past_world_model_loss = world_model_loss.item() - - # actor loss - with FreezeParameters(self.world_model_modules): - imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) - action = self.actor_model(self.current_states_dict["sample"]) - imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], - action, self.history, - imagine_horizon) - - with FreezeParameters(self.world_model_modules + self.value_modules): - imag_rewards = self.reward_model(imagined_rollout["sample"]).mean - imag_values = self.target_value_model(imagined_rollout["sample"]).mean - - discounts = self.args.discount * torch.ones_like(imag_rewards).detach() - - self.returns = self._compute_lambda_return(imag_rewards[:-1], - imag_values[:-1], - discounts[:-1] , - self.args.td_lambda, - imag_values[-1]) - - discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) - self.discounts = torch.cumprod(discounts, 0).detach() - actor_loss = -torch.mean(self.discounts * self.returns) + past_action_loss - past_action_loss = actor_loss.item() - - # value loss - with torch.no_grad(): - value_feat = imagined_rollout["sample"][:-1].detach() - value_targ = self.returns.detach() - - value_dist = self.value_model(value_feat) - - value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) + past_value_loss - past_value_loss = value_loss.item() - - # update target value - if step % self.args.value_target_update_freq == 0: - self.target_value_model = copy.deepcopy(self.value_model) - - # counter for reward - #count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size)) - count = (counter-1) * (self.args.batch_size) - if step % self.args.logging_freq: - writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), self.data_buffer.steps) - writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), self.data_buffer.steps) - writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, self.data_buffer.steps) - writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), self.data_buffer.steps) - writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), self.data_buffer.steps) - writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), self.data_buffer.steps) - writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), self.data_buffer.steps) - writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), self.data_buffer.steps) - step += 1 - - print(world_model_loss, actor_loss, value_loss) - - # update actor model - self.actor_opt.zero_grad() - actor_loss.backward() - nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) - self.actor_opt.step() - - # update world model - self.world_model_opt.zero_grad() - world_model_loss.backward() - nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) - self.world_model_opt.step() - - # update value model - self.value_opt.zero_grad() - value_loss.backward() - nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) - self.value_opt.step() - - # update momentum encoder and projection head - soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau) - soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) - - rew_len = np.arange(count, count+self.args.episode_collection) if count != 0 else np.arange(0, self.args.batch_size) - for j in range(len(all_rews)): - writer.add_scalar('Rewards/Rewards', all_rews[j], rew_len[j]) - - print(step) - if step % 2850 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0: + print(global_step) + if global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0: print("Saving model") path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" self.save_models(path) self.evaluate() - + + global_step = self.data_buffer.steps + + # collect experience + if step !=0: + encoder = self.obs_encoder + actor = self.actor_model + all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder) + + def update(self, step): + last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch() + + + world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations, + current_observations, + next_observations, + actions, + next_actions, + rewards) + self.world_model_opt.zero_grad() + world_loss.backward() + nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) + self.world_model_opt.step() + + actor_loss = self.actor_model_losses() + self.actor_opt.zero_grad() + actor_loss.backward() + nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) + self.actor_opt.step() + + value_loss = self.value_model_losses() + self.value_opt.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) + self.value_opt.step() + + # update momentum encoder and projection head + soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau) + soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) + + # update target value networks + if step % self.args.value_target_update_freq == 0: + self.target_value_model = copy.deepcopy(self.value_model) + + if step % self.args.logging_freq: + writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step) + writer.add_scalar('Main Models Loss/Encoder Loss', enc_loss.detach().item(), step) + writer.add_scalar('Main Models Loss/Decoder Loss', dec_loss, step) + writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step) + writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step) + writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step) + writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step) + writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step) + + return world_loss.item(), actor_loss.item(), value_loss.item() + + def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards): + self.last_state_feat = self.get_features(last_obs) + self.curr_state_feat = self.get_features(curr_obs) + self.nxt_state_feat = self.get_features(nxt_obs) + + # states + self.last_state_enc = self.last_state_feat["sample"] + self.curr_state_enc = self.curr_state_feat["sample"] + self.nxt_state_enc = self.nxt_state_feat["sample"] + + # actions + actions = actions + nxt_actions = nxt_actions + + # rewards + rewards = rewards + + # predict next states + self.transition_model.init_states(self.args.batch_size, device) # (N,128) + self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history) + self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"]) + self.pred_curr_state_enc = self.pred_curr_state_dist.mean + + #print(torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc)) + #print(torch.distributions.kl_divergence(self.curr_state_feat["distribution"], self.pred_curr_state_dist).mean(),0) + + + # encoder loss + enc_loss = torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc) + #self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist) + + # reward loss + rew_dist = self.reward_model(self.curr_state_enc) + rew_loss = -torch.mean(rew_dist.log_prob(rewards.unsqueeze(-1))) + + # decoder loss + dec_dist = self.obs_decoder(self.nxt_state_enc) + dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs)) + + # upper bound loss + likelihood_loss, ub_loss = self._upper_bound_minimization(self.curr_state_enc, + self.pred_curr_state_enc) + + # lower bound loss + # contrastive projection + vec_anchor = self.pred_curr_state_enc + vec_positive = self.nxt_state_enc + z_anchor = self.prjoection_head(vec_anchor, nxt_actions) + z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions) + + # contrastive loss + past_lb_loss = 0 + for i in range(z_anchor.shape[0]): + logits = self.contrastive_head(z_anchor[i], z_positive[i]) + labels = torch.arange(logits.shape[0]).long().to(device) + lb_loss = F.cross_entropy(logits, labels) + past_lb_loss + past_lb_loss = lb_loss.detach().item() + lb_loss = lb_loss/(z_anchor.shape[0]) + + world_loss = enc_loss + rew_loss + dec_loss * 1e-4 + ub_loss * 10 + lb_loss + + return world_loss, enc_loss , rew_loss, dec_loss * 1e-4, ub_loss * 10, lb_loss + + def actor_model_losses(self): + with torch.no_grad(): + curr_state_enc = self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"] + curr_state_hist = self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"] + + with FreezeParameters(self.world_model_modules): + imagine_horizon = self.args.imagine_horizon + action = self.actor_model(curr_state_enc) + self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc, + action, curr_state_hist, + imagine_horizon) + + with FreezeParameters(self.world_model_modules + self.value_modules): + imag_rewards = self.reward_model(self.imagined_rollout["sample"]).mean + imag_values = self.target_value_model(self.imagined_rollout["sample"]).mean + discounts = self.args.discount * torch.ones_like(imag_rewards).detach() + self.returns = self._compute_lambda_return(imag_rewards[:-1], + imag_values[:-1], + discounts[:-1] , + self.args.td_lambda, + imag_values[-1]) + + discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) + self.discounts = torch.cumprod(discounts, 0).detach() + actor_loss = -torch.mean(self.discounts * self.returns) + return actor_loss + + + def value_model_losses(self): + # value loss + with torch.no_grad(): + value_feat = self.imagined_rollout["sample"][:-1].detach() + value_targ = self.returns.detach() + value_dist = self.value_model(value_feat) + value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) + return value_loss + + + def select_one_batch(self): + # collect sequences + non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0] + current_obs = self.data_buffer.observations[non_zero_indices] + next_obs = self.data_buffer.next_observations[non_zero_indices] + actions_raw = self.data_buffer.actions[non_zero_indices] + rewards = self.data_buffer.rewards[non_zero_indices] + self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=False)[0] + + # group by episodes + current_obs = self.grouped_arrays(current_obs) + next_obs = self.grouped_arrays(next_obs) + actions_raw = self.grouped_arrays(actions_raw) + rewards_ = self.grouped_arrays(rewards) + + # select random chunks of episodes + if current_obs.shape[0] < self.args.batch_size: + random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size) + else: + random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size) + + # select random starting points + if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size: + init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size) + else: + init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size)) + + # shuffle + random.shuffle(random_episode_number) + random.shuffle(init_index) + + # select first k elements + last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1] + current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:] + next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1] + actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device) + next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device) + rewards = self.select_first_k(rewards_, init_index, random_episode_number)[:-1].to(device) + + # preprocessing + last_observations = preprocess_obs(last_observations).to(device) + current_observations = preprocess_obs(current_observations).to(device) + next_observations = preprocess_obs(next_observations).to(device) + + return last_observations, current_observations, next_observations, actions, next_actions, rewards + + + def evaluate(self, eval_episodes=10): path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" @@ -511,7 +551,7 @@ class DPI: with torch.no_grad(): obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0) obs_processed = preprocess_obs(obs).to(device) - state = self.obs_encoder(obs_processed)["distribution"].sample() + state = self.get_features(obs_processed)["distribution"].rsample() action = self.actor_model(state).cpu().detach().numpy().squeeze() next_obs, rew, done, _ = self.env.step(action) @@ -534,11 +574,9 @@ class DPI: def grouped_arrays(self,array): indices = [0] + self.terms.tolist() - def subarrays(): for start, end in zip(indices[:-1], indices[1:]): yield array[start:end] - try: subarrays = np.stack(list(subarrays()), axis=0) except ValueError: @@ -548,13 +586,13 @@ class DPI: def select_first_k(self, array, init_index, episode_number): term_index = init_index + self.args.episode_length + array = array[episode_number] array_list = [] for i in range(array.shape[0]): array_list.append(array[i][init_index[i]:term_index[i]]) array = np.asarray(array_list) - if array.ndim == 5: transposed_array = np.transpose(array, (1, 0, 2, 3, 4)) elif array.ndim == 4: @@ -565,20 +603,16 @@ class DPI: transposed_array = np.transpose(array, (1, 0)) else: transposed_array = np.expand_dims(array, axis=0) + + #return torch.tensor(array).float() return torch.tensor(transposed_array).float() - def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states): - club_loss = self.club_sample(current_states, predicted_current_states, negative_current_states) + def _upper_bound_minimization(self, current_states, predicted_current_states): + club_loss = self.club_sample(current_states, predicted_current_states, current_states) likelihood_loss = 0 return likelihood_loss, club_loss - def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): - # current state distribution - curr_states_dist = curr_states_dict["distribution"] - - # predicted current state distribution - predicted_curr_states_dist = predicted_curr_states_dict["distribution"] - + def _encoder_loss(self, curr_states_dist, predicted_curr_states_dist): # KL divergence loss loss = torch.mean(torch.distributions.kl.kl_divergence(curr_states_dist,predicted_curr_states_dist))