From 02a66cfb3317a4ec8be1480094dfbdf5cdd5fa88 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Sat, 22 Apr 2023 13:07:22 +0200 Subject: [PATCH] Adding after some changes --- DPI/train.py | 306 +++++++++++++++++++++++++++++++-------------------- DPI/utils.py | 155 ++++++++++++++++++++++++-- 2 files changed, 333 insertions(+), 128 deletions(-) diff --git a/DPI/train.py b/DPI/train.py index 4138b9c..ed09de3 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -9,7 +9,7 @@ import numpy as np from collections import OrderedDict import utils -from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image +from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image, shuffle_along_axis, Logger from replay_buffer import ReplayBuffer from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from video import VideoRecorder @@ -50,11 +50,11 @@ def parse_args(): parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--num_train_steps', default=100000, type=int) - parser.add_argument('--update_steps', default=1, type=int) + parser.add_argument('--update_steps', default=100, type=int) parser.add_argument('--batch_size', default=64, type=int) #512 parser.add_argument('--state_size', default=50, type=int) parser.add_argument('--hidden_size', default=512, type=int) - parser.add_argument('--history_size', default=128, type=int) + parser.add_argument('--history_size', default=256, type=int) parser.add_argument('--episode_collection', default=5, type=int) parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer') parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models') @@ -66,11 +66,11 @@ def parse_args(): parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000 # value - parser.add_argument('--value_lr', default=8e-6, type=float) + parser.add_argument('--value_lr', default=1e-6, type=float) parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--td_lambda', default=0.95, type=int) # actor - parser.add_argument('--actor_lr', default=8e-6, type=float) + parser.add_argument('--actor_lr', default=1e-6, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float) @@ -78,13 +78,15 @@ def parse_args(): # world/encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--world_model_lr', default=1e-5, type=float) - parser.add_argument('--encoder_tau', default=0.001 , type=float) + parser.add_argument('--decoder_lr', default=1e-5, type=float) + parser.add_argument('--reward_lr', default=1e-5, type=float) + parser.add_argument('--encoder_tau', default=0.001, type=float) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--aug', action='store_true') # sac - parser.add_argument('--discount', default=0.95, type=float) + parser.add_argument('--discount', default=0.99, type=float) # misc parser.add_argument('--seed', default=1, type=int) parser.add_argument('--logging_freq', default=100, type=int) @@ -131,15 +133,14 @@ class DPI: self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.ActionRepeat(self.env, self.args.action_repeat) self.env = utils.NormalizeActions(self.env) - self.env = utils.TimeLimit(self.env, 1000 / args.action_repeat) + self.env = utils.TimeLimit(self.env, 1000 // args.action_repeat) # create replay buffer - self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, - obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), - action_size=self.env.action_space.shape[0], - seq_len=self.args.episode_length, - batch_size=args.batch_size, - args=self.args) + self.data_buffer = ReplayBuffer(self.args.replay_buffer_capacity, + self.env.observation_space.shape, + self.env.action_space.shape[0], + self.args.episode_length, + self.args.batch_size) # create work directory utils.make_dir(self.args.work_dir) @@ -180,7 +181,7 @@ class DPI: hidden_size=self.args.hidden_size, # 256, action_size=self.env.action_space.shape[0], # 6 ).to(device) - self.actor_model.apply(self.init_weights) + #self.actor_model.apply(self.init_weights) # Value Models @@ -225,22 +226,23 @@ class DPI: # model parameters self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \ - list(self.transition_model.parameters()) + list(self.obs_decoder.parameters()) + \ - list(self.reward_model.parameters()) + list(self.club_sample.parameters()) + list(self.transition_model.parameters()) + list(self.club_sample.parameters()) + \ + list(self.contrastive_head.parameters()) self.past_transition_parameters = self.transition_model.parameters() # optimizers - self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr) - self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr) - self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr) - #self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), 1e-5) - #self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), 1e-4) + self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6) + self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6) + self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6) + self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6) + self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6) # Create Modules - self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.obs_decoder, self.reward_model, self.club_sample] + self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head] self.value_modules = [self.value_model] self.actor_modules = [self.actor_model] - #self.reward_modules = [self.reward_model] + self.decoder_modules = [self.obs_decoder] + self.reward_modules = [self.reward_model] #self.decoder_modules = [self.obs_decoder] if use_saved: @@ -251,105 +253,156 @@ class DPI: self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) - def collect_random_sequences(self, episodes): + def collect_random_sequences(self, seed_steps): obs = self.env.reset() done = False - all_rews = [] - for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): - self.global_episodes += 1 - epi_reward = 0 - while not done: - action = self.env.action_space.sample() - next_obs, rew, done, _ = self.env.step(action) - self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes) - obs = next_obs - epi_reward += rew - obs = self.env.reset() - done=False - all_rews.append(epi_reward) + self.global_episodes += 1 + epi_reward = 0 + for _ in tqdm.tqdm(range(seed_steps), desc='Collecting episodes'): + action = self.env.action_space.sample() + next_obs, rew, done, _ = self.env.step(action) + self.data_buffer.add(obs, action, next_obs, rew, done) + obs = next_obs + epi_reward += rew + if done: + obs = self.env.reset() + done=False + all_rews.append(epi_reward) + epi_reward = 0 return all_rews - def collect_sequences(self, episodes, actor_model): + def collect_sequences(self, collect_steps, actor_model): obs = self.env.reset() done = False all_rews = [] - for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): - self.global_episodes += 1 - epi_reward = 0 - while not done: - with torch.no_grad(): - obs = torch.tensor(obs.copy(), dtype=torch.float32).to(device).unsqueeze(0) - state = self.get_features(obs)["distribution"].rsample() - action = self.actor_model(state) - action = actor_model.add_exploration(action).cpu().numpy()[0] - print(action) - obs = obs.cpu().numpy()[0] - next_obs, rew, done, _ = self.env.step(action) - self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes) - obs = next_obs + self.global_episodes += 1 + epi_reward = 0 + for episode_count in tqdm.tqdm(range(collect_steps), desc='Collecting episodes'): + with torch.no_grad(): + obs_ = torch.tensor(obs.copy(), dtype=torch.float32) + obs_ = preprocess_obs(obs_).to(device) + state = self.get_features(obs_)["distribution"].rsample().unsqueeze(0) + action = actor_model(state) + action = actor_model.add_exploration(action) + action = action.cpu().numpy()[0] + next_obs, rew, done, _ = self.env.step(action) + self.data_buffer.add(obs, action, next_obs, rew, done) + + if done: + obs = self.env.reset() + done = False + all_rews.append(epi_reward) + epi_reward = 0 + else: + obs = next_obs epi_reward += rew - obs = self.env.reset() - done=False - all_rews.append(epi_reward) return all_rews def train(self, step, total_steps): - episodic_rews = self.collect_random_sequences(self.args.episodes_buffer) - global_step = self.data_buffer.steps - # logger - logs = OrderedDict() + logdir = os.path.dirname(os.path.realpath(__file__)) + "/log/logs/" + if not(os.path.exists(logdir)): + os.makedirs(logdir) + initial_logs = OrderedDict() + logger = Logger(logdir) - while global_step < total_steps: - step += 1 - for update_steps in range(self.args.update_steps): - model_loss, actor_loss, value_loss = self.update((step-1)*args.update_steps + update_steps) - episodic_rews = self.collect_sequences(self.args.episode_collection, actor_model=self.actor_model, encoder_model=self.obs_encoder) - - logs.update({ - 'model_loss' : model_loss, - 'actor_loss': actor_loss, - 'value_loss': value_loss, - 'train_avg_reward':np.mean(episodic_rews), + episodic_rews = self.collect_random_sequences(5000//args.action_repeat) + self.global_step = self.data_buffer.steps + + initial_logs.update({ + 'train_avg_reward':np.mean(episodic_rews), 'train_max_reward': np.max(episodic_rews), 'train_min_reward': np.min(episodic_rews), 'train_std_reward':np.std(episodic_rews), }) + logger.log_scalars(initial_logs, step=0) + logger.flush() + - print("########## Global Step: ", global_step, " ##########") - for key, value in logs.items(): + while self.global_step < total_steps: + logs = OrderedDict() + step += 1 + for update_steps in range(self.args.update_steps): + model_loss, actor_loss, value_loss, actor_model = self.update((step-1)*args.update_steps + update_steps) + + initial_logs.update({ + 'model_loss' : model_loss, + 'actor_loss': actor_loss, + 'value_loss': value_loss, + 'train_avg_reward':np.mean(episodic_rews), + 'train_max_reward': np.max(episodic_rews), + 'train_min_reward': np.min(episodic_rews), + 'train_std_reward':np.std(episodic_rews), + }) + logger.log_scalars(logs, self.global_step) + + + print("########## Global Step:", self.global_step, " ##########") + for key, value in initial_logs.items(): print(key, " : ", value) + + episodic_rews = self.collect_sequences(1000//self.args.action_repeat, actor_model) - print(global_step) - if global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0: + if self.global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0: print("Saving model") path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" self.save_models(path) self.evaluate() - global_step = self.data_buffer.steps + self.global_step = self.data_buffer.steps * self.args.action_repeat + """ # collect experience if step !=0: encoder = self.obs_encoder actor = self.actor_model all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder) + """ + + def collect_batch(self): + obs_, acs_, nxt_obs_, rews_, terms_ = self.data_buffer.sample() + obs = torch.tensor(obs_, dtype=torch.float32)[1:] + last_obs = torch.tensor(obs_, dtype=torch.float32)[:-1] + nxt_obs = torch.tensor(nxt_obs_, dtype=torch.float32)[1:] + acs = torch.tensor(acs_, dtype=torch.float32)[:-1].to(device) + nxt_acs = torch.tensor(acs_, dtype=torch.float32)[1:].to(device) + rews = torch.tensor(rews_, dtype=torch.float32)[:-1].to(device).unsqueeze(-1) + nonterms = torch.tensor((1.0-terms_), dtype=torch.float32)[:-1].to(device).unsqueeze(-1) + + last_obs = preprocess_obs(last_obs).to(device) + obs = preprocess_obs(obs).to(device) + nxt_obs = preprocess_obs(nxt_obs).to(device) + + return last_obs, obs, nxt_obs, acs, rews, nxt_acs, nonterms def update(self, step): - last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch() + last_observations, current_observations, next_observations, actions, rewards, next_actions, nonterms = self.collect_batch() + + #last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch() world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations, - current_observations, - next_observations, - actions, - next_actions, - rewards) + current_observations, + next_observations, + actions, + next_actions, + rewards, + nonterms) self.world_model_opt.zero_grad() world_loss.backward() nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) self.world_model_opt.step() + + self.decoder_opt.zero_grad() + dec_loss.backward() + nn.utils.clip_grad_norm_(self.obs_decoder.parameters(), self.args.grad_clip_norm) + self.decoder_opt.step() + + self.reward_opt.zero_grad() + rew_loss.backward() + nn.utils.clip_grad_norm_(self.reward_model.parameters(), self.args.grad_clip_norm) + self.reward_opt.step() actor_loss = self.actor_model_losses() self.actor_opt.zero_grad() @@ -368,8 +421,8 @@ class DPI: soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) # update target value networks - if step % self.args.value_target_update_freq == 0: - self.target_value_model = copy.deepcopy(self.value_model) + #if step % self.args.value_target_update_freq == 0: + # self.target_value_model = copy.deepcopy(self.value_model) if step % self.args.logging_freq: writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step) @@ -379,14 +432,15 @@ class DPI: writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step) writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step) - writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step) + writer.add_scalar('Bound Loss/Lower Bound Loss', -lb_loss.detach().item(), step) - return world_loss.item(), actor_loss.item(), value_loss.item() + return world_loss.item(), actor_loss.item(), value_loss.item(), self.actor_model - def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards): + def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards, nonterms): + # get features self.last_state_feat = self.get_features(last_obs) self.curr_state_feat = self.get_features(curr_obs) - self.nxt_state_feat = self.get_features(nxt_obs) + self.nxt_state_feat = self.get_features(nxt_obs, momentum=True) # states self.last_state_enc = self.last_state_feat["sample"] @@ -394,42 +448,37 @@ class DPI: self.nxt_state_enc = self.nxt_state_feat["sample"] # actions - actions = actions - nxt_actions = nxt_actions + actions = actions.clone() + nxt_actions = nxt_actions.clone() # rewards - rewards = rewards + rewards = rewards.clone() # predict next states self.transition_model.init_states(self.args.batch_size, device) # (N,128) - self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history) + self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms) self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"]) self.pred_curr_state_enc = self.pred_curr_state_dist.mean - #print(torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc)) - #print(torch.distributions.kl_divergence(self.curr_state_feat["distribution"], self.pred_curr_state_dist).mean(),0) - - # encoder loss - enc_loss = torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc) - #self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist) + enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist) # reward loss rew_dist = self.reward_model(self.curr_state_enc) - rew_loss = -torch.mean(rew_dist.log_prob(rewards.unsqueeze(-1))) + rew_loss = -torch.mean(rew_dist.log_prob(rewards)) # decoder loss dec_dist = self.obs_decoder(self.nxt_state_enc) dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs)) # upper bound loss - likelihood_loss, ub_loss = self._upper_bound_minimization(self.curr_state_enc, + _, ub_loss = self._upper_bound_minimization(self.curr_state_enc, self.pred_curr_state_enc) - + # lower bound loss # contrastive projection - vec_anchor = self.pred_curr_state_enc - vec_positive = self.nxt_state_enc + vec_anchor = self.pred_curr_state_enc.detach() + vec_positive = self.nxt_state_enc.detach() z_anchor = self.prjoection_head(vec_anchor, nxt_actions) z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions) @@ -440,40 +489,39 @@ class DPI: labels = torch.arange(logits.shape[0]).long().to(device) lb_loss = F.cross_entropy(logits, labels) + past_lb_loss past_lb_loss = lb_loss.detach().item() - lb_loss = lb_loss/(z_anchor.shape[0]) + lb_loss = -0.1 * lb_loss/(z_anchor.shape[0]) - world_loss = enc_loss + rew_loss + dec_loss * 1e-4 + ub_loss * 10 + lb_loss + world_loss = enc_loss + ub_loss + lb_loss - return world_loss, enc_loss , rew_loss, dec_loss * 1e-4, ub_loss * 10, lb_loss + return world_loss, enc_loss , rew_loss, dec_loss, ub_loss, lb_loss def actor_model_losses(self): with torch.no_grad(): curr_state_enc = self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"] curr_state_hist = self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"] - with FreezeParameters(self.world_model_modules): + with FreezeParameters(self.world_model_modules + self.decoder_modules + self.reward_modules): imagine_horizon = self.args.imagine_horizon action = self.actor_model(curr_state_enc) self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc, action, curr_state_hist, imagine_horizon) - with FreezeParameters(self.world_model_modules + self.value_modules): + with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules): imag_rewards = self.reward_model(self.imagined_rollout["sample"]).mean - imag_values = self.target_value_model(self.imagined_rollout["sample"]).mean - discounts = self.args.discount * torch.ones_like(imag_rewards).detach() + imag_values = self.value_model(self.imagined_rollout["sample"]).mean + discounts = self.args.discount * torch.ones_like(imag_rewards).detach() + self.returns = self._compute_lambda_return(imag_rewards[:-1], - imag_values[:-1], - discounts[:-1] , - self.args.td_lambda, - imag_values[-1]) - + imag_values[:-1], + discounts[:-1] , + self.args.td_lambda, + imag_values[-1]) discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) self.discounts = torch.cumprod(discounts, 0).detach() actor_loss = -torch.mean(self.discounts * self.returns) return actor_loss - def value_model_losses(self): # value loss with torch.no_grad(): @@ -483,7 +531,6 @@ class DPI: value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) return value_loss - def select_one_batch(self): # collect sequences non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0] @@ -530,9 +577,6 @@ class DPI: return last_observations, current_observations, next_observations, actions, next_actions, rewards - - - def evaluate(self, eval_episodes=10): path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" self.restore_checkpoint(path) @@ -608,7 +652,9 @@ class DPI: return torch.tensor(transposed_array).float() def _upper_bound_minimization(self, current_states, predicted_current_states): - club_loss = self.club_sample(current_states, predicted_current_states, current_states) + current_negative_states = shuffle_along_axis(current_states.clone(), axis=0) + current_negative_states = shuffle_along_axis(current_negative_states, axis=1) + club_loss = self.club_sample(current_states, predicted_current_states, current_negative_states) likelihood_loss = 0 return likelihood_loss, club_loss @@ -645,6 +691,24 @@ class DPI: returns = torch.flip(torch.stack(rets), [0]) return returns + def lambda_return(self,imged_reward, value_pred, bootstrap, discount=0.99, lambda_=0.95): + # Setting lambda=1 gives a discounted Monte Carlo return. + # Setting lambda=0 gives a fixed 1-step return. + next_values = torch.cat([value_pred[1:], bootstrap[None]], 0) + discount_tensor = discount * torch.ones_like(imged_reward) # pcont + inputs = imged_reward + discount_tensor * next_values * (1 - lambda_) + last = bootstrap + indices = reversed(range(len(inputs))) + outputs = [] + for index in indices: + inp, disc = inputs[index], discount_tensor[index] + last = inp + disc * lambda_ * last + outputs.append(last) + outputs = list(reversed(outputs)) + outputs = torch.stack(outputs, 0) + returns = outputs + return returns + def save_models(self, save_path): torch.save( {'rssm' : self.transition_model.state_dict(), diff --git a/DPI/utils.py b/DPI/utils.py index c4f2208..3b61229 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -1,10 +1,13 @@ import os import random +import pickle import numpy as np from collections import deque import torch import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter + import gym import dmc2gym @@ -144,8 +147,90 @@ class NormalizeActions: original = np.where(self._mask, original, action) return self._env.step(original) +class TimeLimit: + + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + assert self._step is not None, 'Must reset environment.' + obs, reward, done, info = self._env.step(action) + self._step += 1 + if self._step >= self._duration: + done = True + if 'discount' not in info: + info['discount'] = np.array(1.0).astype(np.float32) + self._step = None + return obs, reward, done, info + + def reset(self): + self._step = 0 + return self._env.reset() + class ReplayBuffer: + + def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): + + self.size = size + self.obs_shape = obs_shape + self.action_size = action_size + self.seq_len = seq_len + self.batch_size = batch_size + self.idx = 0 + self.full = False + self.observations = np.empty((size, *obs_shape), dtype=np.uint8) + self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) + self.actions = np.empty((size, action_size), dtype=np.float32) + self.rewards = np.empty((size,), dtype=np.float32) + self.terminals = np.empty((size,), dtype=np.float32) + self.steps, self.episodes = 0, 0 + self.episode_count = np.zeros((size,), dtype=np.int32) + + def add(self, obs, ac, next_obs, rew, done, episode_count): + self.observations[self.idx] = obs + self.next_observations[self.idx] = next_obs + self.actions[self.idx] = ac + self.rewards[self.idx] = rew + self.terminals[self.idx] = done + self.full = self.full or self.idx == 0 + self.steps += 1 + self.episodes = self.episodes + (1 if done else 0) + self.episode_count[self.idx] = episode_count + self.idx = (self.idx + 1) % self.size + + def _sample_idx(self, L): + valid_idx = False + while not valid_idx: + idx = np.random.randint(0, self.size if self.full else self.idx - L) + idxs = np.arange(idx, idx + L) % self.size + valid_idx = not self.idx in idxs[1:] + return idxs + + def _retrieve_batch(self, idxs, n, L): + vec_idxs = idxs.transpose().reshape(-1) # Unroll indices + observations = self.observations[vec_idxs] + next_obs = self.next_observations[vec_idxs] + obs = observations.reshape(L, n, *observations.shape[1:]) + next_obs = next_obs.reshape(L, n, *next_obs.shape[1:]) + acs = self.actions[vec_idxs].reshape(L, n, -1) + rew = self.rewards[vec_idxs].reshape(L, n) + term = self.terminals[vec_idxs].reshape(L, n) + return obs, acs, next_obs, rew, term + + def sample(self): + n = self.batch_size + l = self.seq_len + obs,acs,next_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) + return obs,acs,next_obs,rews,terms + + +class ReplayBuffer1: def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): self.size = size self.obs_shape = obs_shape @@ -199,8 +284,11 @@ class ReplayBuffer: def group_steps(self, buffer, variable, obs=True): variable = getattr(buffer, variable) non_zero_indices = np.nonzero(buffer.episode_count)[0] + print(buffer.episode_count) variable = variable[non_zero_indices] - + print(variable.shape) + exit() + if obs: variable = variable.reshape(-1, self.args.episode_length, self.args.frame_stack*self.args.channels, @@ -215,8 +303,9 @@ class ReplayBuffer: self.args.image_size,self.args.image_size) return variable - def sample_random_idx(self, buffer_length): - random_indices = random.sample(range(0, buffer_length), self.args.batch_size) + def sample_random_idx(self, buffer_length, last=False): + init = 0 if last else buffer_length - self.args.batch_size + random_indices = random.sample(range(init, buffer_length), self.args.batch_size) return random_indices def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0): @@ -248,19 +337,23 @@ def make_env(args): ) return env +def shuffle_along_axis(a, axis): + idx = np.random.rand(*a.shape).argsort(axis=axis) + return np.take_along_axis(a,idx,axis=axis) + def preprocess_obs(obs): - obs = obs/255.0 - 0.5 + obs = (obs/255.0) - 0.5 return obs def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( - tau * param.data + (1 - tau) * target_param.data + tau * param.detach().data + (1 - tau) * target_param.data ) def save_image(array, filename): array = array.transpose(1, 2, 0) - array = (array * 255).astype(np.uint8) + array = ((array+0.5) * 255).astype(np.uint8) image = Image.fromarray(array) image.save(filename) @@ -353,4 +446,52 @@ class FreezeParameters: def __exit__(self, exc_type, exc_val, exc_tb): for i, param in enumerate(get_parameters(self.modules)): - param.requires_grad = self.param_states[i] \ No newline at end of file + param.requires_grad = self.param_states[i] + +class Logger: + + def __init__(self, log_dir, n_logged_samples=10, summary_writer=None): + self._log_dir = log_dir + print('########################') + print('logging outputs to ', log_dir) + print('########################') + self._n_logged_samples = n_logged_samples + self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1) + + def log_scalar(self, scalar, name, step_): + self._summ_writer.add_scalar('{}'.format(name), scalar, step_) + + def log_scalars(self, scalar_dict, step): + for key, value in scalar_dict.items(): + print('{} : {}'.format(key, value)) + self.log_scalar(value, key, step) + self.dump_scalars_to_pickle(scalar_dict, step) + + def log_videos(self, videos, step, max_videos_to_save=1, fps=20, video_title='video'): + + # max rollout length + max_videos_to_save = np.min([max_videos_to_save, videos.shape[0]]) + max_length = videos[0].shape[0] + for i in range(max_videos_to_save): + if videos[i].shape[0]>max_length: + max_length = videos[i].shape[0] + + # pad rollouts to all be same length + for i in range(max_videos_to_save): + if videos[i].shape[0]