diff --git a/DPI/models.py b/DPI/models.py index 4729787..c8d1e30 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -13,7 +13,7 @@ class ObservationEncoder(nn.Module): assert len(obs_shape) == 3 self.state_size = state_size - + layers = [] for i in range(num_layers): input_channels = obs_shape[0] if i == 0 else output_channels @@ -23,23 +23,24 @@ class ObservationEncoder(nn.Module): self.convs = nn.Sequential(*layers) - self.fc = nn.Linear(256 * 3 * 3, 2 * state_size) + self.fc = nn.Linear(256 * obs_shape[0], 2 * state_size) # 9 if 3 frames stacked def forward(self, x): - x = self.convs(x) - x = x.view(x.size(0), -1) - x = self.fc(x) + x_reshaped = x.reshape(-1, *x.shape[-3:]) + x_embed = self.convs(x_reshaped) + x_embed = torch.reshape(x_embed, (*x.shape[:-3], -1)) + x = self.fc(x_embed) # Mean and standard deviation mean, std = torch.chunk(x, 2, dim=-1) + mean = nn.ELU()(mean) std = F.softplus(std) - std = torch.clamp(std, min=0.0, max=1e5) + std = torch.clamp(std, min=0.0, max=1e1) # Normal Distribution dist = self.get_dist(mean, std) # Sampling via reparameterization Trick - #x = dist.rsample() x = self.reparameterize(mean, std) encoded_output = {"sample": x, "distribution": dist} @@ -63,7 +64,7 @@ class ObservationDecoder(nn.Module): self.output_shape = output_shape self.input_size = 256 * 3 * 3 self.in_channels = [self.input_size, 256, 128, 64] - self.out_channels = [256, 128, 64, 3] + self.out_channels = [256, 128, 64, 9] if output_shape[1] == 84: self.kernels = [5, 7, 5, 6] @@ -94,43 +95,50 @@ class ObservationDecoder(nn.Module): class Actor(nn.Module): - def __init__(self, state_size, hidden_size, action_size, num_layers=5): + def __init__(self, state_size, hidden_size, action_size, num_layers=4, min_std=1e-4, init_std=5, mean_scale=5): super().__init__() self.state_size = state_size self.hidden_size = hidden_size self.action_size = action_size self.num_layers = num_layers - self._min_std=torch.Tensor([1e-4])[0] - self._init_std=torch.Tensor([5])[0] - self._mean_scale=torch.Tensor([5])[0] + self._min_std = min_std + self._init_std = init_std + self._mean_scale = mean_scale layers = [] for i in range(self.num_layers): input_channels = state_size if i == 0 else self.hidden_size - output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size - layers.append(nn.Linear(input_channels, output_channels)) - layers.append(nn.LeakyReLU()) + layers.append(nn.Linear(input_channels, self.hidden_size)) + layers.append(nn.ReLU()) + layers.append(nn.Linear(self.hidden_size, 2*self.action_size)) self.action_model = nn.Sequential(*layers) def get_dist(self, mean, std): distribution = torch.distributions.Normal(mean, std) distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector()) distribution = torch.distributions.independent.Independent(distribution, 1) + distribution = SampleDist(distribution) return distribution + def add_exploration(self, action, action_noise=0.3): + return torch.clamp(torch.distributions.Normal(action, action_noise).rsample(), -1, 1) + def forward(self, features): out = self.action_model(features) mean, std = torch.chunk(out, 2, dim=-1) - raw_init_std = torch.log(torch.exp(self._init_std) - 1) + raw_init_std = np.log(np.exp(self._init_std) - 1) action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale) action_std = F.softplus(std + raw_init_std) + self._min_std dist = self.get_dist(action_mean, action_std) - sample = dist.rsample() + sample = dist.rsample() #self.reparameterize(action_mean, action_std) return sample - + + def reparameterize(self, mu, std): + eps = torch.randn_like(std) + return mu + eps * std class ValueModel(nn.Module): def __init__(self, state_size, hidden_size, num_layers=4): @@ -140,11 +148,12 @@ class ValueModel(nn.Module): self.num_layers = num_layers layers = [] - for i in range(self.num_layers): + for i in range(self.num_layers-1): input_channels = state_size if i == 0 else self.hidden_size - output_channels = self.hidden_size if i!= self.num_layers-1 else 1 + output_channels = self.hidden_size layers.append(nn.Linear(input_channels, output_channels)) layers.append(nn.LeakyReLU()) + layers.append(nn.Linear(self.hidden_size, int(np.prod(1)))) self.value_model = nn.Sequential(*layers) def forward(self, state): @@ -169,6 +178,7 @@ class RewardModel(nn.Module): return torch.distributions.independent.Independent( torch.distributions.Normal(reward, 1), 1) +""" class TransitionModel(nn.Module): def __init__(self, state_size, hidden_size, action_size, history_size): super().__init__() @@ -180,6 +190,7 @@ class TransitionModel(nn.Module): self.act_fn = nn.LeakyReLU() self.fc_state_action = nn.Linear(state_size + action_size, hidden_size) + self.ln = nn.LayerNorm(hidden_size) self.history_cell = nn.GRUCell(hidden_size + history_size, history_size) self.fc_state_prior = nn.Linear(history_size + state_size + action_size, 2 * state_size) self.fc_state_posterior = nn.Linear(history_size + state_size + action_size, 2 * state_size) @@ -194,12 +205,25 @@ class TransitionModel(nn.Module): distribution = torch.distributions.independent.Independent(distribution, 1) return distribution - def imagine_step(self, prev_state, prev_action, prev_history): - state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))) - prev_hist = prev_history.detach() - history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist) - - state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1)) + def stack_states(self, states, dim=0): + s = dict( + mean = torch.stack([state['mean'] for state in states], dim=dim), + std = torch.stack([state['std'] for state in states], dim=dim), + sample = torch.stack([state['sample'] for state in states], dim=dim), + history = torch.stack([state['history'] for state in states], dim=dim),) + if 'distribution' in states: + dist = dict(distribution = [state['distribution'] for state in states]) + s.update(dist) + return s + + def seq_to_batch(self, state, name): + return dict( + sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:]))) + + def imagine_step(self, state, action, history): + state_action = self.ln(self.act_fn(self.fc_state_action(torch.cat([state, action], dim=-1)))) + imag_hist = self.history_cell(torch.cat([state_action, history], dim=-1), history) + state_prior = self.fc_state_prior(torch.cat([imag_hist, state, action], dim=-1)) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_std = F.softplus(state_prior_std) @@ -208,19 +232,9 @@ class TransitionModel(nn.Module): # Sampling via reparameterization Trick sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std) - prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} + prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": imag_hist, "distribution": state_prior_dist} return prior - def stack_states(self, states, dim=0): - s = dict( - mean = torch.stack([state['mean'] for state in states], dim=dim), - std = torch.stack([state['std'] for state in states], dim=dim), - sample = torch.stack([state['sample'] for state in states], dim=dim), - history = torch.stack([state['history'] for state in states], dim=dim),) - dist = dict(distribution = [state['distribution'] for state in states]) - s.update(dist) - return s - def imagine_rollout(self, state, action, history, horizon): imagined_priors = [] for i in range(horizon): @@ -231,10 +245,126 @@ class TransitionModel(nn.Module): imagined_priors = self.stack_states(imagined_priors, dim=0) return imagined_priors + def observe_step(self, prev_state, prev_action, prev_history, nonterms): + state_action = self.ln(self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))) + current_history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history) + state_prior = self.fc_state_prior(torch.cat([prev_history, prev_state, prev_action], dim=-1)) + state_prior_mean, state_prior_std = torch.chunk(state_prior*nonterms, 2, dim=-1) + state_prior_std = F.softplus(state_prior_std) + 0.1 + sample_state_prior = state_prior_mean + torch.randn_like(state_prior_mean) * state_prior_std + prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": current_history} + return prior + + def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms): + observed_rollout = [] + for i in range(rollout_states.shape[0]): + actions = rollout_actions[i] * nonterms[i] + prior = self.observe_step(rollout_states[i], actions, init_history, nonterms[i]) + init_history = prior["history"] + observed_rollout.append(prior) + observed_rollout = self.stack_states(observed_rollout, dim=0) + return observed_rollout + def reparemeterize(self, mean, std): eps = torch.randn_like(std) return mean + eps * std +""" +class TransitionModel(nn.Module): + def __init__(self, state_size, hidden_size, action_size, history_size): + super().__init__() + + self.state_size = state_size + self.hidden_size = hidden_size + self.action_size = action_size + self.history_size = history_size + self.act_fn = nn.ELU() + + self.fc_state_action = nn.Linear(state_size + action_size, hidden_size) + self.history_cell = nn.GRUCell(hidden_size, history_size) + self.fc_state_mu = nn.Linear(history_size + hidden_size, state_size) + self.fc_state_sigma = nn.Linear(history_size + hidden_size, state_size) + + self.batch_norm = nn.BatchNorm1d(hidden_size) + self.batch_norm2 = nn.BatchNorm1d(state_size) + + self.min_sigma = 1e-4 + self.max_sigma = 1e0 + + def init_states(self, batch_size, device): + self.prev_state = torch.zeros(batch_size, self.state_size).to(device) + self.prev_action = torch.zeros(batch_size, self.action_size).to(device) + self.prev_history = torch.zeros(batch_size, self.history_size).to(device) + + def get_dist(self, mean, std): + distribution = torch.distributions.Normal(mean, std) + distribution = torch.distributions.independent.Independent(distribution, 1) + return distribution + + def stack_states(self, states, dim=0): + s = dict( + mean = torch.stack([state['mean'] for state in states], dim=dim), + std = torch.stack([state['std'] for state in states], dim=dim), + sample = torch.stack([state['sample'] for state in states], dim=dim), + history = torch.stack([state['history'] for state in states], dim=dim),) + if 'distribution' in states: + dist = dict(distribution = [state['distribution'] for state in states]) + s.update(dist) + return s + + def seq_to_batch(self, state, name): + return dict( + sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:]))) + + def imagine_step(self, state, action, history): + next_state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([state, action], dim=-1)))) + imag_history = self.history_cell(next_state_action_enc, history) + next_state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([next_state_action_enc, imag_history], dim=-1)))) + next_state_sigma = torch.sigmoid(self.fc_state_sigma(torch.cat([next_state_action_enc, imag_history], dim=-1))) + next_state_sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * next_state_sigma + + # Normal Distribution + next_state_dist = self.get_dist(next_state_mu, next_state_sigma) + next_state_sample = self.reparemeterize(next_state_mu, next_state_sigma) + prior = {"mean": next_state_mu, "std": next_state_sigma, "sample": next_state_sample, "history": imag_history, "distribution": next_state_dist} + return prior + + def imagine_rollout(self, state, action, history, horizon): + imagined_priors = [] + for i in range(horizon): + prior = self.imagine_step(state, action, history) + state = prior["sample"] + history = prior["history"] + imagined_priors.append(prior) + imagined_priors = self.stack_states(imagined_priors, dim=0) + return imagined_priors + + def observe_step(self, prev_state, prev_action, prev_history): + state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))) + current_history = self.history_cell(state_action_enc, prev_history) + state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([state_action_enc, prev_history], dim=-1)))) + state_sigma = F.softplus(self.fc_state_sigma(torch.cat([state_action_enc, prev_history], dim=-1))) + + sample_state = state_mu + torch.randn_like(state_mu) * state_sigma + state_enc = {"mean": state_mu, "std": state_sigma, "sample": sample_state, "history": current_history} + return state_enc + + def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms): + observed_rollout = [] + for i in range(rollout_states.shape[0]): + rollout_states_ = rollout_states[i] + rollout_actions_ = rollout_actions[i] + init_history_ = nonterms[i] * init_history + state_enc = self.observe_step(rollout_states_, rollout_actions_, init_history_) + init_history = state_enc["history"] + observed_rollout.append(state_enc) + observed_rollout = self.stack_states(observed_rollout, dim=0) + return observed_rollout + + def reparemeterize(self, mean, std): + eps = torch.randn_like(mean) + return mean + eps * std + class TanhBijector(torch.distributions.Transform): def __init__(self): @@ -300,6 +430,7 @@ class ContrastiveHead(nn.Module): return logits +""" class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def __init__(self, last_states, current_states, negative_current_states, predicted_current_states): super(CLUBSample, self).__init__() @@ -313,7 +444,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses mu = dist.mean var = dist.variance - return mu, var, sample + return mu.detach(), var.detach(), sample.detach() def loglikeli(self): _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) @@ -330,15 +461,136 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator random_index = torch.randperm(sample_size).long() pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0) - neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0) - #neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) + #neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0) + neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) upper_bound = pos - neg return upper_bound/2 def learning_loss(self): return - self.loglikeli() +""" + +class CLUBSample(nn.Module): # Sampled version of the CLUB estimator + def __init__(self, x_dim, y_dim, hidden_size): + super(CLUBSample, self).__init__() + self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2), + nn.ReLU(), + nn.Linear(hidden_size//2, y_dim)) + + self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2), + nn.ReLU(), + nn.Linear(hidden_size//2, y_dim), + nn.Tanh()) + + def get_mu_logvar(self, x_samples): + mu = self.p_mu(x_samples) + logvar = self.p_logvar(x_samples) + return mu, logvar + + + def loglikeli(self, x_samples, y_samples): + mu, logvar = self.get_mu_logvar(x_samples) + return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0) + def forward(self, x_samples, y_samples, y_negatives): + mu, logvar = self.get_mu_logvar(x_samples) + + sample_size = x_samples.shape[0] + #random_index = torch.randint(sample_size, (sample_size,)).long() + random_index = torch.randperm(sample_size).long() + + positive = -(mu - y_samples)**2 / logvar.exp() + #negative = - (mu - y_samples[random_index])**2 / logvar.exp() + negative = -(mu - y_negatives)**2 / logvar.exp() + upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean() + return upper_bound/2. + + def learning_loss(self, x_samples, y_samples): + return -self.loglikeli(x_samples, y_samples) + + +class QFunction(nn.Module): + """MLP for q-function.""" + def __init__(self, obs_dim, action_dim, hidden_dim): + super().__init__() + + self.trunk = nn.Sequential( + nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, obs, action): + assert obs.size(0) == action.size(0) + + obs_action = torch.cat([obs, action], dim=1) + return self.trunk(obs_action) + + +class Critic(nn.Module): + """Critic network, employes two q-functions.""" + def __init__( + self, obs_shape, action_shape, hidden_dim, encoder_feature_dim): + super().__init__() + + self.Q1 = QFunction( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + self.Q2 = QFunction( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + + self.outputs = dict() + + def forward(self, obs, action, detach_encoder=False): + # detach_encoder allows to stop gradient propogation to encoder + obs = self.encoder(obs, detach=detach_encoder) + + q1 = self.Q1(obs, action) + q2 = self.Q2(obs, action) + + self.outputs['q1'] = q1 + self.outputs['q2'] = q2 + + return q1, q2 + +class SampleDist: + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples + + @property + def name(self): + return 'SampleDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + sample = self._dist.rsample(self._samples) + return torch.mean(sample, 0) + + def mode(self): + dist = self._dist.expand((self._samples, *self._dist.batch_shape)) + sample = dist.rsample() + logprob = dist.log_prob(sample) + batch_size = sample.size(1) + feature_size = sample.size(2) + indices = torch.argmax(logprob, dim=0).reshape(1, batch_size, 1).expand(1, batch_size, feature_size) + return torch.gather(sample, 0, indices).squeeze(0) + + def entropy(self): + dist = self._dist.expand((self._samples, *self._dist.batch_shape)) + sample = dist.rsample() + logprob = dist.log_prob(sample) + return -torch.mean(logprob, 0) + + def sample(self): + return self._dist.sample() + + if "__name__ == __main__": - pass \ No newline at end of file + tr = TransitionModel(50, 512, 1, 256) + diff --git a/DPI/replay_buffer.py b/DPI/replay_buffer.py new file mode 100644 index 0000000..242ca71 --- /dev/null +++ b/DPI/replay_buffer.py @@ -0,0 +1,51 @@ +import torch +import numpy as np + + +class ReplayBuffer: + + def __init__(self, size, obs_shape, action_size, seq_len, batch_size): + self.size = size + self.obs_shape = obs_shape + self.action_size = action_size + self.seq_len = seq_len + self.batch_size = batch_size + self.idx = 0 + self.full = False + self.observations = np.empty((size, *obs_shape), dtype=np.uint8) + self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) + self.actions = np.empty((size, action_size), dtype=np.float32) + self.rewards = np.empty((size,), dtype=np.float32) + self.terminals = np.empty((size,), dtype=np.float32) + self.steps, self.episodes = 0, 0 + + def add(self, obs, ac, next_obs, rew, done): + self.observations[self.idx] = obs + self.next_observations[self.idx] = next_obs + self.actions[self.idx] = ac + self.rewards[self.idx] = rew + self.terminals[self.idx] = done + self.idx = (self.idx + 1) % self.size + self.full = self.full or self.idx == 0 + self.steps += 1 + self.episodes = self.episodes + (1 if done else 0) + + def _sample_idx(self, L): + valid_idx = False + while not valid_idx: + idx = np.random.randint(0, self.size if self.full else self.idx - L) + idxs = np.arange(idx, idx + L) % self.size + valid_idx = not self.idx in idxs[1:] + return idxs + + def _retrieve_batch(self, idxs, n, L): + vec_idxs = idxs.transpose().reshape(-1) # Unroll indices + observations = self.observations[vec_idxs] + next_observations = self.next_observations[vec_idxs] + return observations.reshape(L, n, *observations.shape[1:]),self.actions[vec_idxs].reshape(L, n, -1), next_observations.reshape(L, n, *next_observations.shape[1:]), self.rewards[vec_idxs].reshape(L, n), self.terminals[vec_idxs].reshape(L, n) + + def sample(self): + n = self.batch_size + l = self.seq_len + obs,acs,nxt_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) + return obs,acs,nxt_obs,rews,terms diff --git a/DPI/train.py b/DPI/train.py index ed09de3..e804ac9 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -48,10 +48,10 @@ def parse_args(): parser.add_argument('--episode_length', default=51, type=int) # train parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) - parser.add_argument('--init_steps', default=10000, type=int) + parser.add_argument('--init_steps', default=5000, type=int) parser.add_argument('--num_train_steps', default=100000, type=int) parser.add_argument('--update_steps', default=100, type=int) - parser.add_argument('--batch_size', default=64, type=int) #512 + parser.add_argument('--batch_size', default=64, type=int) parser.add_argument('--state_size', default=50, type=int) parser.add_argument('--hidden_size', default=512, type=int) parser.add_argument('--history_size', default=256, type=int) @@ -66,20 +66,20 @@ def parse_args(): parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000 # value - parser.add_argument('--value_lr', default=1e-6, type=float) + parser.add_argument('--value_lr', default=8e-5, type=float) parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--td_lambda', default=0.95, type=int) # actor - parser.add_argument('--actor_lr', default=1e-6, type=float) + parser.add_argument('--actor_lr', default=8e-5, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_update_freq', default=2, type=int) # world/encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) - parser.add_argument('--world_model_lr', default=1e-5, type=float) - parser.add_argument('--decoder_lr', default=1e-5, type=float) - parser.add_argument('--reward_lr', default=1e-5, type=float) + parser.add_argument('--world_model_lr', default=6e-5, type=float) + parser.add_argument('--decoder_lr', default=6e-4, type=float) + parser.add_argument('--reward_lr', default=6e-5, type=float) parser.add_argument('--encoder_tau', default=0.001, type=float) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--num_layers', default=4, type=int) @@ -157,16 +157,19 @@ class DPI: obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) state_size=self.args.state_size # 128 ).to(device) + self.obs_encoder.apply(self.init_weights) self.obs_encoder_momentum = ObservationEncoder( obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) state_size=self.args.state_size # 128 ).to(device) + self.obs_encoder_momentum.apply(self.init_weights) self.obs_decoder = ObservationDecoder( state_size=self.args.state_size, # 128 output_shape=(self.args.channels*self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) ).to(device) + self.obs_decoder.apply(self.init_weights) self.transition_model = TransitionModel( state_size=self.args.state_size, # 128 @@ -174,6 +177,7 @@ class DPI: action_size=self.env.action_space.shape[0], # 6 history_size=self.args.history_size, # 128 ).to(device) + self.transition_model.apply(self.init_weights) # Actor Model self.actor_model = Actor( @@ -181,7 +185,7 @@ class DPI: hidden_size=self.args.hidden_size, # 256, action_size=self.env.action_space.shape[0], # 6 ).to(device) - #self.actor_model.apply(self.init_weights) + self.actor_model.apply(self.init_weights) # Value Models @@ -189,16 +193,19 @@ class DPI: state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 ).to(device) + self.value_model.apply(self.init_weights) self.target_value_model = ValueModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 ).to(device) + self.target_value_model.apply(self.init_weights) self.reward_model = RewardModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 ).to(device) + self.reward_model.apply(self.init_weights) # Contrastive Models self.prjoection_head = ProjectionHead( @@ -228,22 +235,21 @@ class DPI: self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \ list(self.transition_model.parameters()) + list(self.club_sample.parameters()) + \ list(self.contrastive_head.parameters()) - self.past_transition_parameters = self.transition_model.parameters() # optimizers self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6) - self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6) + self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6, weight_decay=1e-5) self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6) self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6) - self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6) + self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6, weight_decay=1e-5) # Create Modules - self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head] + self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head, + self.obs_encoder_momentum, self.prjoection_head_momentum] self.value_modules = [self.value_model] self.actor_modules = [self.actor_model] self.decoder_modules = [self.obs_decoder] self.reward_modules = [self.reward_model] - #self.decoder_modules = [self.obs_decoder] if use_saved: self._use_saved_models(saved_model_dir) @@ -282,7 +288,8 @@ class DPI: with torch.no_grad(): obs_ = torch.tensor(obs.copy(), dtype=torch.float32) obs_ = preprocess_obs(obs_).to(device) - state = self.get_features(obs_)["distribution"].rsample().unsqueeze(0) + #state = self.get_features(obs_)["sample"].unsqueeze(0) + state = self.get_features(obs_)["distribution"].rsample() action = actor_model(state) action = actor_model.add_exploration(action) action = action.cpu().numpy()[0] @@ -307,7 +314,7 @@ class DPI: initial_logs = OrderedDict() logger = Logger(logdir) - episodic_rews = self.collect_random_sequences(5000//args.action_repeat) + episodic_rews = self.collect_random_sequences(self.args.init_steps//args.action_repeat) self.global_step = self.data_buffer.steps initial_logs.update({ @@ -361,7 +368,8 @@ class DPI: """ def collect_batch(self): - obs_, acs_, nxt_obs_, rews_, terms_ = self.data_buffer.sample() + obs_, acs_, nxt_obs_, rews_, terms_ = self.data_buffer.sample() + obs = torch.tensor(obs_, dtype=torch.float32)[1:] last_obs = torch.tensor(obs_, dtype=torch.float32)[:-1] nxt_obs = torch.tensor(nxt_obs_, dtype=torch.float32)[1:] @@ -427,7 +435,7 @@ class DPI: if step % self.args.logging_freq: writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step) writer.add_scalar('Main Models Loss/Encoder Loss', enc_loss.detach().item(), step) - writer.add_scalar('Main Models Loss/Decoder Loss', dec_loss, step) + writer.add_scalar('Main Models Loss/Decoder Loss', dec_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step) @@ -440,31 +448,28 @@ class DPI: # get features self.last_state_feat = self.get_features(last_obs) self.curr_state_feat = self.get_features(curr_obs) - self.nxt_state_feat = self.get_features(nxt_obs, momentum=True) + self.nxt_state_feat = self.get_features(nxt_obs) + self.nxt_state_feat_lb = self.get_features(nxt_obs, momentum=True) # states - self.last_state_enc = self.last_state_feat["sample"] - self.curr_state_enc = self.curr_state_feat["sample"] - self.nxt_state_enc = self.nxt_state_feat["sample"] - - # actions - actions = actions.clone() - nxt_actions = nxt_actions.clone() - - # rewards - rewards = rewards.clone() + self.last_state_enc = self.last_state_feat["distribution"].rsample() #self.last_state_feat["sample"] + self.curr_state_enc = self.curr_state_feat["distribution"].rsample() #self.curr_state_feat["sample"] + self.nxt_state_enc = self.nxt_state_feat["distribution"].rsample() #self.nxt_state_feat["sample"] + self.nxt_state_enc_lb = self.nxt_state_feat_lb["distribution"].rsample() #self.nxt_state_feat_lb["sample"] # predict next states self.transition_model.init_states(self.args.batch_size, device) # (N,128) self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms) self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"]) - self.pred_curr_state_enc = self.pred_curr_state_dist.mean + #print(self.observed_rollout["mean"][0][0]) + self.pred_curr_state_enc = self.pred_curr_state_dist.rsample() #self.observed_rollout["sample"] # encoder loss enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist) # reward loss rew_dist = self.reward_model(self.curr_state_enc) + #print(torch.cat([rew_dist.mean[0], rewards[0]],dim=-1)) rew_loss = -torch.mean(rew_dist.log_prob(rewards)) # decoder loss @@ -472,13 +477,19 @@ class DPI: dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs)) # upper bound loss - _, ub_loss = self._upper_bound_minimization(self.curr_state_enc, - self.pred_curr_state_enc) + past_ub_loss = 0 + for i in range(self.curr_state_enc.shape[0]): + _, ub_loss = self._upper_bound_minimization(self.curr_state_enc[i], + self.pred_curr_state_enc[i]) + ub_loss = ub_loss + past_ub_loss + past_ub_loss = ub_loss + ub_loss = ub_loss / self.curr_state_enc.shape[0] + ub_loss = 0.01 * ub_loss # lower bound loss # contrastive projection vec_anchor = self.pred_curr_state_enc.detach() - vec_positive = self.nxt_state_enc.detach() + vec_positive = self.nxt_state_enc_lb.detach() z_anchor = self.prjoection_head(vec_anchor, nxt_actions) z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions) @@ -489,7 +500,7 @@ class DPI: labels = torch.arange(logits.shape[0]).long().to(device) lb_loss = F.cross_entropy(logits, labels) + past_lb_loss past_lb_loss = lb_loss.detach().item() - lb_loss = -0.1 * lb_loss/(z_anchor.shape[0]) + lb_loss = -0.01 * lb_loss/(z_anchor.shape[0]) world_loss = enc_loss + ub_loss + lb_loss @@ -497,20 +508,27 @@ class DPI: def actor_model_losses(self): with torch.no_grad(): - curr_state_enc = self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"] - curr_state_hist = self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"] + #curr_state_enc = self.curr_state_enc.reshape(self.args.episode_length-1,-1) #self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"] + #curr_state_hist = self.observed_rollout["history"].reshape(self.args.episode_length-1,-1) #self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"] + curr_state_enc = self.curr_state_enc.reshape(-1, self.args.state_size) + curr_state_hist = self.observed_rollout["history"].reshape(-1, self.args.history_size) - with FreezeParameters(self.world_model_modules + self.decoder_modules + self.reward_modules): + with FreezeParameters(self.world_model_modules + self.decoder_modules + self.reward_modules + self.value_modules): imagine_horizon = self.args.imagine_horizon - action = self.actor_model(curr_state_enc) + action = self.actor_model(curr_state_enc.detach()) self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc, - action, curr_state_hist, + action, curr_state_hist.detach(), imagine_horizon) + self.pred_nxt_state_dist = self.transition_model.get_dist(self.imagined_rollout["mean"], self.imagined_rollout["std"]) + #print(self.imagined_rollout["mean"][0][0]) + self.pred_nxt_state_enc = self.pred_nxt_state_dist.rsample() #self.transition_model.reparemeterize(self.imagined_rollout["mean"], self.imagined_rollout["std"]) with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules): - imag_rewards = self.reward_model(self.imagined_rollout["sample"]).mean - imag_values = self.value_model(self.imagined_rollout["sample"]).mean - discounts = self.args.discount * torch.ones_like(imag_rewards).detach() + imag_rewards_dist = self.reward_model(self.pred_nxt_state_enc) + imag_values_dist = self.value_model(self.pred_nxt_state_enc) + imag_rewards = imag_rewards_dist.mean + imag_values = imag_values_dist.mean + discounts = self.args.discount * torch.ones_like(imag_rewards).detach() self.returns = self._compute_lambda_return(imag_rewards[:-1], imag_values[:-1], @@ -525,7 +543,7 @@ class DPI: def value_model_losses(self): # value loss with torch.no_grad(): - value_feat = self.imagined_rollout["sample"][:-1].detach() + value_feat = self.pred_nxt_state_enc[:-1].detach() value_targ = self.returns.detach() value_dist = self.value_model(value_feat) value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) @@ -596,15 +614,14 @@ class DPI: obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0) obs_processed = preprocess_obs(obs).to(device) state = self.get_features(obs_processed)["distribution"].rsample() - action = self.actor_model(state).cpu().detach().numpy().squeeze() - + action = self.actor_model(state).cpu().detach().numpy().squeeze() next_obs, rew, done, _ = self.env.step(action) rewards += rew + obs = next_obs if self.args.save_video: self.env.video.record(self.env) self.env.video.save('/home/vedant/Curiosity/Curiosity/DPI/log/video/learned_model.mp4') - obs = next_obs obs = self.env.reset() episodic_rewards.append(rewards) print("Episodic rewards: ", episodic_rewards) @@ -679,6 +696,19 @@ class DPI: return x def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value): + next_values = torch.cat([values[1:], last_value[None]], 0) + target = rewards + discounts * next_values * (1 - td_lam) + timesteps = list(range(rewards.shape[0] - 1, -1, -1)) + outputs = [] + accumulated_reward = last_value + for t in timesteps: + inp = target[t] + discount_factor = discounts[t] + accumulated_reward = inp + discount_factor * td_lam * accumulated_reward + outputs.append(accumulated_reward) + returns = torch.flip(torch.stack(outputs), [0]) + return returns + """ next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0) targets = rewards + discounts * next_values * (1-td_lam) rets =[] @@ -690,6 +720,7 @@ class DPI: returns = torch.flip(torch.stack(rets), [0]) return returns + """ def lambda_return(self,imged_reward, value_pred, bootstrap, discount=0.99, lambda_=0.95): # Setting lambda=1 gives a discounted Monte Carlo return. @@ -740,7 +771,7 @@ if __name__ == '__main__': device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') step = 0 - total_steps = 500000 + total_steps = 1000000 dpi = DPI(args) dpi.train(step,total_steps) dpi.evaluate() \ No newline at end of file diff --git a/DPI/utils.py b/DPI/utils.py index 3b61229..660fb48 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -63,17 +63,6 @@ def make_dir(dir_path): return dir_path -def preprocess_obs(obs, bits=5): - """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" - bins = 2**bits - assert obs.dtype == torch.float32 - if bits < 8: - obs = torch.floor(obs / 2**(8 - bits)) - obs = obs / bins - obs = obs + torch.rand_like(obs) / bins - obs = obs - 0.5 - return obs - class FrameStack(gym.Wrapper): def __init__(self, env, k): @@ -338,8 +327,8 @@ def make_env(args): return env def shuffle_along_axis(a, axis): - idx = np.random.rand(*a.shape).argsort(axis=axis) - return np.take_along_axis(a,idx,axis=axis) + idx = np.random.rand(*a.shape).argsort(axis=axis) + return np.take_along_axis(a,idx,axis=axis) def preprocess_obs(obs): obs = (obs/255.0) - 0.5 @@ -374,6 +363,20 @@ def video_from_array(arr, high_noise, filename): out.write(frame) out.release() +def save_video(images): + """ + Image shape is (T, C, H, W) + Example:(50, 3, 84, 84) + """ + output_file = "output.avi" + fourcc = cv2.VideoWriter_fourcc(*'XVID') + fps = 2 + height, width, channels = 84,84,3 + out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) + for image in images: + image = np.uint8(image.transpose((1, 2, 0))) + out.write(image) + out.release() class CorruptVideos: def __init__(self, dir_path):