diff --git a/DPI/models.py b/DPI/models.py index a047331..e3ae29f 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -128,25 +128,27 @@ class TransitionModel(nn.Module): class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def __init__(self, x_dim, y_dim, hidden_size): super(CLUBSample, self).__init__() - self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2), - nn.ReLU(), - nn.Linear(hidden_size//2, y_dim)) + self.p_mu = nn.Sequential( + nn.Linear(x_dim, hidden_size//2), + nn.ReLU(), + nn.Linear(hidden_size//2, y_dim) + ) - self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2), - nn.ReLU(), - nn.Linear(hidden_size//2, y_dim), - nn.Tanh()) + self.p_logvar = nn.Sequential( + nn.Linear(x_dim, hidden_size//2), + nn.ReLU(), + nn.Linear(hidden_size//2, y_dim), + nn.Tanh() + ) def get_mu_logvar(self, x_samples): mu = self.p_mu(x_samples) logvar = self.p_logvar(x_samples) return mu, logvar - - + def loglikeli(self, x_samples, y_samples): mu, logvar = self.get_mu_logvar(x_samples) return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0) - def forward(self, x_samples, y_samples): mu, logvar = self.get_mu_logvar(x_samples) @@ -165,8 +167,9 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator if __name__ == "__main__": encoder = ObservationEncoder((12,84,84), 256) - x = torch.randn(100, 12, 84, 84) + x = torch.randn(5000, 12, 84, 84) print(encoder(x).shape) + exit() club = CLUBSample(256, 256 , 512) x = torch.randn(100, 256) diff --git a/DPI/train.py b/DPI/train.py index 7597171..693d372 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -9,7 +9,7 @@ import dmc2gym import wandb import utils -from utils import ReplayBuffer, make_env +from utils import ReplayBuffer, make_env, save_image from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample from logger import Logger from video import VideoRecorder @@ -34,18 +34,18 @@ def parse_args(): parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--total_frames', default=1000, type=int) # replay buffer - parser.add_argument('--replay_buffer_capacity', default=100000, type=int) - parser.add_argument('--episode_length', default=1000, type=int) + parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000 + parser.add_argument('--episode_length', default=50, type=int) # train parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--num_train_steps', default=1000, type=int) - parser.add_argument('--batch_size', default=512, type=int) + parser.add_argument('--batch_size', default=200, type=int) #512 parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int) - parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model') parser.add_argument('--load_encoder', default=None, type=str) + parser.add_argument('--imagination_horizon', default=15, type=str) # eval parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--num_eval_episodes', default=20, type=int) @@ -79,7 +79,6 @@ def parse_args(): parser.add_argument('--alpha_beta', default=0.9, type=float) # misc parser.add_argument('--seed', default=1, type=int) - parser.add_argument('--seed_steps', default=5000, type=int) parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true') @@ -117,20 +116,21 @@ class DPI: obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), action_size=self.env.action_space.shape[0], seq_len=self.args.episode_length, - batch_size=args.batch_size) + batch_size=args.batch_size, + args=self.args) # create work directory utils.make_dir(self.args.work_dir) - video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video')) - model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model')) - buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer')) + self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video')) + self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model')) + self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer')) # create video recorder #video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files) #video.init(enabled=True) # create models - self.build_models(use_saved=False, saved_model_dir=model_dir) + self.build_models(use_saved=False, saved_model_dir=self.model_dir) def build_models(self, use_saved, saved_model_dir=None): self.obs_encoder = ObservationEncoder( @@ -167,13 +167,14 @@ class DPI: def collect_random_episodes(self, episodes): obs = self.env.reset() done = False - + for episode_count in range(episodes): for i in range(self.args.episode_length): action = self.env.action_space.sample() next_obs, _, done, _ = self.env.step(action) self.data_buffer.add(obs, action, next_obs, episode_count+1, done) + if done: obs = self.env.reset() done=False @@ -185,12 +186,33 @@ class DPI: #video.save('%d.mp4' % step) #video.close() - def upper_bound_minimization(self): - pass + def train(self): + # collect experience + self.collect_random_episodes(self.args.batch_size) + + # Group observations and next_observations by steps + observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float() + next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float() + + # Train encoder + for i in range(self.args.episode_length): + # Encode observations and next_observations + self.features = self.obs_encoder(observations[i]) # (N,128) + self.next_features = self.obs_encoder(next_observations[i]) # (N,128) + + # Calculate upper bound loss + past_loss = self.upper_bound_minimization(self.features, self.next_features) + + def upper_bound_minimization(self, features, next_features): + club_sample = CLUBSample(self.args.state_size, + self.args.state_size, + self.args.hidden_size) + club_loss = club_sample(features, next_features) + return club_loss if __name__ == '__main__': args = parse_args() dpi = DPI(args) - dpi.collect_random_episodes(episodes=5) \ No newline at end of file + dpi.train() \ No newline at end of file diff --git a/DPI/utils.py b/DPI/utils.py index 4a6df6e..a7a5612 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -13,6 +13,7 @@ import gym import dmc2gym import random +from PIL import Image from collections import deque @@ -105,7 +106,7 @@ class FrameStack(gym.Wrapper): class ReplayBuffer: - def __init__(self, size, obs_shape, action_size, seq_len, batch_size): + def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): self.size = size self.obs_shape = obs_shape self.action_size = action_size @@ -113,6 +114,7 @@ class ReplayBuffer: self.batch_size = batch_size self.idx = 0 self.full = False + self.args = args self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.actions = np.empty((size, action_size), dtype=np.float32) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) @@ -152,6 +154,22 @@ class ReplayBuffer: obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) return obs,acs,rews,terms + def group_steps(self, buffer, variable): + variable = getattr(buffer, variable) + non_zero_indices = np.nonzero(buffer.episode_count)[0] + variable = variable[non_zero_indices] + + variable = variable.reshape(self.args.episode_length, self.args.batch_size, + self.args.frame_stack*self.args.channels, + self.args.image_size,self.args.image_size) + return variable + + def transform_grouped_steps(self, variable): + variable = variable.transpose((1, 0, 2, 3, 4)) + variable = variable.reshape(self.args.batch_size*self.args.episode_length,self.args.frame_stack*self.args.channels, + self.args.image_size,self.args.image_size) + return variable + def make_env(args): env = dmc2gym.make( @@ -167,4 +185,10 @@ def make_env(args): width=args.image_size, frame_skip=args.action_repeat ) - return env \ No newline at end of file + return env + +def save_image(array, filename): + array = array.transpose(1, 2, 0) + array = (array * 255).astype(np.uint8) + image = Image.fromarray(array) + image.save(filename) \ No newline at end of file