diff --git a/DPI/train.py b/DPI/train.py index 622c0ad..8e2a867 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -7,6 +7,7 @@ import time import json import dmc2gym +import tqdm import wandb import utils from utils import ReplayBuffer, make_env, save_image @@ -33,10 +34,10 @@ def parse_args(): parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) - parser.add_argument('--total_frames', default=10000, type=int) + parser.add_argument('--total_frames', default=1000, type=int) # 10000 parser.add_argument('--high_noise', action='store_true') # replay buffer - parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000 + parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000 parser.add_argument('--episode_length', default=50, type=int) # train parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) @@ -130,10 +131,6 @@ class DPI: self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model')) self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer')) - # create video recorder - #video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files) - #video.init(enabled=True) - # create models self.build_models(use_saved=False, saved_model_dir=self.model_dir) @@ -174,28 +171,24 @@ class DPI: done = False #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) - for episode_count in range(episodes): - self.env.video.init(enabled=True) + for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): + #self.env.video.init(enabled=True) for i in range(self.args.episode_length): action = self.env.action_space.sample() next_obs, _, done, _ = self.env.step(action) self.data_buffer.add(obs, action, next_obs, episode_count+1, done) - if args.save_video: - self.env.video.record(self.env) + #if args.save_video: + # self.env.video.record(self.env) if done: obs = self.env.reset() done=False else: obs = next_obs - self.env.video.save('%d.mp4' % episode_count) + #self.env.video.save('%d.mp4' % episode_count) print("Collected {} random episodes".format(episode_count+1)) - #if args.save_video: - # video.record(self.env) - #video.save('%d.mp4' % step) - #video.close() def train(self): # collect experience @@ -204,26 +197,59 @@ class DPI: # Group observations and next_observations by steps observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float() next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float() - + actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float() + + # Initialize transition model states + self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128) + self.history = self.transition_model.prev_history # (N,128) + # Train encoder previous_information_loss = 0 + previous_encoder_loss = 0 for i in range(self.args.episode_length): # Encode observations and next_observations - self.features = self.obs_encoder(observations[i]) # (N,128) - self.next_features = self.obs_encoder(next_observations[i]) # (N,128) + self.states_dist = self.obs_encoder(observations[i]) + self.next_states_dist = self.obs_encoder(next_observations[i]) + + # Sample states and next_states + self.states = self.states_dist["sample"] # (N,128) + self.next_states = self.next_states_dist["sample"] # (N,128) + self.actions = actions[i] # (N,6) # Calculate upper bound loss - past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features) - previous_information_loss = past_loss - print("past_loss: ", past_loss) - - def upper_bound_minimization(self, features, next_features): - club_sample = CLUBSample(self.args.state_size, + past_latent_loss = previous_information_loss + self._upper_bound_minimization(self.states, self.next_states) + + # Calculate encoder loss + past_encoder_loss = previous_encoder_loss + self._past_encoder_loss(self.states, self.next_states, + self.states_dist, self.next_states_dist, + self.actions, self.history, i) + + previous_information_loss = past_latent_loss + previous_encoder_loss = past_encoder_loss + + def _upper_bound_minimization(self, states, next_states): + club_sample = CLUBSample(self.args.state_size, self.args.state_size, self.args.hidden_size) - club_loss = club_sample(features, next_features) - return club_loss + club_loss = club_sample(states, next_states) + return club_loss + def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step): + # Imagine next state + if step == 0: + actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step + imagined_next_states = self.transition_model.imagine_step(states, actions, history) + self.history = imagined_next_states["history"] + else: + imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128) + + # State Distribution + imagined_next_states_dist = imagined_next_states["distribution"] + + # KL divergence loss + loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean() + + return loss if __name__ == '__main__': args = parse_args()