Compare commits

..

No commits in common. "idea_1" and "main" have entirely different histories.
idea_1 ... main

4 changed files with 325 additions and 983 deletions

View File

@ -23,24 +23,23 @@ class ObservationEncoder(nn.Module):
self.convs = nn.Sequential(*layers) self.convs = nn.Sequential(*layers)
self.fc = nn.Linear(256 * obs_shape[0], 2 * state_size) # 9 if 3 frames stacked self.fc = nn.Linear(256 * 3 * 3, 2 * state_size)
def forward(self, x): def forward(self, x):
x_reshaped = x.reshape(-1, *x.shape[-3:]) x = self.convs(x)
x_embed = self.convs(x_reshaped) x = x.view(x.size(0), -1)
x_embed = torch.reshape(x_embed, (*x.shape[:-3], -1)) x = self.fc(x)
x = self.fc(x_embed)
# Mean and standard deviation # Mean and standard deviation
mean, std = torch.chunk(x, 2, dim=-1) mean, std = torch.chunk(x, 2, dim=-1)
mean = nn.ELU()(mean)
std = F.softplus(std) std = F.softplus(std)
std = torch.clamp(std, min=0.0, max=1e1) std = torch.clamp(std, min=0.0, max=1e5)
# Normal Distribution # Normal Distribution
dist = self.get_dist(mean, std) dist = self.get_dist(mean, std)
# Sampling via reparameterization Trick # Sampling via reparameterization Trick
#x = dist.rsample()
x = self.reparameterize(mean, std) x = self.reparameterize(mean, std)
encoded_output = {"sample": x, "distribution": dist} encoded_output = {"sample": x, "distribution": dist}
@ -64,7 +63,7 @@ class ObservationDecoder(nn.Module):
self.output_shape = output_shape self.output_shape = output_shape
self.input_size = 256 * 3 * 3 self.input_size = 256 * 3 * 3
self.in_channels = [self.input_size, 256, 128, 64] self.in_channels = [self.input_size, 256, 128, 64]
self.out_channels = [256, 128, 64, 9] self.out_channels = [256, 128, 64, 3]
if output_shape[1] == 84: if output_shape[1] == 84:
self.kernels = [5, 7, 5, 6] self.kernels = [5, 7, 5, 6]
@ -95,50 +94,43 @@ class ObservationDecoder(nn.Module):
class Actor(nn.Module): class Actor(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=4, min_std=1e-4, init_std=5, mean_scale=5): def __init__(self, state_size, hidden_size, action_size, num_layers=5):
super().__init__() super().__init__()
self.state_size = state_size self.state_size = state_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.action_size = action_size self.action_size = action_size
self.num_layers = num_layers self.num_layers = num_layers
self._min_std = min_std self._min_std=torch.Tensor([1e-4])[0]
self._init_std = init_std self._init_std=torch.Tensor([5])[0]
self._mean_scale = mean_scale self._mean_scale=torch.Tensor([5])[0]
layers = [] layers = []
for i in range(self.num_layers): for i in range(self.num_layers):
input_channels = state_size if i == 0 else self.hidden_size input_channels = state_size if i == 0 else self.hidden_size
layers.append(nn.Linear(input_channels, self.hidden_size)) output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size
layers.append(nn.ReLU()) layers.append(nn.Linear(input_channels, output_channels))
layers.append(nn.Linear(self.hidden_size, 2*self.action_size)) layers.append(nn.LeakyReLU())
self.action_model = nn.Sequential(*layers) self.action_model = nn.Sequential(*layers)
def get_dist(self, mean, std): def get_dist(self, mean, std):
distribution = torch.distributions.Normal(mean, std) distribution = torch.distributions.Normal(mean, std)
distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector()) distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector())
distribution = torch.distributions.independent.Independent(distribution, 1) distribution = torch.distributions.independent.Independent(distribution, 1)
distribution = SampleDist(distribution)
return distribution return distribution
def add_exploration(self, action, action_noise=0.3):
return torch.clamp(torch.distributions.Normal(action, action_noise).rsample(), -1, 1)
def forward(self, features): def forward(self, features):
out = self.action_model(features) out = self.action_model(features)
mean, std = torch.chunk(out, 2, dim=-1) mean, std = torch.chunk(out, 2, dim=-1)
raw_init_std = np.log(np.exp(self._init_std) - 1) raw_init_std = torch.log(torch.exp(self._init_std) - 1)
action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale) action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale)
action_std = F.softplus(std + raw_init_std) + self._min_std action_std = F.softplus(std + raw_init_std) + self._min_std
dist = self.get_dist(action_mean, action_std) dist = self.get_dist(action_mean, action_std)
sample = dist.rsample() #self.reparameterize(action_mean, action_std) sample = dist.rsample()
return sample return sample
def reparameterize(self, mu, std):
eps = torch.randn_like(std)
return mu + eps * std
class ValueModel(nn.Module): class ValueModel(nn.Module):
def __init__(self, state_size, hidden_size, num_layers=4): def __init__(self, state_size, hidden_size, num_layers=4):
@ -148,12 +140,11 @@ class ValueModel(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
layers = [] layers = []
for i in range(self.num_layers-1): for i in range(self.num_layers):
input_channels = state_size if i == 0 else self.hidden_size input_channels = state_size if i == 0 else self.hidden_size
output_channels = self.hidden_size output_channels = self.hidden_size if i!= self.num_layers-1 else 1
layers.append(nn.Linear(input_channels, output_channels)) layers.append(nn.Linear(input_channels, output_channels))
layers.append(nn.LeakyReLU()) layers.append(nn.LeakyReLU())
layers.append(nn.Linear(self.hidden_size, int(np.prod(1))))
self.value_model = nn.Sequential(*layers) self.value_model = nn.Sequential(*layers)
def forward(self, state): def forward(self, state):
@ -178,7 +169,6 @@ class RewardModel(nn.Module):
return torch.distributions.independent.Independent( return torch.distributions.independent.Independent(
torch.distributions.Normal(reward, 1), 1) torch.distributions.Normal(reward, 1), 1)
"""
class TransitionModel(nn.Module): class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size): def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__() super().__init__()
@ -190,7 +180,6 @@ class TransitionModel(nn.Module):
self.act_fn = nn.LeakyReLU() self.act_fn = nn.LeakyReLU()
self.fc_state_action = nn.Linear(state_size + action_size, hidden_size) self.fc_state_action = nn.Linear(state_size + action_size, hidden_size)
self.ln = nn.LayerNorm(hidden_size)
self.history_cell = nn.GRUCell(hidden_size + history_size, history_size) self.history_cell = nn.GRUCell(hidden_size + history_size, history_size)
self.fc_state_prior = nn.Linear(history_size + state_size + action_size, 2 * state_size) self.fc_state_prior = nn.Linear(history_size + state_size + action_size, 2 * state_size)
self.fc_state_posterior = nn.Linear(history_size + state_size + action_size, 2 * state_size) self.fc_state_posterior = nn.Linear(history_size + state_size + action_size, 2 * state_size)
@ -205,25 +194,12 @@ class TransitionModel(nn.Module):
distribution = torch.distributions.independent.Independent(distribution, 1) distribution = torch.distributions.independent.Independent(distribution, 1)
return distribution return distribution
def stack_states(self, states, dim=0): def imagine_step(self, prev_state, prev_action, prev_history):
s = dict( state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))
mean = torch.stack([state['mean'] for state in states], dim=dim), prev_hist = prev_history.detach()
std = torch.stack([state['std'] for state in states], dim=dim), history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist)
sample = torch.stack([state['sample'] for state in states], dim=dim),
history = torch.stack([state['history'] for state in states], dim=dim),)
if 'distribution' in states:
dist = dict(distribution = [state['distribution'] for state in states])
s.update(dist)
return s
def seq_to_batch(self, state, name): state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1))
return dict(
sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:])))
def imagine_step(self, state, action, history):
state_action = self.ln(self.act_fn(self.fc_state_action(torch.cat([state, action], dim=-1))))
imag_hist = self.history_cell(torch.cat([state_action, history], dim=-1), history)
state_prior = self.fc_state_prior(torch.cat([imag_hist, state, action], dim=-1))
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
state_prior_std = F.softplus(state_prior_std) state_prior_std = F.softplus(state_prior_std)
@ -232,103 +208,19 @@ class TransitionModel(nn.Module):
# Sampling via reparameterization Trick # Sampling via reparameterization Trick
sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std) sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std)
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": imag_hist, "distribution": state_prior_dist} prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
return prior return prior
def imagine_rollout(self, state, action, history, horizon):
imagined_priors = []
for i in range(horizon):
prior = self.imagine_step(state, action, history)
state = prior["sample"]
history = prior["history"]
imagined_priors.append(prior)
imagined_priors = self.stack_states(imagined_priors, dim=0)
return imagined_priors
def observe_step(self, prev_state, prev_action, prev_history, nonterms):
state_action = self.ln(self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))))
current_history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history)
state_prior = self.fc_state_prior(torch.cat([prev_history, prev_state, prev_action], dim=-1))
state_prior_mean, state_prior_std = torch.chunk(state_prior*nonterms, 2, dim=-1)
state_prior_std = F.softplus(state_prior_std) + 0.1
sample_state_prior = state_prior_mean + torch.randn_like(state_prior_mean) * state_prior_std
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": current_history}
return prior
def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms):
observed_rollout = []
for i in range(rollout_states.shape[0]):
actions = rollout_actions[i] * nonterms[i]
prior = self.observe_step(rollout_states[i], actions, init_history, nonterms[i])
init_history = prior["history"]
observed_rollout.append(prior)
observed_rollout = self.stack_states(observed_rollout, dim=0)
return observed_rollout
def reparemeterize(self, mean, std):
eps = torch.randn_like(std)
return mean + eps * std
"""
class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__()
self.state_size = state_size
self.hidden_size = hidden_size
self.action_size = action_size
self.history_size = history_size
self.act_fn = nn.ELU()
self.fc_state_action = nn.Linear(state_size + action_size, hidden_size)
self.history_cell = nn.GRUCell(hidden_size, history_size)
self.fc_state_mu = nn.Linear(history_size + hidden_size, state_size)
self.fc_state_sigma = nn.Linear(history_size + hidden_size, state_size)
self.batch_norm = nn.BatchNorm1d(hidden_size)
self.batch_norm2 = nn.BatchNorm1d(state_size)
self.min_sigma = 1e-4
self.max_sigma = 1e0
def init_states(self, batch_size, device):
self.prev_state = torch.zeros(batch_size, self.state_size).to(device)
self.prev_action = torch.zeros(batch_size, self.action_size).to(device)
self.prev_history = torch.zeros(batch_size, self.history_size).to(device)
def get_dist(self, mean, std):
distribution = torch.distributions.Normal(mean, std)
distribution = torch.distributions.independent.Independent(distribution, 1)
return distribution
def stack_states(self, states, dim=0): def stack_states(self, states, dim=0):
s = dict( s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim), mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim), std = torch.stack([state['std'] for state in states], dim=dim),
sample = torch.stack([state['sample'] for state in states], dim=dim), sample = torch.stack([state['sample'] for state in states], dim=dim),
history = torch.stack([state['history'] for state in states], dim=dim),) history = torch.stack([state['history'] for state in states], dim=dim),)
if 'distribution' in states:
dist = dict(distribution = [state['distribution'] for state in states]) dist = dict(distribution = [state['distribution'] for state in states])
s.update(dist) s.update(dist)
return s return s
def seq_to_batch(self, state, name):
return dict(
sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:])))
def imagine_step(self, state, action, history):
next_state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([state, action], dim=-1))))
imag_history = self.history_cell(next_state_action_enc, history)
next_state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([next_state_action_enc, imag_history], dim=-1))))
next_state_sigma = torch.sigmoid(self.fc_state_sigma(torch.cat([next_state_action_enc, imag_history], dim=-1)))
next_state_sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * next_state_sigma
# Normal Distribution
next_state_dist = self.get_dist(next_state_mu, next_state_sigma)
next_state_sample = self.reparemeterize(next_state_mu, next_state_sigma)
prior = {"mean": next_state_mu, "std": next_state_sigma, "sample": next_state_sample, "history": imag_history, "distribution": next_state_dist}
return prior
def imagine_rollout(self, state, action, history, horizon): def imagine_rollout(self, state, action, history, horizon):
imagined_priors = [] imagined_priors = []
for i in range(horizon): for i in range(horizon):
@ -339,30 +231,8 @@ class TransitionModel(nn.Module):
imagined_priors = self.stack_states(imagined_priors, dim=0) imagined_priors = self.stack_states(imagined_priors, dim=0)
return imagined_priors return imagined_priors
def observe_step(self, prev_state, prev_action, prev_history):
state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))))
current_history = self.history_cell(state_action_enc, prev_history)
state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([state_action_enc, prev_history], dim=-1))))
state_sigma = F.softplus(self.fc_state_sigma(torch.cat([state_action_enc, prev_history], dim=-1)))
sample_state = state_mu + torch.randn_like(state_mu) * state_sigma
state_enc = {"mean": state_mu, "std": state_sigma, "sample": sample_state, "history": current_history}
return state_enc
def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms):
observed_rollout = []
for i in range(rollout_states.shape[0]):
rollout_states_ = rollout_states[i]
rollout_actions_ = rollout_actions[i]
init_history_ = nonterms[i] * init_history
state_enc = self.observe_step(rollout_states_, rollout_actions_, init_history_)
init_history = state_enc["history"]
observed_rollout.append(state_enc)
observed_rollout = self.stack_states(observed_rollout, dim=0)
return observed_rollout
def reparemeterize(self, mean, std): def reparemeterize(self, mean, std):
eps = torch.randn_like(mean) eps = torch.randn_like(std)
return mean + eps * std return mean + eps * std
@ -430,7 +300,6 @@ class ContrastiveHead(nn.Module):
return logits return logits
"""
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, last_states, current_states, negative_current_states, predicted_current_states): def __init__(self, last_states, current_states, negative_current_states, predicted_current_states):
super(CLUBSample, self).__init__() super(CLUBSample, self).__init__()
@ -444,7 +313,7 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses sample = state_dict["sample"] #dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses
mu = dist.mean mu = dist.mean
var = dist.variance var = dist.variance
return mu.detach(), var.detach(), sample.detach() return mu, var, sample
def loglikeli(self): def loglikeli(self):
_, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states)
@ -461,136 +330,15 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
random_index = torch.randperm(sample_size).long() random_index = torch.randperm(sample_size).long()
pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0) pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0)
#neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0) neg = (-(mu_curr - pred_sample[random_index])**2 /var_curr).sum(dim=1).mean(dim=0)
neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) #neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0)
upper_bound = pos - neg upper_bound = pos - neg
return upper_bound/2 return upper_bound/2
def learning_loss(self): def learning_loss(self):
return - self.loglikeli() return - self.loglikeli()
"""
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, x_dim, y_dim, hidden_size):
super(CLUBSample, self).__init__()
self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim))
self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim),
nn.Tanh())
def get_mu_logvar(self, x_samples):
mu = self.p_mu(x_samples)
logvar = self.p_logvar(x_samples)
return mu, logvar
def loglikeli(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples)
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
def forward(self, x_samples, y_samples, y_negatives):
mu, logvar = self.get_mu_logvar(x_samples)
sample_size = x_samples.shape[0]
#random_index = torch.randint(sample_size, (sample_size,)).long()
random_index = torch.randperm(sample_size).long()
positive = -(mu - y_samples)**2 / logvar.exp()
#negative = - (mu - y_samples[random_index])**2 / logvar.exp()
negative = -(mu - y_negatives)**2 / logvar.exp()
upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()
return upper_bound/2.
def learning_loss(self, x_samples, y_samples):
return -self.loglikeli(x_samples, y_samples)
class QFunction(nn.Module):
"""MLP for q-function."""
def __init__(self, obs_dim, action_dim, hidden_dim):
super().__init__()
self.trunk = nn.Sequential(
nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, obs, action):
assert obs.size(0) == action.size(0)
obs_action = torch.cat([obs, action], dim=1)
return self.trunk(obs_action)
class Critic(nn.Module):
"""Critic network, employes two q-functions."""
def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_feature_dim):
super().__init__()
self.Q1 = QFunction(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.Q2 = QFunction(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
# detach_encoder allows to stop gradient propogation to encoder
obs = self.encoder(obs, detach=detach_encoder)
q1 = self.Q1(obs, action)
q2 = self.Q2(obs, action)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
class SampleDist:
def __init__(self, dist, samples=100):
self._dist = dist
self._samples = samples
@property
def name(self):
return 'SampleDist'
def __getattr__(self, name):
return getattr(self._dist, name)
def mean(self):
sample = self._dist.rsample(self._samples)
return torch.mean(sample, 0)
def mode(self):
dist = self._dist.expand((self._samples, *self._dist.batch_shape))
sample = dist.rsample()
logprob = dist.log_prob(sample)
batch_size = sample.size(1)
feature_size = sample.size(2)
indices = torch.argmax(logprob, dim=0).reshape(1, batch_size, 1).expand(1, batch_size, feature_size)
return torch.gather(sample, 0, indices).squeeze(0)
def entropy(self):
dist = self._dist.expand((self._samples, *self._dist.batch_shape))
sample = dist.rsample()
logprob = dist.log_prob(sample)
return -torch.mean(logprob, 0)
def sample(self):
return self._dist.sample()
if "__name__ == __main__": if "__name__ == __main__":
tr = TransitionModel(50, 512, 1, 256) pass

View File

@ -1,51 +0,0 @@
import torch
import numpy as np
class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size):
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,), dtype=np.float32)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
def add(self, obs, ac, next_obs, rew, done):
self.observations[self.idx] = obs
self.next_observations[self.idx] = next_obs
self.actions[self.idx] = ac
self.rewards[self.idx] = rew
self.terminals[self.idx] = done
self.idx = (self.idx + 1) % self.size
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
def _sample_idx(self, L):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.size if self.full else self.idx - L)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = not self.idx in idxs[1:]
return idxs
def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs]
next_observations = self.next_observations[vec_idxs]
return observations.reshape(L, n, *observations.shape[1:]),self.actions[vec_idxs].reshape(L, n, -1), next_observations.reshape(L, n, *next_observations.shape[1:]), self.rewards[vec_idxs].reshape(L, n), self.terminals[vec_idxs].reshape(L, n)
def sample(self):
n = self.batch_size
l = self.seq_len
obs,acs,nxt_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,nxt_obs,rews,terms

View File

@ -6,12 +6,11 @@ import wandb
import random import random
import argparse import argparse
import numpy as np import numpy as np
from collections import OrderedDict
import utils import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image, shuffle_along_axis, Logger from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image
from replay_buffer import ReplayBuffer
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
@ -41,22 +40,19 @@ def parse_args():
parser.add_argument('--resource_files', type=str) parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=5000, type=int) # 10000 parser.add_argument('--total_frames', default=1000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true') parser.add_argument('--high_noise', action='store_true')
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000 parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=51, type=int) parser.add_argument('--episode_length', default=51, type=int)
# train # train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=5000, type=int) parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=100000, type=int) parser.add_argument('--num_train_steps', default=10000, type=int)
parser.add_argument('--update_steps', default=10, type=int) parser.add_argument('--batch_size', default=30, type=int) #512
parser.add_argument('--batch_size', default=64, type=int) parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--state_size', default=100, type=int) parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--hidden_size', default=512, type=int) parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--history_size', default=256, type=int)
parser.add_argument('--episode_collection', default=5, type=int)
parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer')
parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--imagine_horizon', default=15, type=str)
@ -64,33 +60,42 @@ def parse_args():
# eval # eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
# value # value
parser.add_argument('--value_lr', default=8e-6, type=float) parser.add_argument('--value_lr', default=8e-5, type=float)
parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int) parser.add_argument('--td_lambda', default=0.95, type=int)
# actor # actor
parser.add_argument('--actor_lr', default=8e-6, type=float) parser.add_argument('--actor_lr', default=8e-5, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int) parser.add_argument('--actor_update_freq', default=2, type=int)
# world/encoder/decoder # world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--world_model_lr', default=1e-6, type=float) parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--decoder_lr', default=6e-6, type=float) parser.add_argument('--world_model_lr', default=6e-4, type=float)
parser.add_argument('--reward_lr', default=8e-6, type=float) parser.add_argument('--past_transition_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.001, type=float)
parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true') parser.add_argument('--aug', action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int) parser.add_argument('--logging_freq', default=100, type=int)
parser.add_argument('--saving_interval', default=2500, type=int) parser.add_argument('--saving_interval', default=1000, type=int)
parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true')
@ -121,7 +126,6 @@ class DPI:
self.env = make_env(self.args) self.env = make_env(self.args)
#self.args.seed = np.random.randint(0, 1000) #self.args.seed = np.random.randint(0, 1000)
self.env.seed(self.args.seed) self.env.seed(self.args.seed)
self.global_episodes = 0
# noiseless environment setup # noiseless environment setup
self.args.version = 2 # env_id changes to v2 self.args.version = 2 # env_id changes to v2
@ -133,14 +137,14 @@ class DPI:
self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat) self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env) self.env = utils.NormalizeActions(self.env)
self.env = utils.TimeLimit(self.env, 1000 // args.action_repeat)
# create replay buffer # create replay buffer
self.data_buffer = ReplayBuffer(self.args.replay_buffer_capacity, self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
self.env.observation_space.shape, obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
self.env.action_space.shape[0], action_size=self.env.action_space.shape[0],
self.args.episode_length, seq_len=self.args.episode_length,
self.args.batch_size) batch_size=args.batch_size,
args=self.args)
# create work directory # create work directory
utils.make_dir(self.args.work_dir) utils.make_dir(self.args.work_dir)
@ -157,19 +161,16 @@ class DPI:
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
).to(device) ).to(device)
self.obs_encoder.apply(self.init_weights)
self.obs_encoder_momentum = ObservationEncoder( self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
).to(device) ).to(device)
self.obs_encoder_momentum.apply(self.init_weights)
self.obs_decoder = ObservationDecoder( self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
output_shape=(self.args.channels*self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
).to(device) ).to(device)
self.obs_decoder.apply(self.init_weights)
self.transition_model = TransitionModel( self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
@ -177,7 +178,6 @@ class DPI:
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
).to(device) ).to(device)
self.transition_model.apply(self.init_weights)
# Actor Model # Actor Model
self.actor_model = Actor( self.actor_model = Actor(
@ -185,27 +185,22 @@ class DPI:
hidden_size=self.args.hidden_size, # 256, hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
).to(device) ).to(device)
self.actor_model.apply(self.init_weights)
# Value Models # Value Models
self.value_model = ValueModel( self.value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.value_model.apply(self.init_weights)
self.target_value_model = ValueModel( self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.target_value_model.apply(self.init_weights)
self.reward_model = RewardModel( self.reward_model = RewardModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.reward_model.apply(self.init_weights)
# Contrastive Models # Contrastive Models
self.prjoection_head = ProjectionHead( self.prjoection_head = ProjectionHead(
@ -224,32 +219,23 @@ class DPI:
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
).to(device) ).to(device)
self.club_sample = CLUBSample(
x_dim=self.args.state_size, # 128
y_dim=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
# model parameters # model parameters
self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \ self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + \
list(self.transition_model.parameters()) + list(self.club_sample.parameters()) + \ list(self.value_model.parameters()) + list(self.transition_model.parameters()) + \
list(self.contrastive_head.parameters()) list(self.prjoection_head.parameters())
self.past_transition_parameters = self.transition_model.parameters()
# optimizers # optimizers
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6) self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6) self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6) self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr)
self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6) self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr)
self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6)
# Create Modules # Create Modules
self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head, self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.reward_model, self.transition_model, self.prjoection_head]
self.obs_encoder_momentum, self.prjoection_head_momentum]
self.value_modules = [self.value_model] self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model] self.actor_modules = [self.actor_model]
self.decoder_modules = [self.obs_decoder]
self.reward_modules = [self.reward_model]
if use_saved: if use_saved:
self._use_saved_models(saved_model_dir) self._use_saved_models(saved_model_dir)
@ -259,432 +245,280 @@ class DPI:
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_random_sequences(self, seed_steps): def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None):
obs = self.env.reset() obs = self.env.reset()
done = False done = False
all_rews = [] all_rews = []
self.global_episodes += 1 #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
if args.save_video:
self.env.video.init(enabled=True)
epi_reward = 0 epi_reward = 0
for _ in tqdm.tqdm(range(seed_steps), desc='Collecting episodes'): for i in range(self.args.episode_length):
if random:
action = self.env.action_space.sample() action = self.env.action_space.sample()
else:
with torch.no_grad():
obs_torch = torch.unsqueeze(torch.tensor(obs).float(),0).to(device)
state = self.obs_encoder(obs_torch)["distribution"].sample()
action = self.actor_model(state).cpu().detach().numpy().squeeze()
next_obs, rew, done, _ = self.env.step(action) next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done) self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done)
obs = next_obs
epi_reward += rew if args.save_video:
if done: self.env.video.record(self.env)
if done or i == self.args.episode_length-1:
obs = self.env.reset() obs = self.env.reset()
done=False done=False
all_rews.append(epi_reward)
epi_reward = 0
return all_rews
def collect_sequences(self, collect_steps, actor_model):
obs = self.env.reset()
done = False
all_rews = []
self.global_episodes += 1
epi_reward = 0
for episode_count in tqdm.tqdm(range(collect_steps), desc='Collecting episodes'):
with torch.no_grad():
obs_ = torch.tensor(obs.copy(), dtype=torch.float32)
obs_ = preprocess_obs(obs_).to(device)
#state = self.get_features(obs_)["sample"].unsqueeze(0)
state = self.get_features(obs_)["distribution"].rsample()
action = actor_model(state)
action = actor_model.add_exploration(action)
action = action.cpu().numpy()[0]
next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done)
if done:
obs = self.env.reset()
done = False
all_rews.append(epi_reward)
epi_reward = 0
else: else:
obs = next_obs obs = next_obs
epi_reward += rew epi_reward += rew
all_rews.append(epi_reward)
if args.save_video:
self.env.video.save('noisy/%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1))
return all_rews return all_rews
def train(self, step, total_steps): def train(self, step, total_steps):
# logger counter = 0
logdir = os.path.dirname(os.path.realpath(__file__)) + "/log/logs/" while step < total_steps:
if not(os.path.exists(logdir)):
os.makedirs(logdir)
initial_logs = OrderedDict()
logger = Logger(logdir)
episodic_rews = self.collect_random_sequences(self.args.init_steps//args.action_repeat)
self.global_step = self.data_buffer.steps
initial_logs.update({
'train_avg_reward':np.mean(episodic_rews),
'train_max_reward': np.max(episodic_rews),
'train_min_reward': np.min(episodic_rews),
'train_std_reward':np.std(episodic_rews),
})
logger.log_scalars(initial_logs, step=0)
logger.flush()
while self.global_step < total_steps:
logs = OrderedDict()
step += 1
for update_steps in range(self.args.update_steps):
model_loss, actor_loss, value_loss, actor_model = self.update((step-1)*args.update_steps + update_steps)
initial_logs.update({
'model_loss' : model_loss,
'actor_loss': actor_loss,
'value_loss': value_loss,
'train_avg_reward':np.mean(episodic_rews),
'train_max_reward': np.max(episodic_rews),
'train_min_reward': np.min(episodic_rews),
'train_std_reward':np.std(episodic_rews),
})
logger.log_scalars(logs, self.global_step)
print("########## Global Step:", self.global_step, " ##########")
for key, value in initial_logs.items():
print(key, " : ", value)
episodic_rews = self.collect_sequences(1000//self.args.action_repeat, actor_model)
if self.global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
print("Saving model")
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path)
self.evaluate()
self.global_step = self.data_buffer.steps * self.args.action_repeat
"""
# collect experience # collect experience
if step !=0: if step !=0:
encoder = self.obs_encoder encoder = self.obs_encoder
actor = self.actor_model actor = self.actor_model
all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder) #all_rews = self.collect_sequences(self.args.batch_size, random=True)
""" all_rews = self.collect_sequences(self.args.batch_size, random=False, actor_model=actor, encoder_model=encoder)
else:
all_rews = self.collect_sequences(self.args.batch_size, random=True)
def collect_batch(self): # Group by steps and sample random batch
obs_, acs_, nxt_obs_, rews_, terms_ = self.data_buffer.sample() random_indices = self.data_buffer.sample_random_idx(self.args.batch_size * ((step//self.args.collection_interval)+1)) # random indices for batch
#random_indices = np.arange(self.args.batch_size * ((step//self.args.collection_interval)),self.args.batch_size * ((step//self.args.collection_interval)+1))
last_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"observations", "cpu", random_indices=random_indices)
current_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", random_indices=random_indices)
next_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", offset=1, random_indices=random_indices)
actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, random_indices=random_indices)
next_actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, offset=1, random_indices=random_indices)
rewards = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"rewards", device=device, is_obs=False, offset=1, random_indices=random_indices)
obs = torch.tensor(obs_, dtype=torch.float32)[1:] # Preprocessing
last_obs = torch.tensor(obs_, dtype=torch.float32)[:-1] last_observations = preprocess_obs(last_observations).to(device)
nxt_obs = torch.tensor(nxt_obs_, dtype=torch.float32)[1:] current_observations = preprocess_obs(current_observations).to(device)
acs = torch.tensor(acs_, dtype=torch.float32)[:-1].to(device) next_observations = preprocess_obs(next_observations).to(device)
nxt_acs = torch.tensor(acs_, dtype=torch.float32)[1:].to(device)
rews = torch.tensor(rews_, dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
nonterms = torch.tensor((1.0-terms_), dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
last_obs = preprocess_obs(last_obs).to(device) # Initialize transition model states
obs = preprocess_obs(obs).to(device) self.transition_model.init_states(self.args.batch_size, device) # (N,128)
nxt_obs = preprocess_obs(nxt_obs).to(device) self.history = self.transition_model.prev_history # (N,128)
return last_obs, obs, nxt_obs, acs, rews, nxt_acs, nonterms # Train encoder
if step == 0:
step += 1
for _ in range(self.args.collection_interval // self.args.episode_length+1):
counter += 1
for i in range(self.args.episode_length-1):
if i > 0:
# Encode observations and next_observations
self.last_states_dict = self.get_features(last_observations[i])
self.current_states_dict = self.get_features(current_observations[i])
self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6)
self.next_action = next_actions[i] # (N,6)
history = self.transition_model.prev_history
def update(self, step): # Encode negative observations
last_observations, current_observations, next_observations, actions, rewards, next_actions, nonterms = self.collect_batch() idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
#last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch() # Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations, # Calculate encoder loss
current_observations, encoder_loss = self._past_encoder_loss(self.current_states_dict,
next_observations, predicted_current_state_dict)
actions,
next_actions, # contrastive projection
rewards, vec_anchor = predicted_current_state_dict["sample"]
nonterms) vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels)
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.next_action, self.history.detach(),
imagine_horizon)
# decoder loss
horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1]))
# update models
world_model_loss = encoder_loss + 100 * ub_loss + lb_loss + reward_loss + decoder_loss * 1e-2
self.world_model_opt.zero_grad() self.world_model_opt.zero_grad()
world_loss.backward() world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step() self.world_model_opt.step()
self.decoder_opt.zero_grad() # update momentum encoder
dec_loss.backward() soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
nn.utils.clip_grad_norm_(self.obs_decoder.parameters(), self.args.grad_clip_norm)
self.decoder_opt.step()
self.reward_opt.zero_grad() # update momentum projection head
rew_loss.backward() soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
nn.utils.clip_grad_norm_(self.reward_model.parameters(), self.args.grad_clip_norm)
self.reward_opt.step()
actor_loss = self.actor_model_losses() # actor loss
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(imagined_rollout["sample"])
target_imag_val_dist = self.target_value_model(imagined_rollout["sample"])
imag_rews = imag_rew_dist.mean
target_imag_vals = target_imag_val_dist.mean
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
self.target_returns = self._compute_lambda_return(imag_rews[:-1],
target_imag_vals[:-1],
discounts[:-1] ,
self.args.td_lambda,
target_imag_vals[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.target_returns)
# update actor
self.actor_opt.zero_grad() self.actor_opt.zero_grad()
actor_loss.backward() actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step() self.actor_opt.step()
value_loss = self.value_model_losses() # value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.target_returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
# update value
self.value_opt.zero_grad() self.value_opt.zero_grad()
value_loss.backward() value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step() self.value_opt.step()
# update momentum encoder and projection head # update target value
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau) if step % self.args.value_target_update_freq == 0:
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) self.target_value_model = copy.deepcopy(self.value_model)
# counter for reward
count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size))
# update target value networks
#if step % self.args.value_target_update_freq == 0:
# self.target_value_model = copy.deepcopy(self.value_model)
if step % self.args.logging_freq: if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step) writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Encoder Loss', enc_loss.detach().item(), step) writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Decoder Loss', dec_loss.detach().item(), step) writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step) writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', -lb_loss.detach().item(), step) writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step)
return world_loss.item(), actor_loss.item(), value_loss.item(), self.actor_model step += 1
if step>total_steps:
print("Training finished")
break
def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards, nonterms): # save model
# get features if step % self.args.saving_interval == 0:
self.last_state_feat = self.get_features(last_obs) path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.curr_state_feat = self.get_features(curr_obs) self.save_models(path)
self.nxt_state_feat = self.get_features(nxt_obs)
self.nxt_state_feat_lb = self.get_features(nxt_obs, momentum=True)
# states #torch.cuda.empty_cache() # memory leak issues
self.last_state_enc = self.last_state_feat["distribution"].rsample() #self.last_state_feat["sample"]
self.curr_state_enc = self.curr_state_feat["distribution"].rsample() #self.curr_state_feat["sample"]
self.nxt_state_enc = self.nxt_state_feat["distribution"].rsample() #self.nxt_state_feat["sample"]
self.nxt_state_enc_lb = self.nxt_state_feat_lb["distribution"].rsample() #self.nxt_state_feat_lb["sample"]
# predict next states for j in range(len(all_rews)):
self.transition_model.init_states(self.args.batch_size, device) # (N,128) writer.add_scalar('Rewards/Rewards', all_rews[j], count[j])
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms)
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
self.pred_curr_state_enc = self.pred_curr_state_dist.rsample() #self.observed_rollout["sample"]
# encoder loss
enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
# reward loss def evaluate(self, env, eval_episodes, render=False):
rew_dist = self.reward_model(self.curr_state_enc.detach())
#print(torch.cat([rew_dist.mean[0], rewards[0]],dim=-1))
rew_loss = -torch.mean(rew_dist.log_prob(rewards))
# decoder loss episode_rew = np.zeros((eval_episodes))
dec_dist = self.obs_decoder(self.nxt_state_enc.detach())
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
# upper bound loss video_images = [[] for _ in range(eval_episodes)]
past_ub_loss = 0
for i in range(self.curr_state_enc.shape[0]):
_, ub_loss = self._upper_bound_minimization(self.curr_state_enc[i],
self.pred_curr_state_enc[i])
ub_loss = ub_loss + past_ub_loss
past_ub_loss = ub_loss
ub_loss = ub_loss / self.curr_state_enc.shape[0]
ub_loss = 1 * ub_loss
# lower bound loss for i in range(eval_episodes):
# contrastive projection obs = env.reset()
vec_anchor = self.pred_curr_state_enc.detach()
vec_positive = self.nxt_state_enc_lb.detach()
z_anchor = self.prjoection_head(vec_anchor, nxt_actions)
z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions)
# contrastive loss
past_lb_loss = 0
for i in range(z_anchor.shape[0]):
logits = self.contrastive_head(z_anchor[i], z_positive[i])
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels) + past_lb_loss
past_lb_loss = lb_loss.detach().item()
lb_loss = -0.01 * lb_loss/(z_anchor.shape[0])
world_loss = enc_loss + ub_loss + lb_loss
return world_loss, enc_loss , rew_loss, dec_loss, ub_loss, lb_loss
def actor_model_losses(self):
with torch.no_grad():
#curr_state_enc = self.curr_state_enc.reshape(self.args.episode_length-1,-1) #self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"]
#curr_state_hist = self.observed_rollout["history"].reshape(self.args.episode_length-1,-1) #self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"]
curr_state_enc = self.curr_state_enc.reshape(-1, self.args.state_size)
curr_state_hist = self.observed_rollout["history"].reshape(-1, self.args.history_size)
with FreezeParameters(self.world_model_modules + self.decoder_modules + self.reward_modules + self.value_modules):
imagine_horizon = self.args.imagine_horizon
action = self.actor_model(curr_state_enc.detach())
self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc,
action, curr_state_hist.detach(),
imagine_horizon)
self.pred_nxt_state_dist = self.transition_model.get_dist(self.imagined_rollout["mean"], self.imagined_rollout["std"])
self.pred_nxt_state_enc = self.pred_nxt_state_dist.rsample() #self.transition_model.reparemeterize(self.imagined_rollout["mean"], self.imagined_rollout["std"])
with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules):
imag_rewards_dist = self.reward_model(self.pred_nxt_state_enc)
imag_values_dist = self.value_model(self.pred_nxt_state_enc)
imag_rewards = imag_rewards_dist.mean
imag_values = imag_values_dist.mean
#print(torch.cat([imag_rewards[0], imag_values[0]],dim=-1))
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1],
imag_values[:-1],
discounts[:-1] ,
self.args.td_lambda,
imag_values[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.returns)
return actor_loss
def value_model_losses(self):
with torch.no_grad():
value_feat = self.pred_nxt_state_enc[:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
return value_loss
def select_one_batch(self):
# collect sequences
non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0]
current_obs = self.data_buffer.observations[non_zero_indices]
next_obs = self.data_buffer.next_observations[non_zero_indices]
actions_raw = self.data_buffer.actions[non_zero_indices]
rewards = self.data_buffer.rewards[non_zero_indices]
self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=False)[0]
# group by episodes
current_obs = self.grouped_arrays(current_obs)
next_obs = self.grouped_arrays(next_obs)
actions_raw = self.grouped_arrays(actions_raw)
rewards_ = self.grouped_arrays(rewards)
# select random chunks of episodes
if current_obs.shape[0] < self.args.batch_size:
random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size)
else:
random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size)
# select random starting points
if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size:
init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size)
else:
init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size))
# shuffle
random.shuffle(random_episode_number)
random.shuffle(init_index)
# select first k elements
last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1]
current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:]
next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1]
actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device)
next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device)
rewards = self.select_first_k(rewards_, init_index, random_episode_number)[:-1].to(device)
# preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
return last_observations, current_observations, next_observations, actions, next_actions, rewards
def evaluate(self, eval_episodes=10):
path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.restore_checkpoint(path)
obs = self.env.reset()
done = False done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
#video = VideoRecorder(self.video_dir, resource_files=self.args.resource_files)
if self.args.save_video:
self.env.video.init(enabled=True)
episodic_rewards = []
for episode in range(eval_episodes):
rewards = 0
done = False
while not done: while not done:
with torch.no_grad(): with torch.no_grad():
obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0) posterior, action = self.act_with_world_model(obs, prev_state, prev_action)
obs_processed = preprocess_obs(obs).to(device) action = action[0].cpu().numpy()
state = self.get_features(obs_processed)["distribution"].rsample() next_obs, rew, done, _ = env.step(action)
action = self.actor_model(state).cpu().detach().numpy().squeeze() prev_state = posterior
next_obs, rew, done, _ = self.env.step(action) prev_action = torch.tensor(action, dtype=torch.float32).to(self.device).unsqueeze(0)
rewards += rew
episode_rew[i] += rew
if render:
video_images[i].append(obs['image'].transpose(1,2,0).copy())
obs = next_obs obs = next_obs
return episode_rew, np.array(video_images[:self.args.max_videos_to_save])
if self.args.save_video: def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states):
self.env.video.record(self.env) club_sample = CLUBSample(last_states,
self.env.video.save('/home/vedant/Curiosity/Curiosity/DPI/log/video/learned_model.mp4') current_states,
obs = self.env.reset() negative_current_states,
episodic_rewards.append(rewards) predicted_current_states)
print("Episodic rewards: ", episodic_rewards) likelihood_loss = club_sample.learning_loss()
print("Average episodic reward: ", np.mean(episodic_rewards)) club_loss = club_sample()
def init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def grouped_arrays(self,array):
indices = [0] + self.terms.tolist()
def subarrays():
for start, end in zip(indices[:-1], indices[1:]):
yield array[start:end]
try:
subarrays = np.stack(list(subarrays()), axis=0)
except ValueError:
subarrays = np.asarray(list(subarrays()))
return subarrays
def select_first_k(self, array, init_index, episode_number):
term_index = init_index + self.args.episode_length
array = array[episode_number]
array_list = []
for i in range(array.shape[0]):
array_list.append(array[i][init_index[i]:term_index[i]])
array = np.asarray(array_list)
if array.ndim == 5:
transposed_array = np.transpose(array, (1, 0, 2, 3, 4))
elif array.ndim == 4:
transposed_array = np.transpose(array, (1, 0, 2, 3))
elif array.ndim == 3:
transposed_array = np.transpose(array, (1, 0, 2))
elif array.ndim == 2:
transposed_array = np.transpose(array, (1, 0))
else:
transposed_array = np.expand_dims(array, axis=0)
#return torch.tensor(array).float()
return torch.tensor(transposed_array).float()
def _upper_bound_minimization(self, current_states, predicted_current_states):
current_negative_states = shuffle_along_axis(current_states.clone(), axis=0)
current_negative_states = shuffle_along_axis(current_negative_states, axis=1)
club_loss = self.club_sample(current_states, predicted_current_states, current_negative_states)
likelihood_loss = 0
return likelihood_loss, club_loss return likelihood_loss, club_loss
def _encoder_loss(self, curr_states_dist, predicted_curr_states_dist): def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict):
# current state distribution
curr_states_dist = curr_states_dict["distribution"]
# predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss # KL divergence loss
loss = torch.mean(torch.distributions.kl.kl_divergence(curr_states_dist,predicted_curr_states_dist)) loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean()
return loss return loss
def get_features(self, x, momentum=False): def get_features(self, x, momentum=False):
if self.args.aug: if self.args.aug:
crop_transform = T.RandomCrop(size=80) x = T.RandomCrop((80, 80))(x) # (None,80,80,4)
cropped_x = torch.stack([crop_transform(x[i]) for i in range(x.size(0))]) x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4)
padding = (2, 2, 2, 2) x = T.RandomCrop((84, 84))(x) # (None,84,84,4)
x = F.pad(cropped_x, padding)
with torch.no_grad(): with torch.no_grad():
if momentum: if momentum:
@ -694,19 +528,6 @@ class DPI:
return x return x
def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value): def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value):
next_values = torch.cat([values[1:], last_value[None]], 0)
target = rewards + discounts * next_values * (1 - td_lam)
timesteps = list(range(rewards.shape[0] - 1, -1, -1))
outputs = []
accumulated_reward = last_value
for t in timesteps:
inp = target[t]
discount_factor = discounts[t]
accumulated_reward = inp + discount_factor * td_lam * accumulated_reward
outputs.append(accumulated_reward)
returns = torch.flip(torch.stack(outputs), [0])
return returns
"""
next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0) next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0)
targets = rewards + discounts * next_values * (1-td_lam) targets = rewards + discounts * next_values * (1-td_lam)
rets =[] rets =[]
@ -718,25 +539,6 @@ class DPI:
returns = torch.flip(torch.stack(rets), [0]) returns = torch.flip(torch.stack(rets), [0])
return returns return returns
"""
def lambda_return(self,imged_reward, value_pred, bootstrap, discount=0.99, lambda_=0.95):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
next_values = torch.cat([value_pred[1:], bootstrap[None]], 0)
discount_tensor = discount * torch.ones_like(imged_reward) # pcont
inputs = imged_reward + discount_tensor * next_values * (1 - lambda_)
last = bootstrap
indices = reversed(range(len(inputs)))
outputs = []
for index in indices:
inp, disc = inputs[index], discount_tensor[index]
last = inp + disc * lambda_ * last
outputs.append(last)
outputs = list(reversed(outputs))
outputs = torch.stack(outputs, 0)
returns = outputs
return returns
def save_models(self, save_path): def save_models(self, save_path):
torch.save( torch.save(
@ -749,17 +551,6 @@ class DPI:
'value_optimizer': self.value_opt.state_dict(), 'value_optimizer': self.value_opt.state_dict(),
'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path) 'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path)
def restore_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path)
self.transition_model.load_state_dict(checkpoint['rssm'])
self.actor_model.load_state_dict(checkpoint['actor'])
self.reward_model.load_state_dict(checkpoint['reward_model'])
self.obs_encoder.load_state_dict(checkpoint['obs_encoder'])
self.obs_decoder.load_state_dict(checkpoint['obs_decoder'])
self.world_model_opt.load_state_dict(checkpoint['world_model_optimizer'])
self.actor_opt.load_state_dict(checkpoint['actor_optimizer'])
self.value_opt.load_state_dict(checkpoint['value_optimizer'])
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
@ -769,7 +560,6 @@ if __name__ == '__main__':
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
step = 0 step = 0
total_steps = 2000000 total_steps = 10000
dpi = DPI(args) dpi = DPI(args)
dpi.train(step,total_steps) dpi.train(step,total_steps)
dpi.evaluate()

View File

@ -1,13 +1,10 @@
import os import os
import random import random
import pickle
import numpy as np import numpy as np
from collections import deque from collections import deque
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import gym import gym
import dmc2gym import dmc2gym
@ -63,6 +60,17 @@ def make_dir(dir_path):
return dir_path return dir_path
def preprocess_obs(obs, bits=5):
"""Preprocessing image, see https://arxiv.org/abs/1807.03039."""
bins = 2**bits
assert obs.dtype == torch.float32
if bits < 8:
obs = torch.floor(obs / 2**(8 - bits))
obs = obs / bins
obs = obs + torch.rand_like(obs) / bins
obs = obs - 0.5
return obs
class FrameStack(gym.Wrapper): class FrameStack(gym.Wrapper):
def __init__(self, env, k): def __init__(self, env, k):
@ -136,90 +144,8 @@ class NormalizeActions:
original = np.where(self._mask, original, action) original = np.where(self._mask, original, action)
return self._env.step(original) return self._env.step(original)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, 'Must reset environment.'
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if 'discount' not in info:
info['discount'] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
class ReplayBuffer: class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,), dtype=np.float32)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
self.episode_count = np.zeros((size,), dtype=np.int32)
def add(self, obs, ac, next_obs, rew, done, episode_count):
self.observations[self.idx] = obs
self.next_observations[self.idx] = next_obs
self.actions[self.idx] = ac
self.rewards[self.idx] = rew
self.terminals[self.idx] = done
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
self.episode_count[self.idx] = episode_count
self.idx = (self.idx + 1) % self.size
def _sample_idx(self, L):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.size if self.full else self.idx - L)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = not self.idx in idxs[1:]
return idxs
def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs]
next_obs = self.next_observations[vec_idxs]
obs = observations.reshape(L, n, *observations.shape[1:])
next_obs = next_obs.reshape(L, n, *next_obs.shape[1:])
acs = self.actions[vec_idxs].reshape(L, n, -1)
rew = self.rewards[vec_idxs].reshape(L, n)
term = self.terminals[vec_idxs].reshape(L, n)
return obs, acs, next_obs, rew, term
def sample(self):
n = self.batch_size
l = self.seq_len
obs,acs,next_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,next_obs,rews,terms
class ReplayBuffer1:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size self.size = size
self.obs_shape = obs_shape self.obs_shape = obs_shape
@ -273,11 +199,7 @@ class ReplayBuffer1:
def group_steps(self, buffer, variable, obs=True): def group_steps(self, buffer, variable, obs=True):
variable = getattr(buffer, variable) variable = getattr(buffer, variable)
non_zero_indices = np.nonzero(buffer.episode_count)[0] non_zero_indices = np.nonzero(buffer.episode_count)[0]
print(buffer.episode_count)
variable = variable[non_zero_indices] variable = variable[non_zero_indices]
print(variable.shape)
exit()
if obs: if obs:
variable = variable.reshape(-1, self.args.episode_length, variable = variable.reshape(-1, self.args.episode_length,
self.args.frame_stack*self.args.channels, self.args.frame_stack*self.args.channels,
@ -292,9 +214,8 @@ class ReplayBuffer1:
self.args.image_size,self.args.image_size) self.args.image_size,self.args.image_size)
return variable return variable
def sample_random_idx(self, buffer_length, last=False): def sample_random_idx(self, buffer_length):
init = 0 if last else buffer_length - self.args.batch_size random_indices = random.sample(range(0, buffer_length), self.args.batch_size)
random_indices = random.sample(range(init, buffer_length), self.args.batch_size)
return random_indices return random_indices
def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0): def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0):
@ -326,23 +247,19 @@ def make_env(args):
) )
return env return env
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a,idx,axis=axis)
def preprocess_obs(obs): def preprocess_obs(obs):
obs = (obs/255.0) - 0.5 obs = obs/255.0 - 0.5
return obs return obs
def soft_update_params(net, target_net, tau): def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()): for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_( target_param.data.copy_(
tau * param.detach().data + (1 - tau) * target_param.data tau * param.data + (1 - tau) * target_param.data
) )
def save_image(array, filename): def save_image(array, filename):
array = array.transpose(1, 2, 0) array = array.transpose(1, 2, 0)
array = ((array+0.5) * 255).astype(np.uint8) array = (array * 255).astype(np.uint8)
image = Image.fromarray(array) image = Image.fromarray(array)
image.save(filename) image.save(filename)
@ -363,20 +280,6 @@ def video_from_array(arr, high_noise, filename):
out.write(frame) out.write(frame)
out.release() out.release()
def save_video(images):
"""
Image shape is (T, C, H, W)
Example:(50, 3, 84, 84)
"""
output_file = "output.avi"
fourcc = cv2.VideoWriter_fourcc(*'XVID')
fps = 2
height, width, channels = 84,84,3
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
for image in images:
image = np.uint8(image.transpose((1, 2, 0)))
out.write(image)
out.release()
class CorruptVideos: class CorruptVideos:
def __init__(self, dir_path): def __init__(self, dir_path):
@ -450,51 +353,3 @@ class FreezeParameters:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)): for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i] param.requires_grad = self.param_states[i]
class Logger:
def __init__(self, log_dir, n_logged_samples=10, summary_writer=None):
self._log_dir = log_dir
print('########################')
print('logging outputs to ', log_dir)
print('########################')
self._n_logged_samples = n_logged_samples
self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1)
def log_scalar(self, scalar, name, step_):
self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
def log_scalars(self, scalar_dict, step):
for key, value in scalar_dict.items():
print('{} : {}'.format(key, value))
self.log_scalar(value, key, step)
self.dump_scalars_to_pickle(scalar_dict, step)
def log_videos(self, videos, step, max_videos_to_save=1, fps=20, video_title='video'):
# max rollout length
max_videos_to_save = np.min([max_videos_to_save, videos.shape[0]])
max_length = videos[0].shape[0]
for i in range(max_videos_to_save):
if videos[i].shape[0]>max_length:
max_length = videos[i].shape[0]
# pad rollouts to all be same length
for i in range(max_videos_to_save):
if videos[i].shape[0]<max_length:
padding = np.tile([videos[i][-1]], (max_length-videos[i].shape[0],1,1,1))
videos[i] = np.concatenate([videos[i], padding], 0)
clip = mpy.ImageSequenceClip(list(videos[i]), fps=fps)
new_video_title = video_title+'{}_{}'.format(step, i) + '.gif'
filename = os.path.join(self._log_dir, new_video_title)
video.write_gif(filename, fps =fps)
def dump_scalars_to_pickle(self, metrics, step, log_title=None):
log_path = os.path.join(self._log_dir, "scalar_data.pkl" if log_title is None else log_title)
with open(log_path, 'ab') as f:
pickle.dump({'step': step, **dict(metrics)}, f)
def flush(self):
self._summ_writer.flush()