diff --git a/DPI/models.py b/DPI/models.py index e3ae29f..012dc4c 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -131,12 +131,16 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator self.p_mu = nn.Sequential( nn.Linear(x_dim, hidden_size//2), nn.ReLU(), + nn.Linear(hidden_size//2, hidden_size//2), + nn.ReLU(), nn.Linear(hidden_size//2, y_dim) ) self.p_logvar = nn.Sequential( nn.Linear(x_dim, hidden_size//2), nn.ReLU(), + nn.Linear(hidden_size//2, hidden_size//2), + nn.ReLU(), nn.Linear(hidden_size//2, y_dim), nn.Tanh() ) @@ -182,5 +186,4 @@ if __name__ == "__main__": y_enc = encoder(y) print(x_enc.shape) print(y_enc.shape) - print(club.learning_loss(x_enc, y_enc)) - + print(club.learning_loss(x_enc, y_enc)) \ No newline at end of file diff --git a/DPI/train.py b/DPI/train.py index 693d372..1b9c16f 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -169,20 +169,26 @@ class DPI: done = False for episode_count in range(episodes): + video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) + video.init(enabled=True) for i in range(self.args.episode_length): action = self.env.action_space.sample() next_obs, _, done, _ = self.env.step(action) self.data_buffer.add(obs, action, next_obs, episode_count+1, done) - + + if args.save_video: + video.record(self.env) + if done: obs = self.env.reset() done=False else: obs = next_obs + video.save('%d.mp4' % episode_count) print("Collected {} random episodes".format(episode_count+1)) #if args.save_video: - # video.record(env) + # video.record(self.env) #video.save('%d.mp4' % step) #video.close() @@ -195,13 +201,16 @@ class DPI: next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float() # Train encoder + previous_information_loss = 0 for i in range(self.args.episode_length): # Encode observations and next_observations self.features = self.obs_encoder(observations[i]) # (N,128) self.next_features = self.obs_encoder(next_observations[i]) # (N,128) # Calculate upper bound loss - past_loss = self.upper_bound_minimization(self.features, self.next_features) + past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features) + previous_information_loss = past_loss + print("past_loss: ", past_loss) def upper_bound_minimization(self, features, next_features): club_sample = CLUBSample(self.args.state_size, diff --git a/DPI/utils.py b/DPI/utils.py index e9ec827..de28442 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -1,3 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + import os import torch import numpy as np @@ -166,6 +172,9 @@ class ReplayBuffer: def make_env(args): + # For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml, + # else change to (0.5, 0.5, 0.5, 1.0) for default ground plane color + # https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom env = dmc2gym.make( domain_name=args.domain_name, task_name=args.task_name, diff --git a/DPI/video.py b/DPI/video.py index 964f3e1..c372207 100644 --- a/DPI/video.py +++ b/DPI/video.py @@ -1,3 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + import imageio import os import numpy as np @@ -51,4 +57,4 @@ class VideoRecorder(object): def save(self, file_name): if self.enabled: path = os.path.join(self.dir_name, file_name) - imageio.mimsave(path, self.frames, fps=self.fps) + imageio.mimsave(path, self.frames, fps=self.fps) \ No newline at end of file