Compare commits

..

7 Commits
main ... idea_1

Author SHA1 Message Date
bb0265846e Update train model 2023-04-26 09:43:28 +02:00
ab4e7b9a22 Adding some new things 2023-04-24 17:42:37 +02:00
02a66cfb33 Adding after some changes 2023-04-22 13:07:22 +02:00
e7f5533ee6 Adding model 2023-04-20 14:55:54 +02:00
3fa5e8e74a New trained model 2023-04-18 16:47:30 +02:00
21cefbab48 Trying some ideas 2023-04-15 17:01:57 +02:00
9a2e9f420b Checking branch push 2023-04-15 15:54:09 +02:00
4 changed files with 988 additions and 330 deletions

View File

@ -23,23 +23,24 @@ class ObservationEncoder(nn.Module):
self.convs = nn.Sequential(*layers) self.convs = nn.Sequential(*layers)
self.fc = nn.Linear(256 * 3 * 3, 2 * state_size) self.fc = nn.Linear(256 * obs_shape[0], 2 * state_size) # 9 if 3 frames stacked
def forward(self, x): def forward(self, x):
x = self.convs(x) x_reshaped = x.reshape(-1, *x.shape[-3:])
x = x.view(x.size(0), -1) x_embed = self.convs(x_reshaped)
x = self.fc(x) x_embed = torch.reshape(x_embed, (*x.shape[:-3], -1))
x = self.fc(x_embed)
# Mean and standard deviation # Mean and standard deviation
mean, std = torch.chunk(x, 2, dim=-1) mean, std = torch.chunk(x, 2, dim=-1)
mean = nn.ELU()(mean)
std = F.softplus(std) std = F.softplus(std)
std = torch.clamp(std, min=0.0, max=1e5) std = torch.clamp(std, min=0.0, max=1e1)
# Normal Distribution # Normal Distribution
dist = self.get_dist(mean, std) dist = self.get_dist(mean, std)
# Sampling via reparameterization Trick # Sampling via reparameterization Trick
#x = dist.rsample()
x = self.reparameterize(mean, std) x = self.reparameterize(mean, std)
encoded_output = {"sample": x, "distribution": dist} encoded_output = {"sample": x, "distribution": dist}
@ -63,7 +64,7 @@ class ObservationDecoder(nn.Module):
self.output_shape = output_shape self.output_shape = output_shape
self.input_size = 256 * 3 * 3 self.input_size = 256 * 3 * 3
self.in_channels = [self.input_size, 256, 128, 64] self.in_channels = [self.input_size, 256, 128, 64]
self.out_channels = [256, 128, 64, 3] self.out_channels = [256, 128, 64, 9]
if output_shape[1] == 84: if output_shape[1] == 84:
self.kernels = [5, 7, 5, 6] self.kernels = [5, 7, 5, 6]
@ -94,43 +95,50 @@ class ObservationDecoder(nn.Module):
class Actor(nn.Module): class Actor(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=5): def __init__(self, state_size, hidden_size, action_size, num_layers=4, min_std=1e-4, init_std=5, mean_scale=5):
super().__init__() super().__init__()
self.state_size = state_size self.state_size = state_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.action_size = action_size self.action_size = action_size
self.num_layers = num_layers self.num_layers = num_layers
self._min_std=torch.Tensor([1e-4])[0] self._min_std = min_std
self._init_std=torch.Tensor([5])[0] self._init_std = init_std
self._mean_scale=torch.Tensor([5])[0] self._mean_scale = mean_scale
layers = [] layers = []
for i in range(self.num_layers): for i in range(self.num_layers):
input_channels = state_size if i == 0 else self.hidden_size input_channels = state_size if i == 0 else self.hidden_size
output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size layers.append(nn.Linear(input_channels, self.hidden_size))
layers.append(nn.Linear(input_channels, output_channels)) layers.append(nn.ReLU())
layers.append(nn.LeakyReLU()) layers.append(nn.Linear(self.hidden_size, 2*self.action_size))
self.action_model = nn.Sequential(*layers) self.action_model = nn.Sequential(*layers)
def get_dist(self, mean, std): def get_dist(self, mean, std):
distribution = torch.distributions.Normal(mean, std) distribution = torch.distributions.Normal(mean, std)
distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector()) distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector())
distribution = torch.distributions.independent.Independent(distribution, 1) distribution = torch.distributions.independent.Independent(distribution, 1)
distribution = SampleDist(distribution)
return distribution return distribution
def add_exploration(self, action, action_noise=0.3):
return torch.clamp(torch.distributions.Normal(action, action_noise).rsample(), -1, 1)
def forward(self, features): def forward(self, features):
out = self.action_model(features) out = self.action_model(features)
mean, std = torch.chunk(out, 2, dim=-1) mean, std = torch.chunk(out, 2, dim=-1)
raw_init_std = torch.log(torch.exp(self._init_std) - 1) raw_init_std = np.log(np.exp(self._init_std) - 1)
action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale) action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale)
action_std = F.softplus(std + raw_init_std) + self._min_std action_std = F.softplus(std + raw_init_std) + self._min_std
dist = self.get_dist(action_mean, action_std) dist = self.get_dist(action_mean, action_std)
sample = dist.rsample() sample = dist.rsample() #self.reparameterize(action_mean, action_std)
return sample return sample
def reparameterize(self, mu, std):
eps = torch.randn_like(std)
return mu + eps * std
class ValueModel(nn.Module): class ValueModel(nn.Module):
def __init__(self, state_size, hidden_size, num_layers=4): def __init__(self, state_size, hidden_size, num_layers=4):
@ -140,11 +148,12 @@ class ValueModel(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
layers = [] layers = []
for i in range(self.num_layers): for i in range(self.num_layers-1):
input_channels = state_size if i == 0 else self.hidden_size input_channels = state_size if i == 0 else self.hidden_size
output_channels = self.hidden_size if i!= self.num_layers-1 else 1 output_channels = self.hidden_size
layers.append(nn.Linear(input_channels, output_channels)) layers.append(nn.Linear(input_channels, output_channels))
layers.append(nn.LeakyReLU()) layers.append(nn.LeakyReLU())
layers.append(nn.Linear(self.hidden_size, int(np.prod(1))))
self.value_model = nn.Sequential(*layers) self.value_model = nn.Sequential(*layers)
def forward(self, state): def forward(self, state):
@ -169,6 +178,7 @@ class RewardModel(nn.Module):
return torch.distributions.independent.Independent( return torch.distributions.independent.Independent(
torch.distributions.Normal(reward, 1), 1) torch.distributions.Normal(reward, 1), 1)
"""
class TransitionModel(nn.Module): class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size): def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__() super().__init__()
@ -180,6 +190,7 @@ class TransitionModel(nn.Module):
self.act_fn = nn.LeakyReLU() self.act_fn = nn.LeakyReLU()
self.fc_state_action = nn.Linear(state_size + action_size, hidden_size) self.fc_state_action = nn.Linear(state_size + action_size, hidden_size)
self.ln = nn.LayerNorm(hidden_size)
self.history_cell = nn.GRUCell(hidden_size + history_size, history_size) self.history_cell = nn.GRUCell(hidden_size + history_size, history_size)
self.fc_state_prior = nn.Linear(history_size + state_size + action_size, 2 * state_size) self.fc_state_prior = nn.Linear(history_size + state_size + action_size, 2 * state_size)
self.fc_state_posterior = nn.Linear(history_size + state_size + action_size, 2 * state_size) self.fc_state_posterior = nn.Linear(history_size + state_size + action_size, 2 * state_size)
@ -194,12 +205,25 @@ class TransitionModel(nn.Module):
distribution = torch.distributions.independent.Independent(distribution, 1) distribution = torch.distributions.independent.Independent(distribution, 1)
return distribution return distribution
def imagine_step(self, prev_state, prev_action, prev_history): def stack_states(self, states, dim=0):
state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))) s = dict(
prev_hist = prev_history.detach() mean = torch.stack([state['mean'] for state in states], dim=dim),
history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist) std = torch.stack([state['std'] for state in states], dim=dim),
sample = torch.stack([state['sample'] for state in states], dim=dim),
history = torch.stack([state['history'] for state in states], dim=dim),)
if 'distribution' in states:
dist = dict(distribution = [state['distribution'] for state in states])
s.update(dist)
return s
state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1)) def seq_to_batch(self, state, name):
return dict(
sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:])))
def imagine_step(self, state, action, history):
state_action = self.ln(self.act_fn(self.fc_state_action(torch.cat([state, action], dim=-1))))
imag_hist = self.history_cell(torch.cat([state_action, history], dim=-1), history)
state_prior = self.fc_state_prior(torch.cat([imag_hist, state, action], dim=-1))
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
state_prior_std = F.softplus(state_prior_std) state_prior_std = F.softplus(state_prior_std)
@ -208,19 +232,9 @@ class TransitionModel(nn.Module):
# Sampling via reparameterization Trick # Sampling via reparameterization Trick
sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std) sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std)
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": imag_hist, "distribution": state_prior_dist}
return prior return prior
def stack_states(self, states, dim=0):
s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim),
sample = torch.stack([state['sample'] for state in states], dim=dim),
history = torch.stack([state['history'] for state in states], dim=dim),)
dist = dict(distribution = [state['distribution'] for state in states])
s.update(dist)
return s
def imagine_rollout(self, state, action, history, horizon): def imagine_rollout(self, state, action, history, horizon):
imagined_priors = [] imagined_priors = []
for i in range(horizon): for i in range(horizon):
@ -231,9 +245,125 @@ class TransitionModel(nn.Module):
imagined_priors = self.stack_states(imagined_priors, dim=0) imagined_priors = self.stack_states(imagined_priors, dim=0)
return imagined_priors return imagined_priors
def observe_step(self, prev_state, prev_action, prev_history, nonterms):
state_action = self.ln(self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))))
current_history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history)
state_prior = self.fc_state_prior(torch.cat([prev_history, prev_state, prev_action], dim=-1))
state_prior_mean, state_prior_std = torch.chunk(state_prior*nonterms, 2, dim=-1)
state_prior_std = F.softplus(state_prior_std) + 0.1
sample_state_prior = state_prior_mean + torch.randn_like(state_prior_mean) * state_prior_std
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": current_history}
return prior
def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms):
observed_rollout = []
for i in range(rollout_states.shape[0]):
actions = rollout_actions[i] * nonterms[i]
prior = self.observe_step(rollout_states[i], actions, init_history, nonterms[i])
init_history = prior["history"]
observed_rollout.append(prior)
observed_rollout = self.stack_states(observed_rollout, dim=0)
return observed_rollout
def reparemeterize(self, mean, std): def reparemeterize(self, mean, std):
eps = torch.randn_like(std) eps = torch.randn_like(std)
return mean + eps * std return mean + eps * std
"""
class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__()
self.state_size = state_size
self.hidden_size = hidden_size
self.action_size = action_size
self.history_size = history_size
self.act_fn = nn.ELU()
self.fc_state_action = nn.Linear(state_size + action_size, hidden_size)
self.history_cell = nn.GRUCell(hidden_size, history_size)
self.fc_state_mu = nn.Linear(history_size + hidden_size, state_size)
self.fc_state_sigma = nn.Linear(history_size + hidden_size, state_size)
self.batch_norm = nn.BatchNorm1d(hidden_size)
self.batch_norm2 = nn.BatchNorm1d(state_size)
self.min_sigma = 1e-4
self.max_sigma = 1e0
def init_states(self, batch_size, device):
self.prev_state = torch.zeros(batch_size, self.state_size).to(device)
self.prev_action = torch.zeros(batch_size, self.action_size).to(device)
self.prev_history = torch.zeros(batch_size, self.history_size).to(device)
def get_dist(self, mean, std):
distribution = torch.distributions.Normal(mean, std)
distribution = torch.distributions.independent.Independent(distribution, 1)
return distribution
def stack_states(self, states, dim=0):
s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim),
sample = torch.stack([state['sample'] for state in states], dim=dim),
history = torch.stack([state['history'] for state in states], dim=dim),)
if 'distribution' in states:
dist = dict(distribution = [state['distribution'] for state in states])
s.update(dist)
return s
def seq_to_batch(self, state, name):
return dict(
sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:])))
def imagine_step(self, state, action, history):
next_state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([state, action], dim=-1))))
imag_history = self.history_cell(next_state_action_enc, history)
next_state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([next_state_action_enc, imag_history], dim=-1))))
next_state_sigma = torch.sigmoid(self.fc_state_sigma(torch.cat([next_state_action_enc, imag_history], dim=-1)))
next_state_sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * next_state_sigma
# Normal Distribution
next_state_dist = self.get_dist(next_state_mu, next_state_sigma)
next_state_sample = self.reparemeterize(next_state_mu, next_state_sigma)
prior = {"mean": next_state_mu, "std": next_state_sigma, "sample": next_state_sample, "history": imag_history, "distribution": next_state_dist}
return prior
def imagine_rollout(self, state, action, history, horizon):
imagined_priors = []
for i in range(horizon):
prior = self.imagine_step(state, action, history)
state = prior["sample"]
history = prior["history"]
imagined_priors.append(prior)
imagined_priors = self.stack_states(imagined_priors, dim=0)
return imagined_priors
def observe_step(self, prev_state, prev_action, prev_history):
state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))))
current_history = self.history_cell(state_action_enc, prev_history)
state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([state_action_enc, prev_history], dim=-1))))
state_sigma = F.softplus(self.fc_state_sigma(torch.cat([state_action_enc, prev_history], dim=-1)))
sample_state = state_mu + torch.randn_like(state_mu) * state_sigma
state_enc = {"mean": state_mu, "std": state_sigma, "sample": sample_state, "history": current_history}
return state_enc
def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms):
observed_rollout = []
for i in range(rollout_states.shape[0]):
rollout_states_ = rollout_states[i]
rollout_actions_ = rollout_actions[i]
init_history_ = nonterms[i] * init_history
state_enc = self.observe_step(rollout_states_, rollout_actions_, init_history_)
init_history = state_enc["history"]
observed_rollout.append(state_enc)
observed_rollout = self.stack_states(observed_rollout, dim=0)
return observed_rollout
def reparemeterize(self, mean, std):
eps = torch.randn_like(mean)
return mean + eps * std
class TanhBijector(torch.distributions.Transform): class TanhBijector(torch.distributions.Transform):
@ -300,6 +430,7 @@ class ContrastiveHead(nn.Module):
return logits return logits
"""
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, last_states, current_states, negative_current_states, predicted_current_states): def __init__(self, last_states, current_states, negative_current_states, predicted_current_states):
super(CLUBSample, self).__init__() super(CLUBSample, self).__init__()
@ -313,7 +444,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
mu = dist.mean mu = dist.mean
var = dist.variance var = dist.variance
return mu, var, sample return mu.detach(), var.detach(), sample.detach()
def loglikeli(self): def loglikeli(self):
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
@ -330,15 +461,136 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
random_index = torch.randperm(sample_size).long() random_index = torch.randperm(sample_size).long()
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0) pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0) #neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0)
#neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
upper_bound = pos - neg upper_bound = pos - neg
return upper_bound/2 return upper_bound/2
def learning_loss(self): def learning_loss(self):
return - self.loglikeli() return - self.loglikeli()
"""
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, x_dim, y_dim, hidden_size):
super(CLUBSample, self).__init__()
self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim))
self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim),
nn.Tanh())
def get_mu_logvar(self, x_samples):
mu = self.p_mu(x_samples)
logvar = self.p_logvar(x_samples)
return mu, logvar
def loglikeli(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples)
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
def forward(self, x_samples, y_samples, y_negatives):
mu, logvar = self.get_mu_logvar(x_samples)
sample_size = x_samples.shape[0]
#random_index = torch.randint(sample_size, (sample_size,)).long()
random_index = torch.randperm(sample_size).long()
positive = -(mu - y_samples)**2 / logvar.exp()
#negative = - (mu - y_samples[random_index])**2 / logvar.exp()
negative = -(mu - y_negatives)**2 / logvar.exp()
upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()
return upper_bound/2.
def learning_loss(self, x_samples, y_samples):
return -self.loglikeli(x_samples, y_samples)
class QFunction(nn.Module):
"""MLP for q-function."""
def __init__(self, obs_dim, action_dim, hidden_dim):
super().__init__()
self.trunk = nn.Sequential(
nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, obs, action):
assert obs.size(0) == action.size(0)
obs_action = torch.cat([obs, action], dim=1)
return self.trunk(obs_action)
class Critic(nn.Module):
"""Critic network, employes two q-functions."""
def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_feature_dim):
super().__init__()
self.Q1 = QFunction(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.Q2 = QFunction(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
# detach_encoder allows to stop gradient propogation to encoder
obs = self.encoder(obs, detach=detach_encoder)
q1 = self.Q1(obs, action)
q2 = self.Q2(obs, action)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
class SampleDist:
def __init__(self, dist, samples=100):
self._dist = dist
self._samples = samples
@property
def name(self):
return 'SampleDist'
def __getattr__(self, name):
return getattr(self._dist, name)
def mean(self):
sample = self._dist.rsample(self._samples)
return torch.mean(sample, 0)
def mode(self):
dist = self._dist.expand((self._samples, *self._dist.batch_shape))
sample = dist.rsample()
logprob = dist.log_prob(sample)
batch_size = sample.size(1)
feature_size = sample.size(2)
indices = torch.argmax(logprob, dim=0).reshape(1, batch_size, 1).expand(1, batch_size, feature_size)
return torch.gather(sample, 0, indices).squeeze(0)
def entropy(self):
dist = self._dist.expand((self._samples, *self._dist.batch_shape))
sample = dist.rsample()
logprob = dist.log_prob(sample)
return -torch.mean(logprob, 0)
def sample(self):
return self._dist.sample()
if "__name__ == __main__": if "__name__ == __main__":
pass tr = TransitionModel(50, 512, 1, 256)

51
DPI/replay_buffer.py Normal file
View File

@ -0,0 +1,51 @@
import torch
import numpy as np
class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size):
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,), dtype=np.float32)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
def add(self, obs, ac, next_obs, rew, done):
self.observations[self.idx] = obs
self.next_observations[self.idx] = next_obs
self.actions[self.idx] = ac
self.rewards[self.idx] = rew
self.terminals[self.idx] = done
self.idx = (self.idx + 1) % self.size
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
def _sample_idx(self, L):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.size if self.full else self.idx - L)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = not self.idx in idxs[1:]
return idxs
def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs]
next_observations = self.next_observations[vec_idxs]
return observations.reshape(L, n, *observations.shape[1:]),self.actions[vec_idxs].reshape(L, n, -1), next_observations.reshape(L, n, *next_observations.shape[1:]), self.rewards[vec_idxs].reshape(L, n), self.terminals[vec_idxs].reshape(L, n)
def sample(self):
n = self.batch_size
l = self.seq_len
obs,acs,nxt_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,nxt_obs,rews,terms

View File

@ -6,11 +6,12 @@ import wandb
import random import random
import argparse import argparse
import numpy as np import numpy as np
from collections import OrderedDict
import utils import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image, shuffle_along_axis, Logger
from replay_buffer import ReplayBuffer
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
@ -40,19 +41,22 @@ def parse_args():
parser.add_argument('--resource_files', type=str) parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int) # 10000 parser.add_argument('--total_frames', default=5000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true') parser.add_argument('--high_noise', action='store_true')
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000 parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=51, type=int) parser.add_argument('--episode_length', default=51, type=int)
# train # train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--init_steps', default=5000, type=int)
parser.add_argument('--num_train_steps', default=10000, type=int) parser.add_argument('--num_train_steps', default=100000, type=int)
parser.add_argument('--batch_size', default=30, type=int) #512 parser.add_argument('--update_steps', default=10, type=int)
parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--state_size', default=100, type=int)
parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--history_size', default=256, type=int)
parser.add_argument('--episode_collection', default=5, type=int)
parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer')
parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--imagine_horizon', default=15, type=str)
@ -60,42 +64,33 @@ def parse_args():
# eval # eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
# value # value
parser.add_argument('--value_lr', default=8e-5, type=float) parser.add_argument('--value_lr', default=8e-6, type=float)
parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int) parser.add_argument('--td_lambda', default=0.95, type=int)
# actor # actor
parser.add_argument('--actor_lr', default=8e-5, type=float) parser.add_argument('--actor_lr', default=8e-6, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int) parser.add_argument('--actor_update_freq', default=2, type=int)
# world/encoder/decoder # world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--world_model_lr', default=1e-6, type=float)
parser.add_argument('--world_model_lr', default=6e-4, type=float) parser.add_argument('--decoder_lr', default=6e-6, type=float)
parser.add_argument('--past_transition_lr', default=1e-3, type=float) parser.add_argument('--reward_lr', default=8e-6, type=float)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_tau', default=0.001, type=float)
parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true') parser.add_argument('--aug', action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int) parser.add_argument('--logging_freq', default=100, type=int)
parser.add_argument('--saving_interval', default=1000, type=int) parser.add_argument('--saving_interval', default=2500, type=int)
parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true')
@ -126,6 +121,7 @@ class DPI:
self.env = make_env(self.args) self.env = make_env(self.args)
#self.args.seed = np.random.randint(0, 1000) #self.args.seed = np.random.randint(0, 1000)
self.env.seed(self.args.seed) self.env.seed(self.args.seed)
self.global_episodes = 0
# noiseless environment setup # noiseless environment setup
self.args.version = 2 # env_id changes to v2 self.args.version = 2 # env_id changes to v2
@ -137,14 +133,14 @@ class DPI:
self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat) self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env) self.env = utils.NormalizeActions(self.env)
self.env = utils.TimeLimit(self.env, 1000 // args.action_repeat)
# create replay buffer # create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, self.data_buffer = ReplayBuffer(self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), self.env.observation_space.shape,
action_size=self.env.action_space.shape[0], self.env.action_space.shape[0],
seq_len=self.args.episode_length, self.args.episode_length,
batch_size=args.batch_size, self.args.batch_size)
args=self.args)
# create work directory # create work directory
utils.make_dir(self.args.work_dir) utils.make_dir(self.args.work_dir)
@ -161,16 +157,19 @@ class DPI:
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
).to(device) ).to(device)
self.obs_encoder.apply(self.init_weights)
self.obs_encoder_momentum = ObservationEncoder( self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
).to(device) ).to(device)
self.obs_encoder_momentum.apply(self.init_weights)
self.obs_decoder = ObservationDecoder( self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) output_shape=(self.args.channels*self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
).to(device) ).to(device)
self.obs_decoder.apply(self.init_weights)
self.transition_model = TransitionModel( self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
@ -178,6 +177,7 @@ class DPI:
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
).to(device) ).to(device)
self.transition_model.apply(self.init_weights)
# Actor Model # Actor Model
self.actor_model = Actor( self.actor_model = Actor(
@ -185,22 +185,27 @@ class DPI:
hidden_size=self.args.hidden_size, # 256, hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
).to(device) ).to(device)
self.actor_model.apply(self.init_weights)
# Value Models # Value Models
self.value_model = ValueModel( self.value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.value_model.apply(self.init_weights)
self.target_value_model = ValueModel( self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.target_value_model.apply(self.init_weights)
self.reward_model = RewardModel( self.reward_model = RewardModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.reward_model.apply(self.init_weights)
# Contrastive Models # Contrastive Models
self.prjoection_head = ProjectionHead( self.prjoection_head = ProjectionHead(
@ -219,23 +224,32 @@ class DPI:
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.club_sample = CLUBSample(
x_dim=self.args.state_size, # 128
y_dim=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
# model parameters # model parameters
self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + \ self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \
list(self.value_model.parameters()) + list(self.transition_model.parameters()) + \ list(self.transition_model.parameters()) + list(self.club_sample.parameters()) + \
list(self.prjoection_head.parameters()) list(self.contrastive_head.parameters())
self.past_transition_parameters = self.transition_model.parameters()
# optimizers # optimizers
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr) self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr) self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr) self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6)
self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr) self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6)
self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6)
# Create Modules # Create Modules
self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.reward_model, self.transition_model, self.prjoection_head] self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head,
self.obs_encoder_momentum, self.prjoection_head_momentum]
self.value_modules = [self.value_model] self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model] self.actor_modules = [self.actor_model]
self.decoder_modules = [self.obs_decoder]
self.reward_modules = [self.reward_model]
if use_saved: if use_saved:
self._use_saved_models(saved_model_dir) self._use_saved_models(saved_model_dir)
@ -245,280 +259,432 @@ class DPI:
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None): def collect_random_sequences(self, seed_steps):
obs = self.env.reset() obs = self.env.reset()
done = False done = False
all_rews = [] all_rews = []
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) self.global_episodes += 1
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): epi_reward = 0
if args.save_video: for _ in tqdm.tqdm(range(seed_steps), desc='Collecting episodes'):
self.env.video.init(enabled=True) action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done)
obs = next_obs
epi_reward += rew
if done:
obs = self.env.reset()
done=False
all_rews.append(epi_reward)
epi_reward = 0
return all_rews
epi_reward = 0 def collect_sequences(self, collect_steps, actor_model):
for i in range(self.args.episode_length): obs = self.env.reset()
if random: done = False
action = self.env.action_space.sample() all_rews = []
else: self.global_episodes += 1
with torch.no_grad(): epi_reward = 0
obs_torch = torch.unsqueeze(torch.tensor(obs).float(),0).to(device) for episode_count in tqdm.tqdm(range(collect_steps), desc='Collecting episodes'):
state = self.obs_encoder(obs_torch)["distribution"].sample() with torch.no_grad():
action = self.actor_model(state).cpu().detach().numpy().squeeze() obs_ = torch.tensor(obs.copy(), dtype=torch.float32)
obs_ = preprocess_obs(obs_).to(device)
#state = self.get_features(obs_)["sample"].unsqueeze(0)
state = self.get_features(obs_)["distribution"].rsample()
action = actor_model(state)
action = actor_model.add_exploration(action)
action = action.cpu().numpy()[0]
next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done)
next_obs, rew, done, _ = self.env.step(action) if done:
self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done) obs = self.env.reset()
done = False
if args.save_video: all_rews.append(epi_reward)
self.env.video.record(self.env) epi_reward = 0
else:
if done or i == self.args.episode_length-1: obs = next_obs
obs = self.env.reset()
done=False
else:
obs = next_obs
epi_reward += rew epi_reward += rew
all_rews.append(epi_reward)
if args.save_video:
self.env.video.save('noisy/%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1))
return all_rews return all_rews
def train(self, step, total_steps): def train(self, step, total_steps):
counter = 0 # logger
while step < total_steps: logdir = os.path.dirname(os.path.realpath(__file__)) + "/log/logs/"
if not(os.path.exists(logdir)):
os.makedirs(logdir)
initial_logs = OrderedDict()
logger = Logger(logdir)
# collect experience episodic_rews = self.collect_random_sequences(self.args.init_steps//args.action_repeat)
if step !=0: self.global_step = self.data_buffer.steps
encoder = self.obs_encoder
actor = self.actor_model
#all_rews = self.collect_sequences(self.args.batch_size, random=True)
all_rews = self.collect_sequences(self.args.batch_size, random=False, actor_model=actor, encoder_model=encoder)
else:
all_rews = self.collect_sequences(self.args.batch_size, random=True)
# Group by steps and sample random batch initial_logs.update({
random_indices = self.data_buffer.sample_random_idx(self.args.batch_size * ((step//self.args.collection_interval)+1)) # random indices for batch 'train_avg_reward':np.mean(episodic_rews),
#random_indices = np.arange(self.args.batch_size * ((step//self.args.collection_interval)),self.args.batch_size * ((step//self.args.collection_interval)+1)) 'train_max_reward': np.max(episodic_rews),
last_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"observations", "cpu", random_indices=random_indices) 'train_min_reward': np.min(episodic_rews),
current_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", random_indices=random_indices) 'train_std_reward':np.std(episodic_rews),
next_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", offset=1, random_indices=random_indices) })
actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, random_indices=random_indices) logger.log_scalars(initial_logs, step=0)
next_actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, offset=1, random_indices=random_indices) logger.flush()
rewards = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"rewards", device=device, is_obs=False, offset=1, random_indices=random_indices)
# Preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# Train encoder
if step == 0:
step += 1
for _ in range(self.args.collection_interval // self.args.episode_length+1):
counter += 1
for i in range(self.args.episode_length-1):
if i > 0:
# Encode observations and next_observations
self.last_states_dict = self.get_features(last_observations[i])
self.current_states_dict = self.get_features(current_observations[i])
self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6)
self.next_action = next_actions[i] # (N,6)
history = self.transition_model.prev_history
# Encode negative observations
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
# Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels)
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.next_action, self.history.detach(),
imagine_horizon)
# decoder loss
horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1]))
# update models
world_model_loss = encoder_loss + 100 * ub_loss + lb_loss + reward_loss + decoder_loss * 1e-2
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# update momentum encoder
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
# update momentum projection head
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# actor loss
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(imagined_rollout["sample"])
target_imag_val_dist = self.target_value_model(imagined_rollout["sample"])
imag_rews = imag_rew_dist.mean
target_imag_vals = target_imag_val_dist.mean
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
self.target_returns = self._compute_lambda_return(imag_rews[:-1],
target_imag_vals[:-1],
discounts[:-1] ,
self.args.td_lambda,
target_imag_vals[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.target_returns)
# update actor
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
# value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.target_returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
# update value
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
# counter for reward
count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size))
if step % self.args.logging_freq: while self.global_step < total_steps:
writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), step) logs = OrderedDict()
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), step) step += 1
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step) for update_steps in range(self.args.update_steps):
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step) model_loss, actor_loss, value_loss, actor_model = self.update((step-1)*args.update_steps + update_steps)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step)
step += 1 initial_logs.update({
if step>total_steps: 'model_loss' : model_loss,
print("Training finished") 'actor_loss': actor_loss,
break 'value_loss': value_loss,
'train_avg_reward':np.mean(episodic_rews),
# save model 'train_max_reward': np.max(episodic_rews),
if step % self.args.saving_interval == 0: 'train_min_reward': np.min(episodic_rews),
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" 'train_std_reward':np.std(episodic_rews),
self.save_models(path) })
logger.log_scalars(logs, self.global_step)
#torch.cuda.empty_cache() # memory leak issues
for j in range(len(all_rews)):
writer.add_scalar('Rewards/Rewards', all_rews[j], count[j])
def evaluate(self, env, eval_episodes, render=False): print("########## Global Step:", self.global_step, " ##########")
for key, value in initial_logs.items():
print(key, " : ", value)
episode_rew = np.zeros((eval_episodes)) episodic_rews = self.collect_sequences(1000//self.args.action_repeat, actor_model)
video_images = [[] for _ in range(eval_episodes)] if self.global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
print("Saving model")
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path)
self.evaluate()
for i in range(eval_episodes): self.global_step = self.data_buffer.steps * self.args.action_repeat
obs = env.reset()
"""
# collect experience
if step !=0:
encoder = self.obs_encoder
actor = self.actor_model
all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder)
"""
def collect_batch(self):
obs_, acs_, nxt_obs_, rews_, terms_ = self.data_buffer.sample()
obs = torch.tensor(obs_, dtype=torch.float32)[1:]
last_obs = torch.tensor(obs_, dtype=torch.float32)[:-1]
nxt_obs = torch.tensor(nxt_obs_, dtype=torch.float32)[1:]
acs = torch.tensor(acs_, dtype=torch.float32)[:-1].to(device)
nxt_acs = torch.tensor(acs_, dtype=torch.float32)[1:].to(device)
rews = torch.tensor(rews_, dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
nonterms = torch.tensor((1.0-terms_), dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
last_obs = preprocess_obs(last_obs).to(device)
obs = preprocess_obs(obs).to(device)
nxt_obs = preprocess_obs(nxt_obs).to(device)
return last_obs, obs, nxt_obs, acs, rews, nxt_acs, nonterms
def update(self, step):
last_observations, current_observations, next_observations, actions, rewards, next_actions, nonterms = self.collect_batch()
#last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch()
world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations,
current_observations,
next_observations,
actions,
next_actions,
rewards,
nonterms)
self.world_model_opt.zero_grad()
world_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
self.decoder_opt.zero_grad()
dec_loss.backward()
nn.utils.clip_grad_norm_(self.obs_decoder.parameters(), self.args.grad_clip_norm)
self.decoder_opt.step()
self.reward_opt.zero_grad()
rew_loss.backward()
nn.utils.clip_grad_norm_(self.reward_model.parameters(), self.args.grad_clip_norm)
self.reward_opt.step()
actor_loss = self.actor_model_losses()
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
value_loss = self.value_model_losses()
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update momentum encoder and projection head
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# update target value networks
#if step % self.args.value_target_update_freq == 0:
# self.target_value_model = copy.deepcopy(self.value_model)
if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Encoder Loss', enc_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Decoder Loss', dec_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', -lb_loss.detach().item(), step)
return world_loss.item(), actor_loss.item(), value_loss.item(), self.actor_model
def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards, nonterms):
# get features
self.last_state_feat = self.get_features(last_obs)
self.curr_state_feat = self.get_features(curr_obs)
self.nxt_state_feat = self.get_features(nxt_obs)
self.nxt_state_feat_lb = self.get_features(nxt_obs, momentum=True)
# states
self.last_state_enc = self.last_state_feat["distribution"].rsample() #self.last_state_feat["sample"]
self.curr_state_enc = self.curr_state_feat["distribution"].rsample() #self.curr_state_feat["sample"]
self.nxt_state_enc = self.nxt_state_feat["distribution"].rsample() #self.nxt_state_feat["sample"]
self.nxt_state_enc_lb = self.nxt_state_feat_lb["distribution"].rsample() #self.nxt_state_feat_lb["sample"]
# predict next states
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms)
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
self.pred_curr_state_enc = self.pred_curr_state_dist.rsample() #self.observed_rollout["sample"]
# encoder loss
enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
# reward loss
rew_dist = self.reward_model(self.curr_state_enc.detach())
#print(torch.cat([rew_dist.mean[0], rewards[0]],dim=-1))
rew_loss = -torch.mean(rew_dist.log_prob(rewards))
# decoder loss
dec_dist = self.obs_decoder(self.nxt_state_enc.detach())
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
# upper bound loss
past_ub_loss = 0
for i in range(self.curr_state_enc.shape[0]):
_, ub_loss = self._upper_bound_minimization(self.curr_state_enc[i],
self.pred_curr_state_enc[i])
ub_loss = ub_loss + past_ub_loss
past_ub_loss = ub_loss
ub_loss = ub_loss / self.curr_state_enc.shape[0]
ub_loss = 1 * ub_loss
# lower bound loss
# contrastive projection
vec_anchor = self.pred_curr_state_enc.detach()
vec_positive = self.nxt_state_enc_lb.detach()
z_anchor = self.prjoection_head(vec_anchor, nxt_actions)
z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions)
# contrastive loss
past_lb_loss = 0
for i in range(z_anchor.shape[0]):
logits = self.contrastive_head(z_anchor[i], z_positive[i])
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels) + past_lb_loss
past_lb_loss = lb_loss.detach().item()
lb_loss = -0.01 * lb_loss/(z_anchor.shape[0])
world_loss = enc_loss + ub_loss + lb_loss
return world_loss, enc_loss , rew_loss, dec_loss, ub_loss, lb_loss
def actor_model_losses(self):
with torch.no_grad():
#curr_state_enc = self.curr_state_enc.reshape(self.args.episode_length-1,-1) #self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"]
#curr_state_hist = self.observed_rollout["history"].reshape(self.args.episode_length-1,-1) #self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"]
curr_state_enc = self.curr_state_enc.reshape(-1, self.args.state_size)
curr_state_hist = self.observed_rollout["history"].reshape(-1, self.args.history_size)
with FreezeParameters(self.world_model_modules + self.decoder_modules + self.reward_modules + self.value_modules):
imagine_horizon = self.args.imagine_horizon
action = self.actor_model(curr_state_enc.detach())
self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc,
action, curr_state_hist.detach(),
imagine_horizon)
self.pred_nxt_state_dist = self.transition_model.get_dist(self.imagined_rollout["mean"], self.imagined_rollout["std"])
self.pred_nxt_state_enc = self.pred_nxt_state_dist.rsample() #self.transition_model.reparemeterize(self.imagined_rollout["mean"], self.imagined_rollout["std"])
with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules):
imag_rewards_dist = self.reward_model(self.pred_nxt_state_enc)
imag_values_dist = self.value_model(self.pred_nxt_state_enc)
imag_rewards = imag_rewards_dist.mean
imag_values = imag_values_dist.mean
#print(torch.cat([imag_rewards[0], imag_values[0]],dim=-1))
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1],
imag_values[:-1],
discounts[:-1] ,
self.args.td_lambda,
imag_values[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.returns)
return actor_loss
def value_model_losses(self):
with torch.no_grad():
value_feat = self.pred_nxt_state_enc[:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
return value_loss
def select_one_batch(self):
# collect sequences
non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0]
current_obs = self.data_buffer.observations[non_zero_indices]
next_obs = self.data_buffer.next_observations[non_zero_indices]
actions_raw = self.data_buffer.actions[non_zero_indices]
rewards = self.data_buffer.rewards[non_zero_indices]
self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=False)[0]
# group by episodes
current_obs = self.grouped_arrays(current_obs)
next_obs = self.grouped_arrays(next_obs)
actions_raw = self.grouped_arrays(actions_raw)
rewards_ = self.grouped_arrays(rewards)
# select random chunks of episodes
if current_obs.shape[0] < self.args.batch_size:
random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size)
else:
random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size)
# select random starting points
if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size:
init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size)
else:
init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size))
# shuffle
random.shuffle(random_episode_number)
random.shuffle(init_index)
# select first k elements
last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1]
current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:]
next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1]
actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device)
next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device)
rewards = self.select_first_k(rewards_, init_index, random_episode_number)[:-1].to(device)
# preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
return last_observations, current_observations, next_observations, actions, next_actions, rewards
def evaluate(self, eval_episodes=10):
path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.restore_checkpoint(path)
obs = self.env.reset()
done = False
#video = VideoRecorder(self.video_dir, resource_files=self.args.resource_files)
if self.args.save_video:
self.env.video.init(enabled=True)
episodic_rewards = []
for episode in range(eval_episodes):
rewards = 0
done = False done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
while not done: while not done:
with torch.no_grad(): with torch.no_grad():
posterior, action = self.act_with_world_model(obs, prev_state, prev_action) obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0)
action = action[0].cpu().numpy() obs_processed = preprocess_obs(obs).to(device)
next_obs, rew, done, _ = env.step(action) state = self.get_features(obs_processed)["distribution"].rsample()
prev_state = posterior action = self.actor_model(state).cpu().detach().numpy().squeeze()
prev_action = torch.tensor(action, dtype=torch.float32).to(self.device).unsqueeze(0) next_obs, rew, done, _ = self.env.step(action)
rewards += rew
episode_rew[i] += rew
if render:
video_images[i].append(obs['image'].transpose(1,2,0).copy())
obs = next_obs obs = next_obs
return episode_rew, np.array(video_images[:self.args.max_videos_to_save])
def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states): if self.args.save_video:
club_sample = CLUBSample(last_states, self.env.video.record(self.env)
current_states, self.env.video.save('/home/vedant/Curiosity/Curiosity/DPI/log/video/learned_model.mp4')
negative_current_states, obs = self.env.reset()
predicted_current_states) episodic_rewards.append(rewards)
likelihood_loss = club_sample.learning_loss() print("Episodic rewards: ", episodic_rewards)
club_loss = club_sample() print("Average episodic reward: ", np.mean(episodic_rewards))
def init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def grouped_arrays(self,array):
indices = [0] + self.terms.tolist()
def subarrays():
for start, end in zip(indices[:-1], indices[1:]):
yield array[start:end]
try:
subarrays = np.stack(list(subarrays()), axis=0)
except ValueError:
subarrays = np.asarray(list(subarrays()))
return subarrays
def select_first_k(self, array, init_index, episode_number):
term_index = init_index + self.args.episode_length
array = array[episode_number]
array_list = []
for i in range(array.shape[0]):
array_list.append(array[i][init_index[i]:term_index[i]])
array = np.asarray(array_list)
if array.ndim == 5:
transposed_array = np.transpose(array, (1, 0, 2, 3, 4))
elif array.ndim == 4:
transposed_array = np.transpose(array, (1, 0, 2, 3))
elif array.ndim == 3:
transposed_array = np.transpose(array, (1, 0, 2))
elif array.ndim == 2:
transposed_array = np.transpose(array, (1, 0))
else:
transposed_array = np.expand_dims(array, axis=0)
#return torch.tensor(array).float()
return torch.tensor(transposed_array).float()
def _upper_bound_minimization(self, current_states, predicted_current_states):
current_negative_states = shuffle_along_axis(current_states.clone(), axis=0)
current_negative_states = shuffle_along_axis(current_negative_states, axis=1)
club_loss = self.club_sample(current_states, predicted_current_states, current_negative_states)
likelihood_loss = 0
return likelihood_loss, club_loss return likelihood_loss, club_loss
def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): def _encoder_loss(self, curr_states_dist, predicted_curr_states_dist):
# current state distribution
curr_states_dist = curr_states_dict["distribution"]
# predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss # KL divergence loss
loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() loss = torch.mean(torch.distributions.kl.kl_divergence(curr_states_dist,predicted_curr_states_dist))
return loss return loss
def get_features(self, x, momentum=False): def get_features(self, x, momentum=False):
if self.args.aug: if self.args.aug:
x = T.RandomCrop((80, 80))(x) # (None,80,80,4) crop_transform = T.RandomCrop(size=80)
x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) cropped_x = torch.stack([crop_transform(x[i]) for i in range(x.size(0))])
x = T.RandomCrop((84, 84))(x) # (None,84,84,4) padding = (2, 2, 2, 2)
x = F.pad(cropped_x, padding)
with torch.no_grad(): with torch.no_grad():
if momentum: if momentum:
@ -528,6 +694,19 @@ class DPI:
return x return x
def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value): def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value):
next_values = torch.cat([values[1:], last_value[None]], 0)
target = rewards + discounts * next_values * (1 - td_lam)
timesteps = list(range(rewards.shape[0] - 1, -1, -1))
outputs = []
accumulated_reward = last_value
for t in timesteps:
inp = target[t]
discount_factor = discounts[t]
accumulated_reward = inp + discount_factor * td_lam * accumulated_reward
outputs.append(accumulated_reward)
returns = torch.flip(torch.stack(outputs), [0])
return returns
"""
next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0) next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0)
targets = rewards + discounts * next_values * (1-td_lam) targets = rewards + discounts * next_values * (1-td_lam)
rets =[] rets =[]
@ -539,6 +718,25 @@ class DPI:
returns = torch.flip(torch.stack(rets), [0]) returns = torch.flip(torch.stack(rets), [0])
return returns return returns
"""
def lambda_return(self,imged_reward, value_pred, bootstrap, discount=0.99, lambda_=0.95):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
next_values = torch.cat([value_pred[1:], bootstrap[None]], 0)
discount_tensor = discount * torch.ones_like(imged_reward) # pcont
inputs = imged_reward + discount_tensor * next_values * (1 - lambda_)
last = bootstrap
indices = reversed(range(len(inputs)))
outputs = []
for index in indices:
inp, disc = inputs[index], discount_tensor[index]
last = inp + disc * lambda_ * last
outputs.append(last)
outputs = list(reversed(outputs))
outputs = torch.stack(outputs, 0)
returns = outputs
return returns
def save_models(self, save_path): def save_models(self, save_path):
torch.save( torch.save(
@ -551,6 +749,17 @@ class DPI:
'value_optimizer': self.value_opt.state_dict(), 'value_optimizer': self.value_opt.state_dict(),
'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path) 'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path)
def restore_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path)
self.transition_model.load_state_dict(checkpoint['rssm'])
self.actor_model.load_state_dict(checkpoint['actor'])
self.reward_model.load_state_dict(checkpoint['reward_model'])
self.obs_encoder.load_state_dict(checkpoint['obs_encoder'])
self.obs_decoder.load_state_dict(checkpoint['obs_decoder'])
self.world_model_opt.load_state_dict(checkpoint['world_model_optimizer'])
self.actor_opt.load_state_dict(checkpoint['actor_optimizer'])
self.value_opt.load_state_dict(checkpoint['value_optimizer'])
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
@ -560,6 +769,7 @@ if __name__ == '__main__':
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
step = 0 step = 0
total_steps = 10000 total_steps = 2000000
dpi = DPI(args) dpi = DPI(args)
dpi.train(step,total_steps) dpi.train(step,total_steps)
dpi.evaluate()

View File

@ -1,10 +1,13 @@
import os import os
import random import random
import pickle
import numpy as np import numpy as np
from collections import deque from collections import deque
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import gym import gym
import dmc2gym import dmc2gym
@ -60,17 +63,6 @@ def make_dir(dir_path):
return dir_path return dir_path
def preprocess_obs(obs, bits=5):
"""Preprocessing image, see https://arxiv.org/abs/1807.03039."""
bins = 2**bits
assert obs.dtype == torch.float32
if bits < 8:
obs = torch.floor(obs / 2**(8 - bits))
obs = obs / bins
obs = obs + torch.rand_like(obs) / bins
obs = obs - 0.5
return obs
class FrameStack(gym.Wrapper): class FrameStack(gym.Wrapper):
def __init__(self, env, k): def __init__(self, env, k):
@ -144,8 +136,90 @@ class NormalizeActions:
original = np.where(self._mask, original, action) original = np.where(self._mask, original, action)
return self._env.step(original) return self._env.step(original)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, 'Must reset environment.'
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if 'discount' not in info:
info['discount'] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
class ReplayBuffer: class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,), dtype=np.float32)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
self.episode_count = np.zeros((size,), dtype=np.int32)
def add(self, obs, ac, next_obs, rew, done, episode_count):
self.observations[self.idx] = obs
self.next_observations[self.idx] = next_obs
self.actions[self.idx] = ac
self.rewards[self.idx] = rew
self.terminals[self.idx] = done
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
self.episode_count[self.idx] = episode_count
self.idx = (self.idx + 1) % self.size
def _sample_idx(self, L):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.size if self.full else self.idx - L)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = not self.idx in idxs[1:]
return idxs
def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs]
next_obs = self.next_observations[vec_idxs]
obs = observations.reshape(L, n, *observations.shape[1:])
next_obs = next_obs.reshape(L, n, *next_obs.shape[1:])
acs = self.actions[vec_idxs].reshape(L, n, -1)
rew = self.rewards[vec_idxs].reshape(L, n)
term = self.terminals[vec_idxs].reshape(L, n)
return obs, acs, next_obs, rew, term
def sample(self):
n = self.batch_size
l = self.seq_len
obs,acs,next_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,next_obs,rews,terms
class ReplayBuffer1:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size self.size = size
self.obs_shape = obs_shape self.obs_shape = obs_shape
@ -199,7 +273,11 @@ class ReplayBuffer:
def group_steps(self, buffer, variable, obs=True): def group_steps(self, buffer, variable, obs=True):
variable = getattr(buffer, variable) variable = getattr(buffer, variable)
non_zero_indices = np.nonzero(buffer.episode_count)[0] non_zero_indices = np.nonzero(buffer.episode_count)[0]
print(buffer.episode_count)
variable = variable[non_zero_indices] variable = variable[non_zero_indices]
print(variable.shape)
exit()
if obs: if obs:
variable = variable.reshape(-1, self.args.episode_length, variable = variable.reshape(-1, self.args.episode_length,
self.args.frame_stack*self.args.channels, self.args.frame_stack*self.args.channels,
@ -214,8 +292,9 @@ class ReplayBuffer:
self.args.image_size,self.args.image_size) self.args.image_size,self.args.image_size)
return variable return variable
def sample_random_idx(self, buffer_length): def sample_random_idx(self, buffer_length, last=False):
random_indices = random.sample(range(0, buffer_length), self.args.batch_size) init = 0 if last else buffer_length - self.args.batch_size
random_indices = random.sample(range(init, buffer_length), self.args.batch_size)
return random_indices return random_indices
def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0): def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0):
@ -247,19 +326,23 @@ def make_env(args):
) )
return env return env
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a,idx,axis=axis)
def preprocess_obs(obs): def preprocess_obs(obs):
obs = obs/255.0 - 0.5 obs = (obs/255.0) - 0.5
return obs return obs
def soft_update_params(net, target_net, tau): def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()): for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_( target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data tau * param.detach().data + (1 - tau) * target_param.data
) )
def save_image(array, filename): def save_image(array, filename):
array = array.transpose(1, 2, 0) array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8) array = ((array+0.5) * 255).astype(np.uint8)
image = Image.fromarray(array) image = Image.fromarray(array)
image.save(filename) image.save(filename)
@ -280,6 +363,20 @@ def video_from_array(arr, high_noise, filename):
out.write(frame) out.write(frame)
out.release() out.release()
def save_video(images):
"""
Image shape is (T, C, H, W)
Example:(50, 3, 84, 84)
"""
output_file = "output.avi"
fourcc = cv2.VideoWriter_fourcc(*'XVID')
fps = 2
height, width, channels = 84,84,3
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
for image in images:
image = np.uint8(image.transpose((1, 2, 0)))
out.write(image)
out.release()
class CorruptVideos: class CorruptVideos:
def __init__(self, dir_path): def __init__(self, dir_path):
@ -353,3 +450,51 @@ class FreezeParameters:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)): for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i] param.requires_grad = self.param_states[i]
class Logger:
def __init__(self, log_dir, n_logged_samples=10, summary_writer=None):
self._log_dir = log_dir
print('########################')
print('logging outputs to ', log_dir)
print('########################')
self._n_logged_samples = n_logged_samples
self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1)
def log_scalar(self, scalar, name, step_):
self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
def log_scalars(self, scalar_dict, step):
for key, value in scalar_dict.items():
print('{} : {}'.format(key, value))
self.log_scalar(value, key, step)
self.dump_scalars_to_pickle(scalar_dict, step)
def log_videos(self, videos, step, max_videos_to_save=1, fps=20, video_title='video'):
# max rollout length
max_videos_to_save = np.min([max_videos_to_save, videos.shape[0]])
max_length = videos[0].shape[0]
for i in range(max_videos_to_save):
if videos[i].shape[0]>max_length:
max_length = videos[i].shape[0]
# pad rollouts to all be same length
for i in range(max_videos_to_save):
if videos[i].shape[0]<max_length:
padding = np.tile([videos[i][-1]], (max_length-videos[i].shape[0],1,1,1))
videos[i] = np.concatenate([videos[i], padding], 0)
clip = mpy.ImageSequenceClip(list(videos[i]), fps=fps)
new_video_title = video_title+'{}_{}'.format(step, i) + '.gif'
filename = os.path.join(self._log_dir, new_video_title)
video.write_gif(filename, fps =fps)
def dump_scalars_to_pickle(self, metrics, step, log_title=None):
log_path = os.path.join(self._log_dir, "scalar_data.pkl" if log_title is None else log_title)
with open(log_path, 'ab') as f:
pickle.dump({'step': step, **dict(metrics)}, f)
def flush(self):
self._summ_writer.flush()