From 6b4762d5fc0d9e45cfcb9cda009794e3f89ffe4c Mon Sep 17 00:00:00 2001 From: VedantDave Date: Sun, 9 Apr 2023 18:22:41 +0200 Subject: [PATCH] Changing Upper Bound loss --- DPI/train.py | 148 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 101 insertions(+), 47 deletions(-) diff --git a/DPI/train.py b/DPI/train.py index 219583c..db8c027 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -16,6 +16,8 @@ from logger import Logger from video import VideoRecorder from dmc2gym.wrappers import set_global_var +import torchvision.transforms as T + #from agent.baseline_agent import BaselineAgent #from agent.bisim_agent import BisimAgent #from agent.deepmdp_agent import DeepMDPAgent @@ -31,7 +33,7 @@ def parse_args(): parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--channels', default=3, type=int) parser.add_argument('--action_repeat', default=1, type=int) - parser.add_argument('--frame_stack', default=4, type=int) + parser.add_argument('--frame_stack', default=3, type=int) parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) @@ -39,18 +41,18 @@ def parse_args(): parser.add_argument('--high_noise', action='store_true') # replay buffer parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000 - parser.add_argument('--episode_length', default=50, type=int) + parser.add_argument('--episode_length', default=51, type=int) # train parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) - parser.add_argument('--init_steps', default=1000, type=int) - parser.add_argument('--num_train_steps', default=1000, type=int) - parser.add_argument('--batch_size', default=200, type=int) #512 + parser.add_argument('--init_steps', default=10000, type=int) + parser.add_argument('--num_train_steps', default=10000, type=int) + parser.add_argument('--batch_size', default=20, type=int) #512 parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models') parser.add_argument('--load_encoder', default=None, type=str) - parser.add_argument('--imagination_horizon', default=15, type=str) + parser.add_argument('--imagine_horizon', default=15, type=str) # eval parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--num_eval_episodes', default=20, type=int) @@ -113,6 +115,7 @@ class DPI: # environment setup self.env = make_env(self.args) + #self.args.seed = np.random.randint(0, 1000) self.env.seed(self.args.seed) # noiseless environment setup @@ -190,87 +193,123 @@ class DPI: def collect_sequences(self, episodes): obs = self.env.reset() - obs_clean = self.env_clean.reset() + #obs_clean = self.env_clean.reset() done = False #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): if args.save_video: self.env.video.init(enabled=True) - self.env_clean.video.init(enabled=True) + #self.env_clean.video.init(enabled=True) for i in range(self.args.episode_length): + action = self.env.action_space.sample() - next_obs, _, done, _ = self.env.step(action) - next_obs_clean, _, done, _ = self.env_clean.step(action) + next_obs, rew, done, _ = self.env.step(action) + #next_obs_clean, _, done, _ = self.env_clean.step(action) self.data_buffer.add(obs, action, next_obs, episode_count+1, done) - self.data_buffer_clean.add(obs_clean, action, next_obs_clean, episode_count+1, done) + #self.data_buffer_clean.add(obs_clean, action, next_obs_clean, episode_count+1, done) if args.save_video: - self.env.video.record(self.env_clean) - self.env_clean.video.record(self.env_clean) + self.env.video.record(self.env) + #self.env_clean.video.record(self.env_clean) - if done: + if done or i == self.args.episode_length-1: obs = self.env.reset() - obs_clean = self.env_clean.reset() - done=False + #obs_clean = self.env_clean.reset() + done=False else: obs = next_obs - obs_clean = next_obs_clean + #obs_clean = next_obs_clean if args.save_video: self.env.video.save('noisy/%d.mp4' % episode_count) - self.env_clean.video.save('clean/%d.mp4' % episode_count) + #self.env_clean.video.save('clean/%d.mp4' % episode_count) print("Collected {} random episodes".format(episode_count+1)) def train(self): # collect experience self.collect_sequences(self.args.batch_size) - # Group observations and next_observations by steps - observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float() - next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float() - actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float() + # Group observations and next_observations by steps from past to present + last_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1] + current_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[:self.args.episode_length-1] + next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[1:] + actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1] + next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:] # Initialize transition model states self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128) self.history = self.transition_model.prev_history # (N,128) # Train encoder - previous_information_loss = 0 - previous_encoder_loss = 0 - for i in range(self.args.episode_length): - # Encode observations and next_observations - self.states_dist = self.obs_encoder(observations[i]) - self.next_states_dist = self.obs_encoder(next_observations[i]) + total_ub_loss = 0 + total_encoder_loss = 0 + for i in range(self.args.episode_length-1): + if i > 0: + # Encode observations and next_observations + self.last_states_dict = self.obs_encoder(last_observations[i]) + self.current_states_dict = self.obs_encoder(current_observations[i]) + self.next_states_dict = self.obs_encoder_momentum(next_observations[i]) + self.action = actions[i] # (N,6) + history = self.transition_model.prev_history + + # Encode negative observations + idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch + random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index + negative_current_observations = current_observations[random_time_index][idx] + self.negative_current_states_dict = self.obs_encoder(negative_current_observations) - # Sample states and next_states - self.states = self.states_dist["sample"] # (N,128) - self.next_states = self.next_states_dist["sample"] # (N,128) - self.actions = actions[i] # (N,6) + # Predict current state from past state with transition model + last_states_sample = self.last_states_dict["sample"] + predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) + self.history = predicted_current_state_dict["history"] + + # Calculate upper bound loss + ub_loss = self._upper_bound_minimization(self.last_states_dict, + self.current_states_dict, + self.negative_current_states_dict, + predicted_current_state_dict + ) + + # Calculate encoder loss + encoder_loss = self._past_encoder_loss(self.current_states_dict, + predicted_current_state_dict) - # Calculate upper bound loss - past_latent_loss = previous_information_loss + self._upper_bound_minimization(self.states, self.next_states) + total_ub_loss += ub_loss + total_encoder_loss += encoder_loss - # Calculate encoder loss - past_encoder_loss = previous_encoder_loss + self._past_encoder_loss(self.states, self.next_states, - self.states_dist, self.next_states_dist, - self.actions, self.history, i) - + imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) + imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon) + print(imagine_horizon) + #exit() - print(past_encoder_loss, past_latent_loss) + #print(total_ub_loss, total_encoder_loss) - previous_information_loss = past_latent_loss - previous_encoder_loss = past_encoder_loss - def _upper_bound_minimization(self, states, next_states): - club_sample = CLUBSample(self.args.state_size, - self.args.state_size, - self.args.hidden_size) - club_loss = club_sample(states, next_states) + + def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states): + club_sample = CLUBSample(last_states, + current_states, + negative_current_states, + predicted_current_states) + club_loss = club_sample() return club_loss + + def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): + # current state distribution + curr_states_dist = curr_states_dict["distribution"] + # predicted current state distribution + predicted_curr_states_dist = predicted_curr_states_dict["distribution"] + + # KL divergence loss + loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() + + return loss + + """ def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step): # Imagine next state if step == 0: @@ -287,6 +326,21 @@ class DPI: loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean() return loss + """ + + def get_features(self, x, momentum=False): + if self.aug: + x = T.RandomCrop((80, 80))(x) # (None,80,80,4) + x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) + x = T.RandomCrop((84, 84))(x) # (None,84,84,4) + + with torch.no_grad(): + x = (x.float() - self.ob_mean) / self.ob_std + if momentum: + x = self.obs_encoder(x).detach() + else: + x = self.obs_encoder_momentum(x) + return x if __name__ == '__main__': args = parse_args()