DB
This commit is contained in:
commit
8ebc277814
60
README.md
Normal file
60
README.md
Normal file
@ -0,0 +1,60 @@
|
||||
# Dynamic Bottleneck
|
||||
|
||||
## Introduction
|
||||
|
||||
This is a TensorFlow based implementation for our paper on
|
||||
|
||||
**"Dynamic Bottleneck for Robust Self-Supervised Exploration". NeurIPS 2021**
|
||||
|
||||
## Prerequisites
|
||||
|
||||
python3.6 or 3.7,
|
||||
tensorflow-gpu 1.x, tensorflow-probability,
|
||||
openAI [baselines](https://github.com/openai/baselines),
|
||||
openAI [Gym](http://gym.openai.com/)
|
||||
|
||||
## Installation and Usage
|
||||
|
||||
### Atari games
|
||||
|
||||
The following command should train a pure exploration
|
||||
agent on "Breakout" with default experiment parameters.
|
||||
|
||||
```
|
||||
python run.py --env BreakoutNoFrameskip-v4
|
||||
```
|
||||
|
||||
|
||||
### Atari games with Random-Box noise
|
||||
|
||||
The following command should train a pure exploration
|
||||
agent on "Breakout" with randomBox noise.
|
||||
|
||||
```
|
||||
python run.py --env BreakoutNoFrameskip-v4 --randomBoxNoise
|
||||
```
|
||||
|
||||
### Atari games with Gaussian noise
|
||||
|
||||
The following command should train a pure exploration
|
||||
agent on "Breakout" with Gaussian noise.
|
||||
|
||||
```
|
||||
python run.py --env BreakoutNoFrameskip-v4 --pixelNoise
|
||||
```
|
||||
|
||||
|
||||
### Atari games with sticky actions
|
||||
|
||||
The following command should train a pure exploration
|
||||
agent on "sticky Breakout" with a probability of 0.25
|
||||
|
||||
```
|
||||
python run.py --env BreakoutNoFrameskip-v4 --stickyAtari
|
||||
```
|
||||
|
||||
### Baselines
|
||||
|
||||
- **ICM**: We use the official [code](https://github.com/openai/large-scale-curiosity) of "Curiosity-driven Exploration by Self-supervised Prediction, ICML 2017" and "Large-Scale Study of Curiosity-Driven Learning, ICLR 2019".
|
||||
- **Disagreement**: We use the official [code](https://github.com/pathak22/exploration-by-disagreement) of "Self-Supervised Exploration via Disagreement, ICML 2019".
|
||||
- **CB**: We use the official [code](https://github.com/whyjay/curiosity-bottleneck) of "Curiosity-Bottleneck: Exploration by Distilling Task-Specific Novelty, ICML 2019".
|
1
__init__.py
Normal file
1
__init__.py
Normal file
@ -0,0 +1 @@
|
||||
#############
|
71
cnn_policy.py
Normal file
71
cnn_policy.py
Normal file
@ -0,0 +1,71 @@
|
||||
import tensorflow as tf
|
||||
from baselines.common.distributions import make_pdtype
|
||||
from utils import getsess, small_convnet, activ, fc, flatten_two_dims, unflatten_first_dim
|
||||
|
||||
|
||||
class CnnPolicy(object):
|
||||
def __init__(self, ob_space, ac_space, hidsize,
|
||||
ob_mean, ob_std, feat_dim, layernormalize, nl, scope="policy"):
|
||||
""" ob_space: (84,84,4); ac_space: 4;
|
||||
ob_mean.shape=(84,84,4); ob_std=1.7; hidsize: 512;
|
||||
feat_dim: 512; layernormalize: False; nl: tf.nn.leaky_relu.
|
||||
"""
|
||||
if layernormalize:
|
||||
print("Warning: policy is operating on top of layer-normed features. It might slow down the training.")
|
||||
self.layernormalize = layernormalize
|
||||
self.nl = nl
|
||||
self.ob_mean = ob_mean
|
||||
self.ob_std = ob_std
|
||||
with tf.variable_scope(scope):
|
||||
self.ob_space = ob_space
|
||||
self.ac_space = ac_space
|
||||
self.ac_pdtype = make_pdtype(ac_space)
|
||||
self.ph_ob = tf.placeholder(dtype=tf.int32,
|
||||
shape=(None, None) + ob_space.shape, name='ob')
|
||||
self.ph_ac = self.ac_pdtype.sample_placeholder([None, None], name='ac')
|
||||
self.pd = self.vpred = None
|
||||
self.hidsize = hidsize
|
||||
self.feat_dim = feat_dim
|
||||
self.scope = scope
|
||||
pdparamsize = self.ac_pdtype.param_shape()[0]
|
||||
|
||||
sh = tf.shape(self.ph_ob) # ph_ob.shape = (None,None,84,84,4)
|
||||
x = flatten_two_dims(self.ph_ob) # x.shape = (None,84,84,4)
|
||||
|
||||
self.flat_features = self.get_features(x, reuse=False) # shape=(None,512)
|
||||
self.features = unflatten_first_dim(self.flat_features, sh) # shape=(None,None,512)
|
||||
|
||||
with tf.variable_scope(scope, reuse=False):
|
||||
x = fc(self.flat_features, units=hidsize, activation=activ) # activ=tf.nn.relu
|
||||
x = fc(x, units=hidsize, activation=activ) # value and policy
|
||||
pdparam = fc(x, name='pd', units=pdparamsize, activation=None) # logits, shape=(None,4)
|
||||
vpred = fc(x, name='value_function_output', units=1, activation=None) # shape=(None,1)
|
||||
pdparam = unflatten_first_dim(pdparam, sh) # shape=(None,None,4)
|
||||
self.vpred = unflatten_first_dim(vpred, sh)[:, :, 0] # value function shape=(None,None)
|
||||
self.pd = pd = self.ac_pdtype.pdfromflat(pdparam) # mean,neglogp,kl,entropy,sample
|
||||
self.a_samp = pd.sample() #
|
||||
self.entropy = pd.entropy() # (None,None)
|
||||
self.nlp_samp = pd.neglogp(self.a_samp) # -log pi(a|s) (None,None)
|
||||
|
||||
def get_features(self, x, reuse):
|
||||
x_has_timesteps = (x.get_shape().ndims == 5)
|
||||
if x_has_timesteps:
|
||||
sh = tf.shape(x)
|
||||
x = flatten_two_dims(x)
|
||||
|
||||
with tf.variable_scope(self.scope + "_features", reuse=reuse):
|
||||
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
|
||||
x = small_convnet(x, nl=self.nl, feat_dim=self.feat_dim, last_nl=None, layernormalize=self.layernormalize)
|
||||
|
||||
if x_has_timesteps:
|
||||
x = unflatten_first_dim(x, sh)
|
||||
return x
|
||||
|
||||
def get_ac_value_nlp(self, ob):
|
||||
# ob.shape=(128,84,84,1), ob[:,None].shape=(128,1,84,84,4)
|
||||
a, vpred, nlp = \
|
||||
getsess().run([self.a_samp, self.vpred, self.nlp_samp],
|
||||
feed_dict={self.ph_ob: ob[:, None]})
|
||||
return a[:, 0], vpred[:, 0], nlp[:, 0]
|
||||
|
||||
|
260
cppo_agent.py
Normal file
260
cppo_agent.py
Normal file
@ -0,0 +1,260 @@
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines.common import explained_variance
|
||||
from baselines.common.mpi_moments import mpi_moments
|
||||
from baselines.common.running_mean_std import RunningMeanStd
|
||||
from mpi4py import MPI
|
||||
from mpi_utils import MpiAdamOptimizer
|
||||
from rollouts import Rollout
|
||||
from utils import bcast_tf_vars_from_root, get_mean_and_std
|
||||
from vec_env import ShmemVecEnv as VecEnv
|
||||
|
||||
getsess = tf.get_default_session
|
||||
|
||||
|
||||
class PpoOptimizer(object):
|
||||
envs = None
|
||||
|
||||
def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma, lam, nepochs, lr, cliprange,
|
||||
nminibatches, normrew, normadv, use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
|
||||
dynamic_bottleneck):
|
||||
self.dynamic_bottleneck = dynamic_bottleneck
|
||||
with tf.variable_scope(scope):
|
||||
self.use_recorder = True
|
||||
self.n_updates = 0
|
||||
self.scope = scope
|
||||
self.ob_space = ob_space # Box(84,84,4)
|
||||
self.ac_space = ac_space # Discrete(4)
|
||||
self.stochpol = stochpol # cnn policy
|
||||
self.nepochs = nepochs # 3
|
||||
self.lr = lr # 1e-4
|
||||
self.cliprange = cliprange # 0.1
|
||||
self.nsteps_per_seg = nsteps_per_seg # 128
|
||||
self.nsegs_per_env = nsegs_per_env # 1
|
||||
self.nminibatches = nminibatches # 8
|
||||
self.gamma = gamma # 0.99
|
||||
self.lam = lam # 0.99
|
||||
self.normrew = normrew # 1
|
||||
self.normadv = normadv # 1
|
||||
self.use_news = use_news # False
|
||||
self.ext_coeff = ext_coeff # 0.0
|
||||
self.int_coeff = int_coeff # 1.0
|
||||
self.ph_adv = tf.placeholder(tf.float32, [None, None])
|
||||
self.ph_ret = tf.placeholder(tf.float32, [None, None])
|
||||
self.ph_rews = tf.placeholder(tf.float32, [None, None])
|
||||
self.ph_oldnlp = tf.placeholder(tf.float32, [None, None]) # -log pi(a|s)
|
||||
self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
|
||||
self.ph_lr = tf.placeholder(tf.float32, [])
|
||||
self.ph_cliprange = tf.placeholder(tf.float32, [])
|
||||
neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
|
||||
entropy = tf.reduce_mean(self.stochpol.pd.entropy())
|
||||
vpred = self.stochpol.vpred
|
||||
|
||||
vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret) ** 2)
|
||||
ratio = tf.exp(self.ph_oldnlp - neglogpac) # p_new / p_old
|
||||
negadv = - self.ph_adv
|
||||
pg_losses1 = negadv * ratio
|
||||
pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
|
||||
pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
|
||||
pg_loss = tf.reduce_mean(pg_loss_surr)
|
||||
ent_loss = (- ent_coef) * entropy
|
||||
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))
|
||||
|
||||
self.total_loss = pg_loss + ent_loss + vf_loss
|
||||
self.to_report = {'tot': self.total_loss, 'pg': pg_loss, 'vf': vf_loss, 'ent': entropy, 'approxkl': approxkl, 'clipfrac': clipfrac}
|
||||
|
||||
# add bai
|
||||
self.db_loss = None
|
||||
|
||||
def start_interaction(self, env_fns, dynamic_bottleneck, nlump=2):
|
||||
self.loss_names, self._losses = zip(*list(self.to_report.items()))
|
||||
|
||||
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
|
||||
params_db = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="DB")
|
||||
print("***total params:", np.sum([np.prod(v.get_shape().as_list()) for v in params])) # idf:10,172,133
|
||||
print("***DB params:", np.sum([np.prod(v.get_shape().as_list()) for v in params_db])) # idf:10,172,133
|
||||
|
||||
if MPI.COMM_WORLD.Get_size() > 1:
|
||||
trainer = MpiAdamOptimizer(learning_rate=self.ph_lr, comm=MPI.COMM_WORLD)
|
||||
else:
|
||||
trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
|
||||
gradsandvars = trainer.compute_gradients(self.total_loss, params) # 计算梯度
|
||||
self._train = trainer.apply_gradients(gradsandvars)
|
||||
|
||||
# Train DB
|
||||
# gradsandvars_db = trainer.compute_gradients(self.db_loss, params_db)
|
||||
# self._train_db = trainer.apply_gradients(gradsandvars_db)
|
||||
|
||||
# Train DB with gradient clipping
|
||||
gradients_db, variables_db = zip(*trainer.compute_gradients(self.db_loss, params_db))
|
||||
gradients_db, self.norm_var = tf.clip_by_global_norm(gradients_db, 50.0)
|
||||
self._train_db = trainer.apply_gradients(zip(gradients_db, variables_db))
|
||||
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
|
||||
bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
|
||||
|
||||
self.all_visited_rooms = []
|
||||
self.all_scores = []
|
||||
self.nenvs = nenvs = len(env_fns) # 128
|
||||
self.nlump = nlump # 1
|
||||
self.lump_stride = nenvs // self.nlump # 128/1=128
|
||||
self.envs = [
|
||||
VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
|
||||
l in range(self.nlump)]
|
||||
|
||||
self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
|
||||
nsteps_per_seg=self.nsteps_per_seg,
|
||||
nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
|
||||
envs=self.envs,
|
||||
policy=self.stochpol,
|
||||
int_rew_coeff=self.int_coeff,
|
||||
ext_rew_coeff=self.ext_coeff,
|
||||
record_rollouts=self.use_recorder,
|
||||
dynamic_bottleneck=dynamic_bottleneck)
|
||||
|
||||
self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
|
||||
self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)
|
||||
|
||||
# add bai. Dynamic Bottleneck Reward Normalization
|
||||
if self.normrew:
|
||||
self.rff = RewardForwardFilter(self.gamma)
|
||||
self.rff_rms = RunningMeanStd()
|
||||
|
||||
self.step_count = 0
|
||||
self.t_last_update = time.time()
|
||||
self.t_start = time.time()
|
||||
|
||||
def stop_interaction(self):
|
||||
for env in self.envs:
|
||||
env.close()
|
||||
|
||||
def calculate_advantages(self, rews, use_news, gamma, lam):
|
||||
nsteps = self.rollout.nsteps
|
||||
lastgaelam = 0
|
||||
for t in range(nsteps - 1, -1, -1): # nsteps-2 ... 0
|
||||
nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
|
||||
if not use_news:
|
||||
nextnew = 0
|
||||
nextvals = self.rollout.buf_vpreds[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_last
|
||||
nextnotnew = 1 - nextnew
|
||||
delta = rews[:, t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:, t]
|
||||
self.buf_advs[:, t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
|
||||
self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds
|
||||
|
||||
def update(self):
|
||||
# add bai. use dynamic bottleneck
|
||||
if self.normrew:
|
||||
rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
|
||||
rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
|
||||
self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
|
||||
rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var) # shape=(128,128)
|
||||
else:
|
||||
rews = np.copy(self.rollout.buf_rews)
|
||||
|
||||
self.calculate_advantages(rews=rews, use_news=self.use_news, gamma=self.gamma, lam=self.lam)
|
||||
|
||||
info = dict(
|
||||
advmean=self.buf_advs.mean(),
|
||||
advstd=self.buf_advs.std(),
|
||||
retmean=self.buf_rets.mean(),
|
||||
retstd=self.buf_rets.std(),
|
||||
vpredmean=self.rollout.buf_vpreds.mean(),
|
||||
vpredstd=self.rollout.buf_vpreds.std(),
|
||||
ev=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel()),
|
||||
DB_rew=np.mean(self.rollout.buf_rews), # add bai.
|
||||
DB_rew_norm=np.mean(rews), # add bai.
|
||||
recent_best_ext_ret=self.rollout.current_max
|
||||
)
|
||||
if self.rollout.best_ext_ret is not None:
|
||||
info['best_ext_ret'] = self.rollout.best_ext_ret
|
||||
|
||||
if self.normadv:
|
||||
m, s = get_mean_and_std(self.buf_advs)
|
||||
self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
|
||||
envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
|
||||
envsperbatch = max(1, envsperbatch)
|
||||
envinds = np.arange(self.nenvs * self.nsegs_per_env)
|
||||
|
||||
def resh(x):
|
||||
if self.nsegs_per_env == 1:
|
||||
return x
|
||||
sh = x.shape
|
||||
return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])
|
||||
|
||||
ph_buf = [
|
||||
(self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
|
||||
(self.ph_rews, resh(self.rollout.buf_rews)),
|
||||
(self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
|
||||
(self.ph_oldnlp, resh(self.rollout.buf_nlps)),
|
||||
(self.stochpol.ph_ob, resh(self.rollout.buf_obs)), # numpy shape=(128,128,84,84,4)
|
||||
(self.ph_ret, resh(self.buf_rets)), #
|
||||
(self.ph_adv, resh(self.buf_advs)), #
|
||||
]
|
||||
ph_buf.extend([
|
||||
(self.dynamic_bottleneck.last_ob, # shape=(128,1,84,84,4)
|
||||
self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
|
||||
])
|
||||
mblossvals = [] #
|
||||
for _ in range(self.nepochs): # nepochs = 3
|
||||
np.random.shuffle(envinds) # envinds = [0,1,2,...,127]
|
||||
# nenvs=128, nsgs_per_env=1, envsperbatch=16
|
||||
for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
|
||||
end = start + envsperbatch
|
||||
mbenvinds = envinds[start:end]
|
||||
fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf} # feed_dict
|
||||
fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange}) # , self.dynamic_bottleneck.l2_aux_loss_tf: l2_aux_loss_fd})
|
||||
mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1]) #
|
||||
|
||||
# gradient norm computation
|
||||
# print("gradient norm:", getsess().run(self.norm_var, fd))
|
||||
|
||||
# momentum update DB parameters
|
||||
print("Momentum Update DB Encoder")
|
||||
getsess().run(self.dynamic_bottleneck.momentum_updates)
|
||||
DB_loss_info = getsess().run(self.dynamic_bottleneck.loss_info, fd)
|
||||
|
||||
#
|
||||
mblossvals = [mblossvals[0]]
|
||||
info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
|
||||
info["rank"] = MPI.COMM_WORLD.Get_rank()
|
||||
self.n_updates += 1
|
||||
info["n_updates"] = self.n_updates
|
||||
info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
|
||||
info.update(self.rollout.stats)
|
||||
if "states_visited" in info:
|
||||
info.pop("states_visited")
|
||||
tnow = time.time()
|
||||
info["ups"] = 1. / (tnow - self.t_last_update)
|
||||
info["total_secs"] = tnow - self.t_start
|
||||
info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
|
||||
self.t_last_update = tnow
|
||||
|
||||
return info, DB_loss_info
|
||||
|
||||
def step(self):
|
||||
self.rollout.collect_rollout()
|
||||
update_info, DB_loss_info = self.update()
|
||||
return {'update': update_info, "DB_loss_info": DB_loss_info}
|
||||
|
||||
def get_var_values(self):
|
||||
return self.stochpol.get_var_values()
|
||||
|
||||
def set_var_values(self, vv):
|
||||
self.stochpol.set_var_values(vv)
|
||||
|
||||
|
||||
class RewardForwardFilter(object):
|
||||
def __init__(self, gamma):
|
||||
self.rewems = None
|
||||
self.gamma = gamma
|
||||
|
||||
def update(self, rews):
|
||||
if self.rewems is None:
|
||||
self.rewems = rews
|
||||
else:
|
||||
self.rewems = self.rewems * self.gamma + rews
|
||||
return self.rewems
|
168
dynamic_bottleneck.py
Normal file
168
dynamic_bottleneck.py
Normal file
@ -0,0 +1,168 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
import numpy as np
|
||||
from utils import getsess
|
||||
|
||||
tfd = tfp.distributions
|
||||
|
||||
from utils import flatten_two_dims, unflatten_first_dim, SmallConv, TransitionNetwork, normal_parse_params, \
|
||||
ProjectionHead, ContrastiveHead, rec_log_prob, GenerativeNetworkGaussianFix
|
||||
|
||||
|
||||
class DynamicBottleneck(object):
|
||||
def __init__(self, policy, tau, loss_kl_weight, loss_l2_weight, loss_nce_weight, aug, feat_dim=512, scope='DB'):
|
||||
self.scope = scope
|
||||
self.feat_dim = feat_dim
|
||||
self.policy = policy
|
||||
self.hidsize = policy.hidsize # 512
|
||||
self.ob_space = policy.ob_space # Box(84, 84, 4)
|
||||
self.ac_space = policy.ac_space # Discrete(4)
|
||||
self.obs = self.policy.ph_ob # shape=(None,None,84,84,4)
|
||||
self.ob_mean = self.policy.ob_mean # shape=(None,None,84,84,4)
|
||||
self.ob_std = self.policy.ob_std # 1.8
|
||||
self.tau = tau # tau for update the momentum network
|
||||
self.loss_kl_weight = loss_kl_weight # loss_kl_weight
|
||||
self.loss_l2_weight = loss_l2_weight # loss_l2_weight
|
||||
self.loss_nce_weight = loss_nce_weight # loss_nce_weight
|
||||
self.aug = aug
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
self.feature_conv = SmallConv(feat_dim=self.feat_dim, name="DB_main") # (None, None, 512)
|
||||
self.feature_conv_momentum = SmallConv(feat_dim=self.feat_dim, name="DB_momentum") # (None, None, 512)
|
||||
self.transition_model = TransitionNetwork(name="DB_transition") # (None, None, 256)
|
||||
self.generative_model = GenerativeNetworkGaussianFix(name="DB_generative") # (None, None, 512)
|
||||
self.projection_head = ProjectionHead(name="DB_projection_main") # projection head
|
||||
self.projection_head_momentum = ProjectionHead(name="DB_projection_momentum") # projection head Momentum
|
||||
self.contrastive_head = ContrastiveHead(temperature=1.0, name="DB_contrastive")
|
||||
|
||||
# (None,1,84,84,4)
|
||||
self.last_ob = tf.placeholder(dtype=tf.int32, shape=(None, 1) + self.ob_space.shape, name='last_ob')
|
||||
self.next_ob = tf.concat([self.obs[:, 1:], self.last_ob], 1) # (None,None,84,84,4)
|
||||
|
||||
self.features = self.get_features(self.obs) # (None,None,512)
|
||||
self.next_features = self.get_features(self.next_ob, momentum=True) # (None,None,512) stop gradient
|
||||
self.ac = self.policy.ph_ac # (None, None)
|
||||
self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2)
|
||||
|
||||
# transition model
|
||||
latent_params = self.transition_model([self.features, self.ac_pad]) # (None, None, 256)
|
||||
self.latent_dis = normal_parse_params(latent_params, 1e-3) # Gaussian. mu, sigma=(None, None, 128)
|
||||
|
||||
# prior
|
||||
sh = tf.shape(self.latent_dis.mean()) # sh=(None, None, 128)
|
||||
self.prior_dis = tfd.Normal(loc=tf.zeros(sh), scale=tf.ones(sh))
|
||||
|
||||
# kl
|
||||
kl = tfp.distributions.kl_divergence(self.latent_dis, self.prior_dis) # (None, None, 128)
|
||||
kl = tf.reduce_sum(kl, axis=-1) # (None, None)
|
||||
|
||||
# generative network
|
||||
latent = self.latent_dis.sample() # (None, None, 128)
|
||||
rec_params = self.generative_model(latent) # (None, None, 1024)
|
||||
assert rec_params.get_shape().as_list()[-1] == 1024 and len(rec_params.get_shape().as_list()) == 3
|
||||
rec_dis = normal_parse_params(rec_params, 0.1) # distribution
|
||||
|
||||
rec_vec = rec_dis.sample() # mean of rec_params
|
||||
assert rec_vec.get_shape().as_list()[-1] == 512 and len(rec_vec.get_shape().as_list()) == 3
|
||||
|
||||
# contrastive projection
|
||||
z_a = self.projection_head(rec_vec) # (None, 128)
|
||||
z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128)
|
||||
assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2
|
||||
|
||||
# contrastive loss
|
||||
logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size)
|
||||
labels = tf.one_hot(tf.range(int(16*128)), depth=16*128) # (batch_size, batch_size)
|
||||
rec_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) # (batch_size, )
|
||||
rec_log_nce = -1. * rec_loss
|
||||
rec_log_nce = unflatten_first_dim(rec_log_nce, sh) # shape=(None, None) (128,128)
|
||||
|
||||
# L2 loss
|
||||
log_prob = rec_dis.log_prob(self.next_features) # (None, None, 512)
|
||||
assert len(log_prob.get_shape().as_list()) == 3 and log_prob.get_shape().as_list()[-1] == 512
|
||||
rec_log_l2 = tf.reduce_sum(log_prob, axis=-1)
|
||||
rec_log = rec_log_nce * self.loss_nce_weight + rec_log_l2 * self.loss_l2_weight
|
||||
|
||||
# loss
|
||||
self.loss = kl * self.loss_kl_weight - rec_log # kl
|
||||
self.loss_info = {"DB_NCELoss": -1.*tf.reduce_mean(rec_log_nce),
|
||||
"DB_NCELoss_w": -1. * tf.reduce_mean(rec_log_nce) * self.loss_nce_weight,
|
||||
"DB_L2Loss": -1.*tf.reduce_mean(rec_log_l2),
|
||||
"DB_L2Loss_w": -1.*tf.reduce_mean(rec_log_l2) * self.loss_l2_weight,
|
||||
"DB_KLLoss": tf.reduce_mean(kl),
|
||||
"DB_KLLoss_w": tf.reduce_mean(kl) * self.loss_kl_weight,
|
||||
"DB_Loss": tf.reduce_mean(self.loss)}
|
||||
|
||||
# intrinsic reward
|
||||
self.intrinsic_reward = self.intrinsic_contrastive()
|
||||
self.intrinsic_reward = tf.stop_gradient(self.intrinsic_reward)
|
||||
|
||||
# update the momentum network
|
||||
self.init_updates, self.momentum_updates = self.get_momentum_updates(tau=self.tau)
|
||||
print("*** DB Total Components:", len(self.ib_get_vars(name='DB/')), ", Total Variables:", self.ib_get_params(self.ib_get_vars(name='DB/')), "\n")
|
||||
|
||||
@staticmethod
|
||||
def ib_get_vars(name):
|
||||
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
|
||||
|
||||
@staticmethod
|
||||
def ib_get_params(vars):
|
||||
return np.sum([np.prod(v.shape) for v in vars])
|
||||
|
||||
def get_momentum_updates(self, tau): # tau=0.001
|
||||
main_var = self.ib_get_vars(name='DB/DB_features/DB_main') + self.ib_get_vars(name="DB/DB_projection_main")
|
||||
momentum_var = self.ib_get_vars(name='DB/DB_features_1/DB_momentum') + self.ib_get_vars(name="DB/DB_projection_momentum")
|
||||
|
||||
# print("\n\n momentum_var:", momentum_var)
|
||||
assert len(main_var) > 0 and len(main_var) == len(momentum_var)
|
||||
print("***In DB, feature & projection has ", len(main_var), "components, ", self.ib_get_params(main_var), "parameters.")
|
||||
|
||||
soft_updates = []
|
||||
init_updates = []
|
||||
assert len(main_var) == len(momentum_var)
|
||||
for var, tvar in zip(main_var, momentum_var):
|
||||
init_updates.append(tf.assign(tvar, var))
|
||||
soft_updates.append(tf.assign(tvar, (1. - tau) * tvar + tau * var))
|
||||
assert len(init_updates) == len(main_var)
|
||||
assert len(soft_updates) == len(main_var)
|
||||
return tf.group(*init_updates), tf.group(*soft_updates)
|
||||
|
||||
def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4)
|
||||
x_has_timesteps = (x.get_shape().ndims == 5) # True
|
||||
if x_has_timesteps:
|
||||
sh = tf.shape(x)
|
||||
x = flatten_two_dims(x) # (None,84,84,4)
|
||||
|
||||
if self.aug:
|
||||
print(x.get_shape().as_list())
|
||||
x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4)
|
||||
x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4)
|
||||
x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4)
|
||||
|
||||
with tf.variable_scope(self.scope + "_features"):
|
||||
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
|
||||
if momentum:
|
||||
x = tf.stop_gradient(self.feature_conv_momentum(x)) # (None,512)
|
||||
else:
|
||||
x = self.feature_conv(x) # (None,512)
|
||||
if x_has_timesteps:
|
||||
x = unflatten_first_dim(x, sh) # (None,None,512)
|
||||
return x
|
||||
|
||||
def intrinsic_contrastive(self):
|
||||
kl = tfp.distributions.kl_divergence(self.latent_dis, self.prior_dis) # (None, None, 128)
|
||||
rew = tf.reduce_sum(kl, axis=-1) # (None, None)
|
||||
return rew
|
||||
|
||||
def calculate_db_reward(self, ob, last_ob, acs):
|
||||
n_chunks = 8
|
||||
n = ob.shape[0]
|
||||
chunk_size = n // n_chunks
|
||||
assert n % n_chunks == 0
|
||||
sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size)
|
||||
|
||||
# compute reward
|
||||
rew = np.concatenate([getsess().run(self.intrinsic_reward,
|
||||
{self.obs: ob[sli(i)], self.last_ob: last_ob[sli(i)],
|
||||
self.ac: acs[sli(i)]}) for i in range(n_chunks)], 0)
|
||||
return rew
|
33
mpi_utils.py
Normal file
33
mpi_utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from mpi4py import MPI
|
||||
|
||||
class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||
"""Adam optimizer that averages gradients across mpi processes."""
|
||||
|
||||
def __init__(self, comm, **kwargs):
|
||||
self.comm = comm
|
||||
tf.train.AdamOptimizer.__init__(self, **kwargs)
|
||||
|
||||
def compute_gradients(self, loss, var_list, **kwargs):
|
||||
grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
|
||||
grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
|
||||
flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0)
|
||||
shapes = [v.shape.as_list() for g, v in grads_and_vars]
|
||||
sizes = [int(np.prod(s)) for s in shapes]
|
||||
|
||||
_task_id, num_tasks = self.comm.Get_rank(), self.comm.Get_size()
|
||||
buf = np.zeros(sum(sizes), np.float32)
|
||||
|
||||
def _collect_grads(flat_grad):
|
||||
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
||||
np.divide(buf, float(num_tasks), out=buf)
|
||||
return buf
|
||||
|
||||
avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32)
|
||||
avg_flat_grad.set_shape(flat_grad.shape)
|
||||
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
|
||||
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
|
||||
for g, (_, v) in zip(avg_grads, grads_and_vars)]
|
||||
|
||||
return avg_grads_and_vars
|
101
para.json
Normal file
101
para.json
Normal file
@ -0,0 +1,101 @@
|
||||
{
|
||||
"standard": {
|
||||
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||
"CrazyClimber":{"kl": 0.1, "nce": 0.1},
|
||||
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||
"Kangaroo":{"kl": 0.1, "nce": 0.1},
|
||||
"KungFuMaster":{"kl": 0.1, "nce": 0.01},
|
||||
"MsPacman":{"kl": 0.1, "nce": 0.1},
|
||||
"Seaquest":{"kl": 0.1, "nce": 0.1},
|
||||
"Solaris":{"kl": 0.1, "nce": 0.1},
|
||||
"Tennis":{"kl": 0.1, "nce": 0.01},
|
||||
"TimePilot":{"kl": 0.1, "nce": 0.01},
|
||||
"UpNDown":{"kl": 0.1, "nce": 0.01},
|
||||
"VideoPinball":{"kl": 0.1, "nce": 0.01},
|
||||
"WizardOfWor":{"kl": 0.1, "nce": 0.1},
|
||||
"Zaxxon":{"kl": 0.1, "nce": 0.01},
|
||||
"other":{"kl": 0.1, "nce":0.01}},
|
||||
|
||||
"randomBox": {
|
||||
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||
"CrazyClimber":{"kl": 0.001, "nce": 0.01},
|
||||
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||
"Kangaroo":{"kl": 0.001, "nce": 0.01},
|
||||
"KungFuMaster":{"kl": 0.001, "nce": 0.01},
|
||||
"MsPacman":{"kl": 0.001, "nce": 0.01},
|
||||
"Seaquest":{"kl": 0.001, "nce": 0.01},
|
||||
"Solaris":{"kl": 0.1, "nce":0.1},
|
||||
"Tennis":{"kl": 0.1, "nce":0.01},
|
||||
"TimePilot":{"kl": 0.1, "nce":0.01},
|
||||
"UpNDown":{"kl": 0.001, "nce":0.01},
|
||||
"VideoPinball":{"kl": 0.1, "nce":0.01},
|
||||
"WizardOfWor":{"kl": 0.1, "nce":0.1},
|
||||
"Zaxxon":{"kl": 0.1, "nce":0.01},
|
||||
"other":{"kl": 0.001, "nce":0.01}},
|
||||
|
||||
"stickyAtari":{
|
||||
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||
"CrazyClimber":{"kl": 0.1, "nce": 0.01},
|
||||
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||
"Kangaroo":{"kl": 0.001, "nce": 0.01},
|
||||
"KungFuMaster":{"kl": 0.001, "nce": 0.01},
|
||||
"MsPacman":{"kl": 0.1, "nce": 0.1},
|
||||
"Seaquest":{"kl": 0.1, "nce": 0.1},
|
||||
"Solaris":{"kl": 0.1, "nce": 0.1},
|
||||
"Tennis":{"kl": 0.1, "nce": 0.01},
|
||||
"TimePilot":{"kl": 0.1, "nce": 0.01},
|
||||
"UpNDown":{"kl": 0.1, "nce": 0.01},
|
||||
"VideoPinball":{"kl": 0.1, "nce": 0.01},
|
||||
"WizardOfWor":{"kl": 0.1, "nce": 0.1},
|
||||
"Zaxxon":{"kl": 0.1, "nce": 0.01},
|
||||
"other":{"kl": 0.1, "nce":0.01}},
|
||||
|
||||
"pixelNoise": {
|
||||
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||
"CrazyClimber":{"kl": 0.001, "nce": 0.01},
|
||||
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||
"Kangaroo":{"kl": 0.001, "nce": 0.01},
|
||||
"KungFuMaster":{"kl": 0.001, "nce": 0.01},
|
||||
"MsPacman":{"kl": 0.001, "nce": 0.01},
|
||||
"Seaquest":{"kl": 0.1, "nce": 0.1},
|
||||
"Solaris":{"kl": 0.1, "nce": 0.1},
|
||||
"Tennis":{"kl": 0.1, "nce": 0.01},
|
||||
"TimePilot":{"kl": 0.1, "nce": 0.01},
|
||||
"UpNDown":{"kl": 0.1, "nce": 0.01},
|
||||
"VideoPinball":{"kl": 0.1, "nce": 0.01},
|
||||
"WizardOfWor":{"kl": 0.1, "nce": 0.1},
|
||||
"Zaxxon":{"kl": 0.1, "nce": 0.01},
|
||||
"other":{"kl": 0.1, "nce":0.01}}
|
||||
}
|
63
recorder.py
Normal file
63
recorder.py
Normal file
@ -0,0 +1,63 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from baselines import logger
|
||||
from mpi4py import MPI
|
||||
|
||||
class Recorder(object):
|
||||
def __init__(self, nenvs, nlumps):
|
||||
self.nenvs = nenvs
|
||||
self.nlumps = nlumps
|
||||
self.nenvs_per_lump = nenvs // nlumps
|
||||
self.acs = [[] for _ in range(nenvs)]
|
||||
self.int_rews = [[] for _ in range(nenvs)]
|
||||
self.ext_rews = [[] for _ in range(nenvs)]
|
||||
self.ep_infos = [{} for _ in range(nenvs)]
|
||||
self.filenames = [self.get_filename(i) for i in range(nenvs)]
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
logger.info("episode recordings saved to ", self.filenames[0])
|
||||
|
||||
def record(self, timestep, lump, acs, infos, int_rew, ext_rew, news):
|
||||
for out_index in range(self.nenvs_per_lump):
|
||||
in_index = out_index + lump * self.nenvs_per_lump
|
||||
if timestep == 0:
|
||||
self.acs[in_index].append(acs[out_index])
|
||||
else:
|
||||
if self.is_first_episode_step(in_index):
|
||||
try:
|
||||
self.ep_infos[in_index]['random_state'] = infos[out_index]['random_state']
|
||||
except:
|
||||
pass
|
||||
|
||||
self.int_rews[in_index].append(int_rew[out_index])
|
||||
self.ext_rews[in_index].append(ext_rew[out_index])
|
||||
|
||||
if news[out_index]:
|
||||
self.ep_infos[in_index]['ret'] = infos[out_index]['episode']['r']
|
||||
self.ep_infos[in_index]['len'] = infos[out_index]['episode']['l']
|
||||
self.dump_episode(in_index)
|
||||
|
||||
self.acs[in_index].append(acs[out_index])
|
||||
|
||||
def dump_episode(self, i):
|
||||
episode = {'acs': self.acs[i],
|
||||
'int_rew': self.int_rews[i],
|
||||
'info': self.ep_infos[i]}
|
||||
filename = self.filenames[i]
|
||||
if self.episode_worth_saving(i):
|
||||
with open(filename, 'ab') as f:
|
||||
pickle.dump(episode, f, protocol=-1)
|
||||
self.acs[i].clear()
|
||||
self.int_rews[i].clear()
|
||||
self.ext_rews[i].clear()
|
||||
self.ep_infos[i].clear()
|
||||
|
||||
def episode_worth_saving(self, i):
|
||||
return (i == 0 and MPI.COMM_WORLD.Get_rank() == 0)
|
||||
|
||||
def is_first_episode_step(self, i):
|
||||
return len(self.int_rews[i]) == 0
|
||||
|
||||
def get_filename(self, i):
|
||||
filename = os.path.join(logger.get_dir(), 'env{}_{}.pk'.format(MPI.COMM_WORLD.Get_rank(), i))
|
||||
return filename
|
177
rollouts.py
Normal file
177
rollouts.py
Normal file
@ -0,0 +1,177 @@
|
||||
from collections import deque, defaultdict
|
||||
|
||||
import numpy as np
|
||||
from mpi4py import MPI
|
||||
# from utils import add_noise
|
||||
from recorder import Recorder
|
||||
|
||||
|
||||
class Rollout(object):
|
||||
def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy,
|
||||
int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p):
|
||||
# int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True
|
||||
self.nenvs = nenvs # 128
|
||||
self.nsteps_per_seg = nsteps_per_seg # 128
|
||||
self.nsegs_per_env = nsegs_per_env # 1
|
||||
self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128
|
||||
self.ob_space = ob_space # Box (84,84,4)
|
||||
self.ac_space = ac_space # Discrete(4)
|
||||
self.nlumps = nlumps # 1
|
||||
self.lump_stride = nenvs // self.nlumps # 128
|
||||
self.envs = envs
|
||||
self.policy = policy
|
||||
self.dynamic_bottleneck = dynamic_bottleneck # Dynamic Bottleneck
|
||||
|
||||
self.reward_fun = lambda ext_rew, int_rew: ext_rew_coeff * np.clip(ext_rew, -1., 1.) + int_rew_coeff * int_rew
|
||||
|
||||
self.buf_vpreds = np.empty((nenvs, self.nsteps), np.float32)
|
||||
self.buf_nlps = np.empty((nenvs, self.nsteps), np.float32)
|
||||
self.buf_rews = np.empty((nenvs, self.nsteps), np.float32)
|
||||
self.buf_ext_rews = np.empty((nenvs, self.nsteps), np.float32)
|
||||
self.buf_acs = np.empty((nenvs, self.nsteps, *self.ac_space.shape), self.ac_space.dtype)
|
||||
self.buf_obs = np.empty((nenvs, self.nsteps, *self.ob_space.shape), self.ob_space.dtype)
|
||||
self.buf_obs_last = np.empty((nenvs, self.nsegs_per_env, *self.ob_space.shape), np.float32)
|
||||
|
||||
self.buf_news = np.zeros((nenvs, self.nsteps), np.float32)
|
||||
self.buf_new_last = self.buf_news[:, 0, ...].copy()
|
||||
self.buf_vpred_last = self.buf_vpreds[:, 0, ...].copy()
|
||||
|
||||
self.env_results = [None] * self.nlumps
|
||||
self.int_rew = np.zeros((nenvs,), np.float32)
|
||||
|
||||
self.recorder = Recorder(nenvs=self.nenvs, nlumps=self.nlumps) if record_rollouts else None
|
||||
self.statlists = defaultdict(lambda: deque([], maxlen=100))
|
||||
self.stats = defaultdict(float)
|
||||
self.best_ext_ret = None
|
||||
self.all_visited_rooms = []
|
||||
self.all_scores = []
|
||||
|
||||
self.step_count = 0
|
||||
|
||||
# add bai. Noise box in observation
|
||||
# self.noisy_box = noisy_box
|
||||
# self.noisy_p = noisy_p
|
||||
|
||||
def collect_rollout(self):
|
||||
self.ep_infos_new = []
|
||||
for t in range(self.nsteps):
|
||||
self.rollout_step()
|
||||
self.calculate_reward()
|
||||
self.update_info()
|
||||
|
||||
def calculate_reward(self): # Reward comes from Dynamic Bottleneck
|
||||
db_rew = self.dynamic_bottleneck.calculate_db_reward(
|
||||
ob=self.buf_obs, last_ob=self.buf_obs_last, acs=self.buf_acs)
|
||||
self.buf_rews[:] = self.reward_fun(int_rew=db_rew, ext_rew=self.buf_ext_rews)
|
||||
|
||||
def rollout_step(self):
|
||||
t = self.step_count % self.nsteps
|
||||
s = t % self.nsteps_per_seg
|
||||
for l in range(self.nlumps): # nclumps=1
|
||||
obs, prevrews, news, infos = self.env_get(l)
|
||||
# if t > 0:
|
||||
# prev_feat = self.prev_feat[l]
|
||||
# prev_acs = self.prev_acs[l]
|
||||
for info in infos:
|
||||
epinfo = info.get('episode', {})
|
||||
mzepinfo = info.get('mz_episode', {})
|
||||
retroepinfo = info.get('retro_episode', {})
|
||||
epinfo.update(mzepinfo)
|
||||
epinfo.update(retroepinfo)
|
||||
if epinfo:
|
||||
if "n_states_visited" in info:
|
||||
epinfo["n_states_visited"] = info["n_states_visited"]
|
||||
epinfo["states_visited"] = info["states_visited"]
|
||||
self.ep_infos_new.append((self.step_count, epinfo))
|
||||
|
||||
# slice(0,128) lump_stride=128
|
||||
sli = slice(l * self.lump_stride, (l + 1) * self.lump_stride)
|
||||
|
||||
acs, vpreds, nlps = self.policy.get_ac_value_nlp(obs)
|
||||
self.env_step(l, acs)
|
||||
|
||||
# self.prev_feat[l] = dyn_feat
|
||||
# self.prev_acs[l] = acs
|
||||
self.buf_obs[sli, t] = obs # obs.shape=(128,84,84,4)
|
||||
self.buf_news[sli, t] = news # shape=(128,) True/False
|
||||
self.buf_vpreds[sli, t] = vpreds # shape=(128,)
|
||||
self.buf_nlps[sli, t] = nlps # -log pi(a|s), shape=(128,)
|
||||
self.buf_acs[sli, t] = acs # shape=(128,)
|
||||
if t > 0:
|
||||
self.buf_ext_rews[sli, t - 1] = prevrews # prevrews.shape=(128,)
|
||||
if self.recorder is not None:
|
||||
self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli],
|
||||
ext_rew=prevrews, news=news)
|
||||
|
||||
self.step_count += 1
|
||||
if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128
|
||||
for l in range(self.nlumps): # nclumps=1
|
||||
sli = slice(l * self.lump_stride, (l + 1) * self.lump_stride)
|
||||
nextobs, ext_rews, nextnews, _ = self.env_get(l)
|
||||
self.buf_obs_last[sli, t // self.nsteps_per_seg] = nextobs
|
||||
if t == self.nsteps - 1: # t=127
|
||||
self.buf_new_last[sli] = nextnews
|
||||
self.buf_ext_rews[sli, t] = ext_rews #
|
||||
_, self.buf_vpred_last[sli], _ = self.policy.get_ac_value_nlp(nextobs) #
|
||||
|
||||
def update_info(self):
|
||||
all_ep_infos = MPI.COMM_WORLD.allgather(self.ep_infos_new)
|
||||
all_ep_infos = sorted(sum(all_ep_infos, []), key=lambda x: x[0])
|
||||
if all_ep_infos:
|
||||
all_ep_infos = [i_[1] for i_ in all_ep_infos] # remove the step_count
|
||||
keys_ = all_ep_infos[0].keys()
|
||||
all_ep_infos = {k: [i[k] for i in all_ep_infos] for k in keys_}
|
||||
# all_ep_infos: {'r': [0.0, 0.0, 0.0], 'l': [124, 125, 127], 't': [6.60745, 12.034875, 10.772788]}
|
||||
|
||||
self.statlists['eprew'].extend(all_ep_infos['r'])
|
||||
self.stats['eprew_recent'] = np.mean(all_ep_infos['r'])
|
||||
self.statlists['eplen'].extend(all_ep_infos['l'])
|
||||
self.stats['epcount'] += len(all_ep_infos['l'])
|
||||
self.stats['tcount'] += sum(all_ep_infos['l'])
|
||||
if 'visited_rooms' in keys_:
|
||||
# Montezuma specific logging.
|
||||
self.stats['visited_rooms'] = sorted(list(set.union(*all_ep_infos['visited_rooms'])))
|
||||
self.stats['pos_count'] = np.mean(all_ep_infos['pos_count'])
|
||||
self.all_visited_rooms.extend(self.stats['visited_rooms'])
|
||||
self.all_scores.extend(all_ep_infos["r"])
|
||||
self.all_scores = sorted(list(set(self.all_scores)))
|
||||
self.all_visited_rooms = sorted(list(set(self.all_visited_rooms)))
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
print("All visited rooms")
|
||||
print(self.all_visited_rooms)
|
||||
print("All scores")
|
||||
print(self.all_scores)
|
||||
if 'levels' in keys_:
|
||||
# Retro logging
|
||||
temp = sorted(list(set.union(*all_ep_infos['levels'])))
|
||||
self.all_visited_rooms.extend(temp)
|
||||
self.all_visited_rooms = sorted(list(set(self.all_visited_rooms)))
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
print("All visited levels")
|
||||
print(self.all_visited_rooms)
|
||||
|
||||
current_max = np.max(all_ep_infos['r'])
|
||||
else:
|
||||
current_max = None
|
||||
self.ep_infos_new = []
|
||||
|
||||
# best_ext_ret
|
||||
if current_max is not None:
|
||||
if (self.best_ext_ret is None) or (current_max > self.best_ext_ret):
|
||||
self.best_ext_ret = current_max
|
||||
self.current_max = current_max
|
||||
|
||||
def env_step(self, l, acs):
|
||||
self.envs[l].step_async(acs)
|
||||
self.env_results[l] = None
|
||||
|
||||
def env_get(self, l):
|
||||
if self.step_count == 0:
|
||||
ob = self.envs[l].reset()
|
||||
out = self.env_results[l] = (ob, None, np.ones(self.lump_stride, bool), {})
|
||||
else:
|
||||
if self.env_results[l] is None:
|
||||
out = self.env_results[l] = self.envs[l].step_wait()
|
||||
else:
|
||||
out = self.env_results[l]
|
||||
return out
|
287
run.py
Normal file
287
run.py
Normal file
@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python
|
||||
try:
|
||||
from OpenGL import GLU
|
||||
except:
|
||||
print("no OpenGL.GLU")
|
||||
import functools
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
import os
|
||||
import gym
|
||||
import tensorflow as tf
|
||||
from baselines import logger
|
||||
from baselines.bench import Monitor
|
||||
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
|
||||
from mpi4py import MPI
|
||||
|
||||
from dynamic_bottleneck import DynamicBottleneck
|
||||
from cnn_policy import CnnPolicy
|
||||
from cppo_agent import PpoOptimizer
|
||||
from utils import random_agent_ob_mean_std
|
||||
from wrappers import MontezumaInfoWrapper, make_mario_env, make_robo_pong, make_robo_hockey, \
|
||||
make_multi_pong, AddRandomStateToInfo, MaxAndSkipEnv, ProcessFrame84, ExtraTimeLimit, StickyActionEnv
|
||||
import datetime
|
||||
from wrappers import PixelNoiseWrapper, RandomBoxNoiseWrapper
|
||||
import json
|
||||
|
||||
getsess = tf.get_default_session
|
||||
|
||||
|
||||
def start_experiment(**args):
|
||||
make_env = partial(make_env_all_params, add_monitor=True, args=args)
|
||||
|
||||
trainer = Trainer(make_env=make_env,
|
||||
num_timesteps=args['num_timesteps'], hps=args,
|
||||
envs_per_process=args['envs_per_process'])
|
||||
log, tf_sess, saver, logger_dir = get_experiment_environment(**args)
|
||||
with log, tf_sess:
|
||||
logdir = logger.get_dir()
|
||||
print("results will be saved to ", logdir)
|
||||
trainer.train(saver, logger_dir)
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, make_env, hps, num_timesteps, envs_per_process):
|
||||
self.make_env = make_env
|
||||
self.hps = hps
|
||||
self.envs_per_process = envs_per_process
|
||||
self.num_timesteps = num_timesteps
|
||||
self._set_env_vars()
|
||||
|
||||
self.policy = CnnPolicy(scope='pol',
|
||||
ob_space=self.ob_space,
|
||||
ac_space=self.ac_space,
|
||||
hidsize=512,
|
||||
feat_dim=512,
|
||||
ob_mean=self.ob_mean,
|
||||
ob_std=self.ob_std,
|
||||
layernormalize=False,
|
||||
nl=tf.nn.leaky_relu)
|
||||
|
||||
self.dynamic_bottleneck = DynamicBottleneck(
|
||||
policy=self.policy, feat_dim=512, tau=hps['momentum_tau'], loss_kl_weight=hps['loss_kl_weight'],
|
||||
loss_nce_weight=hps['loss_nce_weight'], loss_l2_weight=hps['loss_l2_weight'], aug=hps['aug'])
|
||||
|
||||
self.agent = PpoOptimizer(
|
||||
scope='ppo',
|
||||
ob_space=self.ob_space,
|
||||
ac_space=self.ac_space,
|
||||
stochpol=self.policy,
|
||||
use_news=hps['use_news'],
|
||||
gamma=hps['gamma'],
|
||||
lam=hps["lambda"],
|
||||
nepochs=hps['nepochs'],
|
||||
nminibatches=hps['nminibatches'],
|
||||
lr=hps['lr'],
|
||||
cliprange=0.1,
|
||||
nsteps_per_seg=hps['nsteps_per_seg'],
|
||||
nsegs_per_env=hps['nsegs_per_env'],
|
||||
ent_coef=hps['ent_coeff'],
|
||||
normrew=hps['norm_rew'],
|
||||
normadv=hps['norm_adv'],
|
||||
ext_coeff=hps['ext_coeff'],
|
||||
int_coeff=hps['int_coeff'],
|
||||
dynamic_bottleneck=self.dynamic_bottleneck
|
||||
)
|
||||
|
||||
self.agent.to_report['db'] = tf.reduce_mean(self.dynamic_bottleneck.loss)
|
||||
self.agent.total_loss += self.agent.to_report['db']
|
||||
|
||||
self.agent.db_loss = tf.reduce_mean(self.dynamic_bottleneck.loss)
|
||||
|
||||
self.agent.to_report['feat_var'] = tf.reduce_mean(tf.nn.moments(self.dynamic_bottleneck.features, [0, 1])[1])
|
||||
|
||||
def _set_env_vars(self):
|
||||
env = self.make_env(0, add_monitor=False)
|
||||
# ob_space.shape=(84, 84, 4) ac_space.shape=Discrete(4)
|
||||
self.ob_space, self.ac_space = env.observation_space, env.action_space
|
||||
self.ob_mean, self.ob_std = random_agent_ob_mean_std(env)
|
||||
del env
|
||||
self.envs = [functools.partial(self.make_env, i) for i in range(self.envs_per_process)]
|
||||
|
||||
def train(self, saver, logger_dir):
|
||||
self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamic_bottleneck=self.dynamic_bottleneck)
|
||||
previous_saved_tcount = 0
|
||||
|
||||
# add bai. initialize IB parameters
|
||||
print("***Init Momentum Network in Dynamic-Bottleneck.")
|
||||
getsess().run(self.dynamic_bottleneck.init_updates)
|
||||
|
||||
while True:
|
||||
info = self.agent.step() #
|
||||
if info['DB_loss_info']: # add bai. for debug
|
||||
logger.logkvs(info['DB_loss_info'])
|
||||
if info['update']:
|
||||
logger.logkvs(info['update'])
|
||||
logger.dumpkvs()
|
||||
if self.hps["save_period"] and (int(self.agent.rollout.stats['tcount'] / self.hps["save_freq"]) > previous_saved_tcount):
|
||||
previous_saved_tcount += 1
|
||||
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
|
||||
print("Periodically model saved in path:", save_path)
|
||||
if self.agent.rollout.stats['tcount'] > self.num_timesteps:
|
||||
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
|
||||
print("Model saved in path:", save_path)
|
||||
break
|
||||
|
||||
self.agent.stop_interaction()
|
||||
|
||||
|
||||
def make_env_all_params(rank, add_monitor, args):
|
||||
if args["env_kind"] == 'atari':
|
||||
env = gym.make(args['env'])
|
||||
assert 'NoFrameskip' in env.spec.id
|
||||
if args["stickyAtari"]: #
|
||||
env._max_episode_steps = args['max_episode_steps'] * 4
|
||||
env = StickyActionEnv(env)
|
||||
else:
|
||||
env = NoopResetEnv(env, noop_max=args['noop_max'])
|
||||
env = MaxAndSkipEnv(env, skip=4) #
|
||||
if args['pixelNoise']: # add pixel noise
|
||||
env = PixelNoiseWrapper(env)
|
||||
if args['randomBoxNoise']:
|
||||
env = RandomBoxNoiseWrapper(env)
|
||||
env = ProcessFrame84(env, crop=False) #
|
||||
env = FrameStack(env, 4) #
|
||||
# env = ExtraTimeLimit(env, args['max_episode_steps'])
|
||||
if not args["stickyAtari"]:
|
||||
env = ExtraTimeLimit(env, args['max_episode_steps']) #
|
||||
if 'Montezuma' in args['env']: #
|
||||
env = MontezumaInfoWrapper(env)
|
||||
env = AddRandomStateToInfo(env)
|
||||
elif args["env_kind"] == 'mario': #
|
||||
env = make_mario_env()
|
||||
elif args["env_kind"] == "retro_multi": #
|
||||
env = make_multi_pong()
|
||||
elif args["env_kind"] == 'robopong':
|
||||
if args["env"] == "pong":
|
||||
env = make_robo_pong()
|
||||
elif args["env"] == "hockey":
|
||||
env = make_robo_hockey()
|
||||
|
||||
if add_monitor:
|
||||
env = Monitor(env, osp.join(logger.get_dir(), '%.2i' % rank))
|
||||
return env
|
||||
|
||||
|
||||
def get_experiment_environment(**args):
|
||||
from utils import setup_mpi_gpus, setup_tensorflow_session
|
||||
from baselines.common import set_global_seeds
|
||||
from gym.utils.seeding import hash_seed
|
||||
process_seed = args["seed"] + 1000 * MPI.COMM_WORLD.Get_rank()
|
||||
process_seed = hash_seed(process_seed, max_bytes=4)
|
||||
set_global_seeds(process_seed)
|
||||
setup_mpi_gpus()
|
||||
|
||||
# log dir name
|
||||
logger_dir = './logs/' + args["env"].replace("NoFrameskip-v4", "")
|
||||
# logger_dir += "-KLloss-"+str(args["loss_kl_weight"])
|
||||
# logger_dir += "-NCEloss-" + str(args["loss_nce_weight"])
|
||||
# logger_dir += "-L2loss-" + str(args["loss_l2_weight"])
|
||||
if args['pixelNoise'] is True:
|
||||
logger_dir += "-pixelNoise"
|
||||
if args['randomBoxNoise'] is True:
|
||||
logger_dir += "-randomBoxNoise"
|
||||
if args['stickyAtari'] is True:
|
||||
logger_dir += "-stickyAtari"
|
||||
if args["comments"] != "":
|
||||
logger_dir += '-' + args["comments"]
|
||||
logger_dir += datetime.datetime.now().strftime("-%m-%d-%H-%M-%S")
|
||||
|
||||
# write config
|
||||
logger.configure(dir=logger_dir)
|
||||
with open(os.path.join(logger_dir, 'parameters.txt'), 'w') as f:
|
||||
f.write("\n".join([str(x[0]) + ": " + str(x[1]) for x in args.items()]))
|
||||
|
||||
logger_context = logger.scoped_configure(
|
||||
dir=logger_dir,
|
||||
format_strs=['stdout', 'log', 'csv'] if MPI.COMM_WORLD.Get_rank() == 0 else ['log'])
|
||||
tf_context = setup_tensorflow_session()
|
||||
|
||||
# saver
|
||||
saver = tf.train.Saver()
|
||||
return logger_context, tf_context, saver, logger_dir
|
||||
|
||||
|
||||
def add_environments_params(parser):
|
||||
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4', type=str)
|
||||
parser.add_argument('--max-episode-steps', help='maximum number of timesteps for episode', default=4500, type=int)
|
||||
parser.add_argument('--env_kind', type=str, default="atari")
|
||||
parser.add_argument('--noop_max', type=int, default=30)
|
||||
parser.add_argument('--stickyAtari', action='store_true', default=False)
|
||||
parser.add_argument('--pixelNoise', action='store_true', default=False)
|
||||
parser.add_argument('--randomBoxNoise', action='store_true', default=False)
|
||||
|
||||
|
||||
def add_optimization_params(parser):
|
||||
parser.add_argument('--lambda', type=float, default=0.95)
|
||||
parser.add_argument('--gamma', type=float, default=0.99) # lambda, gamma 用于计算 GAE advantage
|
||||
parser.add_argument('--nminibatches', type=int, default=8)
|
||||
parser.add_argument('--norm_adv', type=int, default=1) #
|
||||
parser.add_argument('--norm_rew', type=int, default=1) #
|
||||
parser.add_argument('--lr', type=float, default=1e-4) #
|
||||
parser.add_argument('--ent_coeff', type=float, default=0.001) #
|
||||
parser.add_argument('--nepochs', type=int, default=3) #
|
||||
parser.add_argument('--num_timesteps', type=int, default=int(1e8))
|
||||
parser.add_argument('--save_period', action='store_true', default=False) # 1e7
|
||||
parser.add_argument('--save_freq', type=int, default=int(1e7)) # 1e7
|
||||
# Parameters of Dynamic-Bottleneck
|
||||
parser.add_argument('--loss_kl_weight', type=float, default=0.1) # KL loss weight
|
||||
parser.add_argument('--loss_l2_weight', type=float, default=0.1) # l2 loss weight
|
||||
parser.add_argument('--loss_nce_weight', type=float, default=0.01) # nce loss weight
|
||||
parser.add_argument('--momentum_tau', type=float, default=0.001) # momentum tau
|
||||
parser.add_argument('--aug', action='store_true', default=False) # data augmentation (bottleneck)
|
||||
parser.add_argument('--comments', type=str, default="")
|
||||
|
||||
|
||||
def add_rollout_params(parser):
|
||||
parser.add_argument('--nsteps_per_seg', type=int, default=128)
|
||||
parser.add_argument('--nsegs_per_env', type=int, default=1)
|
||||
parser.add_argument('--envs_per_process', type=int, default=128)
|
||||
parser.add_argument('--nlumps', type=int, default=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
add_environments_params(parser)
|
||||
add_optimization_params(parser)
|
||||
add_rollout_params(parser)
|
||||
|
||||
parser.add_argument('--exp_name', type=str, default='')
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||
parser.add_argument('--dyn_from_pixels', type=int, default=0)
|
||||
parser.add_argument('--use_news', type=int, default=0)
|
||||
parser.add_argument('--ext_coeff', type=float, default=0.)
|
||||
parser.add_argument('--int_coeff', type=float, default=1.)
|
||||
parser.add_argument('--layernorm', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# load paramets
|
||||
with open("para.json") as f:
|
||||
d = json.load(f)
|
||||
env_name_para = args.env.replace("NoFrameskip-v4", "")
|
||||
if env_name_para not in list(d["standard"].keys()):
|
||||
env_name_para = "other"
|
||||
|
||||
if args.pixelNoise is True:
|
||||
print("pixel noise")
|
||||
args.loss_kl_weight = d["pixelNoise"][env_name_para]["kl"]
|
||||
args.loss_nce_weight = d["pixelNoise"][env_name_para]["nce"]
|
||||
elif args.randomBoxNoise is True:
|
||||
print("random box noise")
|
||||
args.loss_kl_weight = d["randomBox"][env_name_para]["kl"]
|
||||
args.loss_nce_weight = d["randomBox"][env_name_para]["nce"]
|
||||
elif args.stickyAtari is True:
|
||||
print("sticky noise")
|
||||
args.loss_kl_weight = d["stickyAtari"][env_name_para]["kl"]
|
||||
args.loss_nce_weight = d["stickyAtari"][env_name_para]["nce"]
|
||||
else:
|
||||
print("standard atari")
|
||||
args.loss_kl_weight = d["standard"][env_name_para]["kl"]
|
||||
args.loss_nce_weight = d["standard"][env_name_para]["nce"]
|
||||
|
||||
print("env_name:", env_name_para, "kl:", args.loss_kl_weight, ", nce:", args.loss_nce_weight)
|
||||
start_experiment(**args.__dict__)
|
||||
|
430
utils.py
Normal file
430
utils.py
Normal file
@ -0,0 +1,430 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines.common.tf_util import normc_initializer
|
||||
from mpi4py import MPI
|
||||
import tensorflow_probability as tfp
|
||||
import os
|
||||
import numpy as np
|
||||
tfd = tfp.distributions
|
||||
|
||||
layers = tf.keras.layers
|
||||
|
||||
|
||||
def bcast_tf_vars_from_root(sess, vars):
|
||||
"""
|
||||
Send the root node's parameters to every worker.
|
||||
|
||||
Arguments:
|
||||
sess: the TensorFlow session.
|
||||
vars: all parameter variables including optimizer's
|
||||
"""
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
for var in vars:
|
||||
if rank == 0:
|
||||
MPI.COMM_WORLD.bcast(sess.run(var))
|
||||
else:
|
||||
sess.run(tf.assign(var, MPI.COMM_WORLD.bcast(None)))
|
||||
|
||||
|
||||
def get_mean_and_std(array):
|
||||
comm = MPI.COMM_WORLD
|
||||
task_id, num_tasks = comm.Get_rank(), comm.Get_size()
|
||||
local_mean = np.array(np.mean(array))
|
||||
sum_of_means = np.zeros((), dtype=np.float32)
|
||||
comm.Allreduce(local_mean, sum_of_means, op=MPI.SUM)
|
||||
mean = sum_of_means / num_tasks
|
||||
|
||||
n_array = array - mean
|
||||
sqs = n_array ** 2
|
||||
local_mean = np.array(np.mean(sqs))
|
||||
sum_of_means = np.zeros((), dtype=np.float32)
|
||||
comm.Allreduce(local_mean, sum_of_means, op=MPI.SUM)
|
||||
var = sum_of_means / num_tasks
|
||||
std = var ** 0.5
|
||||
return mean, std
|
||||
|
||||
|
||||
def guess_available_gpus(n_gpus=None):
|
||||
if n_gpus is not None:
|
||||
return list(range(n_gpus))
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
cuda_visible_divices = os.environ['CUDA_VISIBLE_DEVICES']
|
||||
cuda_visible_divices = cuda_visible_divices.split(',')
|
||||
return [int(n) for n in cuda_visible_divices]
|
||||
nvidia_dir = '/proc/driver/nvidia/gpus/'
|
||||
if os.path.exists(nvidia_dir):
|
||||
n_gpus = len(os.listdir(nvidia_dir))
|
||||
return list(range(n_gpus))
|
||||
raise Exception("Couldn't guess the available gpus on this machine")
|
||||
|
||||
|
||||
def setup_mpi_gpus():
|
||||
"""
|
||||
Set CUDA_VISIBLE_DEVICES using MPI.
|
||||
"""
|
||||
available_gpus = guess_available_gpus()
|
||||
|
||||
node_id = platform.node()
|
||||
nodes_ordered_by_rank = MPI.COMM_WORLD.allgather(node_id)
|
||||
processes_outranked_on_this_node = [n for n in nodes_ordered_by_rank[:MPI.COMM_WORLD.Get_rank()] if n == node_id]
|
||||
local_rank = len(processes_outranked_on_this_node)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(available_gpus[local_rank])
|
||||
|
||||
|
||||
def guess_available_cpus():
|
||||
return int(multiprocessing.cpu_count())
|
||||
|
||||
|
||||
def setup_tensorflow_session():
|
||||
num_cpu = guess_available_cpus()
|
||||
|
||||
tf_config = tf.ConfigProto(
|
||||
inter_op_parallelism_threads=num_cpu,
|
||||
intra_op_parallelism_threads=num_cpu
|
||||
)
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
return tf.Session(config=tf_config)
|
||||
|
||||
|
||||
def random_agent_ob_mean_std(env, nsteps=10000):
|
||||
ob = np.asarray(env.reset())
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
obs = [ob]
|
||||
for _ in range(nsteps):
|
||||
ac = env.action_space.sample()
|
||||
ob, _, done, _ = env.step(ac)
|
||||
if done:
|
||||
ob = env.reset()
|
||||
obs.append(np.asarray(ob))
|
||||
mean = np.mean(obs, 0).astype(np.float32)
|
||||
std = np.std(obs, 0).mean().astype(np.float32)
|
||||
else:
|
||||
mean = np.empty(shape=ob.shape, dtype=np.float32)
|
||||
std = np.empty(shape=(), dtype=np.float32)
|
||||
MPI.COMM_WORLD.Bcast(mean, root=0)
|
||||
MPI.COMM_WORLD.Bcast(std, root=0)
|
||||
return mean, std
|
||||
|
||||
|
||||
def layernorm(x):
|
||||
m, v = tf.nn.moments(x, -1, keep_dims=True)
|
||||
return (x - m) / (tf.sqrt(v) + 1e-8)
|
||||
|
||||
|
||||
getsess = tf.get_default_session
|
||||
|
||||
fc = partial(tf.layers.dense, kernel_initializer=normc_initializer(1.))
|
||||
activ = tf.nn.relu
|
||||
|
||||
|
||||
def flatten_two_dims(x):
|
||||
return tf.reshape(x, [-1] + x.get_shape().as_list()[2:])
|
||||
|
||||
|
||||
def unflatten_first_dim(x, sh):
|
||||
return tf.reshape(x, [sh[0], sh[1]] + x.get_shape().as_list()[1:])
|
||||
|
||||
|
||||
def add_pos_bias(x):
|
||||
with tf.variable_scope(name_or_scope=None, default_name="pos_bias"):
|
||||
b = tf.get_variable(name="pos_bias", shape=[1] + x.get_shape().as_list()[1:], dtype=tf.float32,
|
||||
initializer=tf.zeros_initializer())
|
||||
return x + b
|
||||
|
||||
|
||||
def small_convnet(x, nl, feat_dim, last_nl, layernormalize, batchnorm=False):
|
||||
# nl=512, feat_dim=None, last_nl=0, layernormalize=0, batchnorm=False
|
||||
bn = tf.layers.batch_normalization if batchnorm else lambda x: x
|
||||
x = bn(tf.layers.conv2d(x, filters=32, kernel_size=8, strides=(4, 4), activation=nl))
|
||||
x = bn(tf.layers.conv2d(x, filters=64, kernel_size=4, strides=(2, 2), activation=nl))
|
||||
x = bn(tf.layers.conv2d(x, filters=64, kernel_size=3, strides=(1, 1), activation=nl))
|
||||
x = tf.reshape(x, (-1, np.prod(x.get_shape().as_list()[1:])))
|
||||
x = bn(fc(x, units=feat_dim, activation=None))
|
||||
if last_nl is not None:
|
||||
x = last_nl(x)
|
||||
if layernormalize:
|
||||
x = layernorm(x)
|
||||
return x
|
||||
|
||||
|
||||
# new add
|
||||
class SmallConv(tf.keras.Model):
|
||||
def __init__(self, feat_dim, name=None):
|
||||
super(SmallConv, self).__init__(name=name)
|
||||
self.conv1 = layers.Conv2D(filters=32, kernel_size=8, strides=(4, 4), activation=tf.nn.leaky_relu)
|
||||
self.conv2 = layers.Conv2D(filters=64, kernel_size=4, strides=(2, 2), activation=tf.nn.leaky_relu)
|
||||
self.conv3 = layers.Conv2D(filters=64, kernel_size=3, strides=(1, 1), activation=tf.nn.leaky_relu)
|
||||
self.fc = layers.Dense(units=feat_dim, activation=None)
|
||||
|
||||
def call(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = tf.reshape(x, (-1, np.prod(x.get_shape().as_list()[1:])))
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
# new add
|
||||
class ResBlock(tf.keras.Model):
|
||||
def __init__(self, hidsize):
|
||||
super(ResBlock, self).__init__()
|
||||
self.hidsize = hidsize
|
||||
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||
self.dense2 = layers.Dense(hidsize, activation=None)
|
||||
|
||||
def call(self, xs):
|
||||
x, a = xs
|
||||
res = self.dense1(tf.concat([x, a], axis=-1))
|
||||
res = self.dense2(tf.concat([res, a], axis=-1))
|
||||
assert x.get_shape().as_list()[-1] == self.hidsize and res.get_shape().as_list()[-1] == self.hidsize
|
||||
return x + res
|
||||
|
||||
|
||||
# new add
|
||||
class TransitionNetwork(tf.keras.Model):
|
||||
def __init__(self, hidsize=256, name=None):
|
||||
super(TransitionNetwork, self).__init__(name=name)
|
||||
self.hidsize = hidsize
|
||||
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||
self.residual_block1 = ResBlock(hidsize)
|
||||
self.residual_block2 = ResBlock(hidsize)
|
||||
self.dense2 = layers.Dense(hidsize, activation=None)
|
||||
|
||||
def call(self, xs):
|
||||
s, a = xs
|
||||
sh = tf.shape(a) # sh=(None,None,4)
|
||||
assert len(s.get_shape().as_list()) == 3 and s.get_shape().as_list()[-1] in [512, 256]
|
||||
assert len(a.get_shape().as_list()) == 3
|
||||
|
||||
x = flatten_two_dims(s) # shape=(None,512)
|
||||
a = flatten_two_dims(a) # shape=(None,4)
|
||||
|
||||
#
|
||||
x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256)
|
||||
x = self.residual_block1([x, a]) # (None, 256)
|
||||
x = self.residual_block2([x, a]) # (None, 256)
|
||||
x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256)
|
||||
x = unflatten_first_dim(x, sh) # shape=(None, None, 256)
|
||||
return x
|
||||
|
||||
|
||||
class GenerativeNetworkGaussianFix(tf.keras.Model):
|
||||
def __init__(self, hidsize=256, outsize=512, name=None):
|
||||
super(GenerativeNetworkGaussianFix, self).__init__(name=name)
|
||||
self.outsize = outsize
|
||||
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||
self.dense2 = layers.Dense(outsize, activation=tf.nn.leaky_relu)
|
||||
self.var_single = tf.Variable(1.0, trainable=True)
|
||||
|
||||
self.residual_block1 = tf.keras.Sequential([
|
||||
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||
layers.Dense(hidsize, activation=None)
|
||||
])
|
||||
self.residual_block2 = tf.keras.Sequential([
|
||||
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||
layers.Dense(hidsize, activation=None)
|
||||
])
|
||||
self.residual_block3 = tf.keras.Sequential([
|
||||
layers.Dense(outsize, activation=tf.nn.leaky_relu), # 512
|
||||
layers.Dense(outsize, activation=None)
|
||||
])
|
||||
|
||||
def call(self, z):
|
||||
sh = tf.shape(z) # z, sh=(None,None,128)
|
||||
assert z.get_shape().as_list()[-1] == 128 and len(z.get_shape().as_list()) == 3
|
||||
z = flatten_two_dims(z) # shape=(None,128)
|
||||
|
||||
x = self.dense1(z) # (None, 256)
|
||||
x = x + self.residual_block1(x) # (None, 256)
|
||||
x = x + self.residual_block2(x) # (None, 256)
|
||||
|
||||
# variance
|
||||
var_tile = tf.tile(tf.expand_dims(tf.expand_dims(self.var_single, axis=0), axis=0), [16*128, self.outsize])
|
||||
|
||||
# mean
|
||||
x = self.dense2(x) # (None, 512)
|
||||
x = x + self.residual_block3(x) # (None, 512) mean
|
||||
|
||||
# concat and return
|
||||
x = tf.concat([x, var_tile], axis=-1) # (None, 1024)
|
||||
x = unflatten_first_dim(x, sh) # shape=(None, None, 1024)
|
||||
return x
|
||||
|
||||
|
||||
class GenerativeNetworkGaussian(tf.keras.Model):
|
||||
def __init__(self, hidsize=256, outsize=512, name=None):
|
||||
super(GenerativeNetworkGaussian, self).__init__(name=name)
|
||||
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||
self.dense2 = layers.Dense(outsize, activation=tf.nn.leaky_relu)
|
||||
self.dense3 = layers.Dense(outsize*2, activation=tf.nn.leaky_relu)
|
||||
|
||||
self.residual_block1 = tf.keras.Sequential([
|
||||
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||
layers.Dense(hidsize, activation=None)
|
||||
])
|
||||
self.residual_block2 = tf.keras.Sequential([
|
||||
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||
layers.Dense(hidsize, activation=None)
|
||||
])
|
||||
self.residual_block3 = tf.keras.Sequential([
|
||||
layers.Dense(outsize, activation=tf.nn.leaky_relu), # 512
|
||||
layers.Dense(outsize, activation=None)
|
||||
])
|
||||
|
||||
def call(self, z):
|
||||
sh = tf.shape(z) # z, sh=(None,None,128)
|
||||
assert z.get_shape().as_list()[-1] == 128 and len(z.get_shape().as_list()) == 3
|
||||
z = flatten_two_dims(z) # shape=(None,128)
|
||||
|
||||
x = self.dense1(z) # (None, 256)
|
||||
x = x + self.residual_block1(x) # (None, 256)
|
||||
x = x + self.residual_block2(x) # (None, 256)
|
||||
x = self.dense2(x) # (None, 512)
|
||||
x = x + self.residual_block3(x) # (None, 512)
|
||||
x = self.dense3(x) # (None, 1024)
|
||||
x = unflatten_first_dim(x, sh) # shape=(None, None, 1024)
|
||||
return x
|
||||
|
||||
|
||||
class ProjectionHead(tf.keras.Model):
|
||||
def __init__(self, name=None):
|
||||
super(ProjectionHead, self).__init__(name=name)
|
||||
self.dense1 = layers.Dense(256, activation=None)
|
||||
self.dense2 = layers.Dense(128, activation=None)
|
||||
self.ln1 = layers.LayerNormalization()
|
||||
self.ln2 = layers.LayerNormalization()
|
||||
|
||||
def call(self, x, ln=False):
|
||||
assert x.get_shape().as_list()[-1] == 512 and len(x.get_shape().as_list()) == 3
|
||||
x = flatten_two_dims(x) # shape=(None,512)
|
||||
x = self.dense1(x) # shape=(None,256)
|
||||
x = self.ln1(x) # layer norm
|
||||
x = tf.nn.relu(x) # relu
|
||||
x = self.dense2(x) # shape=(None,128)
|
||||
x = self.ln2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ContrastiveHead(tf.keras.Model):
|
||||
def __init__(self, temperature, z_dim=128, name=None):
|
||||
super(ContrastiveHead, self).__init__(name=name)
|
||||
self.W = tf.Variable(tf.random.uniform((z_dim, z_dim)), name='W_Contras')
|
||||
self.temperature = temperature
|
||||
|
||||
def call(self, z_a_pos):
|
||||
z_a, z_pos = z_a_pos
|
||||
Wz = tf.linalg.matmul(self.W, z_pos, transpose_b=True) # (z_dim,B) Wz.shape = (50,32)
|
||||
logits = tf.linalg.matmul(z_a, Wz) # (B,B) logits.shape = (32,32)
|
||||
logits = logits - tf.reduce_max(logits, 1)[:, None] # logits
|
||||
logits = logits * self.temperature
|
||||
return logits
|
||||
|
||||
|
||||
def rec_log_prob(rec_params, s_next, min_sigma=1e-2):
|
||||
# rec_params.shape = (None, None, 1024)
|
||||
distr = normal_parse_params(rec_params, min_sigma)
|
||||
log_prob = distr.log_prob(s_next) # (None, None, 512)
|
||||
assert len(log_prob.get_shape().as_list()) == 3 and log_prob.get_shape().as_list()[-1] == 512
|
||||
return tf.reduce_sum(log_prob, axis=-1)
|
||||
|
||||
|
||||
def normal_parse_params(params, min_sigma=0.0):
|
||||
n = params.shape[0]
|
||||
d = params.shape[-1] # channel
|
||||
mu = params[..., :d // 2] #
|
||||
sigma_params = params[..., d // 2:]
|
||||
sigma = tf.math.softplus(sigma_params)
|
||||
sigma = tf.clip_by_value(t=sigma, clip_value_min=min_sigma, clip_value_max=1e5)
|
||||
|
||||
distr = tfd.Normal(loc=mu, scale=sigma) #
|
||||
return distr
|
||||
|
||||
|
||||
def tile_images(array, n_cols=None, max_images=None, div=1):
|
||||
if max_images is not None:
|
||||
array = array[:max_images]
|
||||
if len(array.shape) == 4 and array.shape[3] == 1:
|
||||
array = array[:, :, :, 0]
|
||||
assert len(array.shape) in [3, 4], "wrong number of dimensions - shape {}".format(array.shape)
|
||||
if len(array.shape) == 4:
|
||||
assert array.shape[3] == 3, "wrong number of channels- shape {}".format(array.shape)
|
||||
if n_cols is None:
|
||||
n_cols = max(int(np.sqrt(array.shape[0])) // div * div, div)
|
||||
n_rows = int(np.ceil(float(array.shape[0]) / n_cols))
|
||||
|
||||
def cell(i, j):
|
||||
ind = i * n_cols + j
|
||||
return array[ind] if ind < array.shape[0] else np.zeros(array[0].shape)
|
||||
|
||||
def row(i):
|
||||
return np.concatenate([cell(i, j) for j in range(n_cols)], axis=1)
|
||||
|
||||
return np.concatenate([row(i) for i in range(n_rows)], axis=0)
|
||||
|
||||
|
||||
|
||||
import distutils.spawn
|
||||
import subprocess
|
||||
|
||||
|
||||
def save_np_as_mp4(frames, filename, frames_per_sec=30):
|
||||
print(filename)
|
||||
if distutils.spawn.find_executable('avconv') is not None:
|
||||
backend = 'avconv'
|
||||
elif distutils.spawn.find_executable('ffmpeg') is not None:
|
||||
backend = 'ffmpeg'
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"""Found neither the ffmpeg nor avconv executables. On OS X, you can install ffmpeg via `brew install ffmpeg`. On most Ubuntu variants, `sudo apt-get install ffmpeg` should do it. On Ubuntu 14.04, however, you'll need to install avconv with `sudo apt-get install libav-tools`.""")
|
||||
|
||||
h, w = frames[0].shape[:2]
|
||||
output_path = filename
|
||||
cmdline = (backend,
|
||||
'-nostats',
|
||||
'-loglevel', 'error', # suppress warnings
|
||||
'-y',
|
||||
'-r', '%d' % frames_per_sec,
|
||||
|
||||
# input
|
||||
'-f', 'rawvideo',
|
||||
'-s:v', '{}x{}'.format(w, h),
|
||||
'-pix_fmt', 'rgb24',
|
||||
'-i', '-', # this used to be /dev/stdin, which is not Windows-friendly
|
||||
|
||||
# output
|
||||
'-vcodec', 'libx264',
|
||||
'-pix_fmt', 'yuv420p',
|
||||
output_path)
|
||||
|
||||
print('saving ', output_path)
|
||||
if hasattr(os, 'setsid'): # setsid not present on Windows
|
||||
process = subprocess.Popen(cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
|
||||
else:
|
||||
process = subprocess.Popen(cmdline, stdin=subprocess.PIPE)
|
||||
process.stdin.write(np.array(frames).tobytes())
|
||||
process.stdin.close()
|
||||
ret = process.wait()
|
||||
if ret != 0:
|
||||
print("VideoRecorder encoder exited with status {}".format(ret))
|
||||
|
||||
|
||||
# ExponentialSchedule
|
||||
class ExponentialSchedule(object):
|
||||
def __init__(self, start_value, decay_factor, end_value, outside_value=None):
|
||||
"""Exponential Schedule.
|
||||
y = start_value * (1.0 - decay_factor) ^ t
|
||||
"""
|
||||
assert 0.0 <= decay_factor <= 1.0
|
||||
self.start_value = start_value
|
||||
self.decay_factor = decay_factor
|
||||
self.end_value = end_value
|
||||
|
||||
def value(self, t):
|
||||
v = self.start_value * np.power(1.0 - self.decay_factor, t/int(1e5))
|
||||
return np.maximum(v, self.end_value)
|
222
vec_env.py
Normal file
222
vec_env.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""
|
||||
An interface for asynchronous vectorized environments.
|
||||
"""
|
||||
|
||||
|
||||
import ctypes
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Pipe, Array, Process
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from baselines import logger
|
||||
|
||||
_NP_TO_CT = {np.float32: ctypes.c_float,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.uint8: ctypes.c_char,
|
||||
np.bool: ctypes.c_bool}
|
||||
_CT_TO_NP = {v: k for k, v in _NP_TO_CT.items()}
|
||||
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.x)
|
||||
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
|
||||
|
||||
class VecEnv(ABC):
|
||||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
"""
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all the environments and return an array of
|
||||
observations, or a tuple of observation arrays.
|
||||
|
||||
If step_async is still doing work, that work will
|
||||
be cancelled and step_wait() should not be called
|
||||
until step_async() is invoked again.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_async(self, actions):
|
||||
"""
|
||||
Tell all the environments to start taking a step
|
||||
with the given actions.
|
||||
Call step_wait() to get the results of the step.
|
||||
|
||||
You should not call this if a step_async run is
|
||||
already pending.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
"""
|
||||
Wait for the step taken with step_async().
|
||||
|
||||
Returns (obs, rews, dones, infos):
|
||||
- obs: an array of observations, or a tuple of
|
||||
arrays of observations.
|
||||
- rews: an array of rewards
|
||||
- dones: an array of "episode done" booleans
|
||||
- infos: a sequence of info objects
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
"""
|
||||
Clean up the environments' resources.
|
||||
"""
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def render(self):
|
||||
logger.warn('Render not defined for %s' % self)
|
||||
|
||||
|
||||
class ShmemVecEnv(VecEnv):
|
||||
"""
|
||||
An AsyncEnv that uses multiprocessing to run multiple
|
||||
environments in parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns, spaces=None):
|
||||
"""
|
||||
If you don't specify observation_space, we'll have to create a dummy
|
||||
environment to get it.
|
||||
"""
|
||||
if spaces:
|
||||
observation_space, action_space = spaces
|
||||
else:
|
||||
logger.log('Creating dummy env object to get spaces')
|
||||
with logger.scoped_configure(format_strs=[]):
|
||||
dummy = env_fns[0]()
|
||||
observation_space, action_space = dummy.observation_space, dummy.action_space
|
||||
dummy.close()
|
||||
del dummy
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||
|
||||
obs_spaces = observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (
|
||||
self.observation_space,)
|
||||
self.obs_bufs = [tuple(Array(_NP_TO_CT[s.dtype.type], int(np.prod(s.shape))) for s in obs_spaces) for _ in
|
||||
env_fns]
|
||||
self.obs_shapes = [s.shape for s in obs_spaces]
|
||||
self.obs_dtypes = [s.dtype for s in obs_spaces]
|
||||
|
||||
self.parent_pipes = []
|
||||
self.procs = []
|
||||
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
|
||||
wrapped_fn = CloudpickleWrapper(env_fn)
|
||||
parent_pipe, child_pipe = Pipe()
|
||||
proc = Process(target=_subproc_worker,
|
||||
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes))
|
||||
proc.daemon = True
|
||||
self.procs.append(proc)
|
||||
self.parent_pipes.append(parent_pipe)
|
||||
proc.start()
|
||||
child_pipe.close()
|
||||
self.waiting_step = False
|
||||
|
||||
def reset(self):
|
||||
if self.waiting_step:
|
||||
logger.warn('Called reset() while waiting for the step to complete')
|
||||
self.step_wait()
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(('reset', None))
|
||||
return self._decode_obses([pipe.recv() for pipe in self.parent_pipes])
|
||||
|
||||
def step_async(self, actions):
|
||||
assert len(actions) == len(self.parent_pipes)
|
||||
for pipe, act in zip(self.parent_pipes, actions):
|
||||
pipe.send(('step', act))
|
||||
|
||||
def step_wait(self):
|
||||
outs = [pipe.recv() for pipe in self.parent_pipes]
|
||||
obs, rews, dones, infos = zip(*outs)
|
||||
return self._decode_obses(obs), np.array(rews), np.array(dones), infos
|
||||
|
||||
def close(self):
|
||||
if self.waiting_step:
|
||||
self.step_wait()
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(('close', None))
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.recv()
|
||||
pipe.close()
|
||||
for proc in self.procs:
|
||||
proc.join()
|
||||
|
||||
def _decode_obses(self, obs):
|
||||
"""
|
||||
Turn the observation responses into a single numpy
|
||||
array, possibly via shared memory.
|
||||
"""
|
||||
obs = []
|
||||
for i, shape in enumerate(self.obs_shapes):
|
||||
bufs = [b[i] for b in self.obs_bufs]
|
||||
o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[i]).reshape(shape) for b in bufs]
|
||||
obs.append(np.array(o))
|
||||
return tuple(obs) if len(obs) > 1 else obs[0]
|
||||
|
||||
|
||||
def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_buf, obs_shape):
|
||||
"""
|
||||
Control a single environment instance using IPC and
|
||||
shared memory.
|
||||
|
||||
If obs_buf is not None, it is a shared-memory buffer
|
||||
for communicating observations.
|
||||
"""
|
||||
|
||||
def _write_obs(obs):
|
||||
if not isinstance(obs, tuple):
|
||||
obs = (obs,)
|
||||
for o, b, s in zip(obs, obs_buf, obs_shape):
|
||||
dst = b.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=_CT_TO_NP[dst._type_]).reshape(s) # pylint: disable=W0212
|
||||
np.copyto(dst_np, o)
|
||||
|
||||
env = env_fn_wrapper.x()
|
||||
parent_pipe.close()
|
||||
try:
|
||||
while True:
|
||||
cmd, data = pipe.recv()
|
||||
if cmd == 'reset':
|
||||
pipe.send(_write_obs(env.reset()))
|
||||
elif cmd == 'step':
|
||||
obs, reward, done, info = env.step(data)
|
||||
if done:
|
||||
obs = env.reset()
|
||||
pipe.send((_write_obs(obs), reward, done, info))
|
||||
elif cmd == 'close':
|
||||
pipe.send(None)
|
||||
break
|
||||
else:
|
||||
raise RuntimeError('Got unrecognized cmd %s' % cmd)
|
||||
finally:
|
||||
env.close()
|
475
wrappers.py
Normal file
475
wrappers.py
Normal file
@ -0,0 +1,475 @@
|
||||
import itertools
|
||||
from collections import deque
|
||||
from copy import copy
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import random
|
||||
|
||||
|
||||
def unwrap(env):
|
||||
if hasattr(env, "unwrapped"):
|
||||
return env.unwrapped
|
||||
elif hasattr(env, "env"):
|
||||
return unwrap(env.env)
|
||||
elif hasattr(env, "leg_env"):
|
||||
return unwrap(env.leg_env)
|
||||
else:
|
||||
return env
|
||||
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action):
|
||||
"""Repeat action, sum reward, and max over last observations."""
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
acc_info = {}
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
acc_info.update(info)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, acc_info
|
||||
|
||||
def reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
|
||||
class ProcessFrame84(gym.ObservationWrapper):
|
||||
def __init__(self, env, crop=True):
|
||||
self.crop = crop
|
||||
super(ProcessFrame84, self).__init__(env)
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||||
|
||||
def observation(self, obs):
|
||||
return ProcessFrame84.process(obs, crop=self.crop)
|
||||
|
||||
@staticmethod
|
||||
def process(frame, crop=True):
|
||||
if frame.size == 210 * 160 * 3:
|
||||
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
|
||||
elif frame.size == 250 * 160 * 3:
|
||||
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
|
||||
elif frame.size == 224 * 240 * 3: # mario resolution
|
||||
img = np.reshape(frame, [224, 240, 3]).astype(np.float32)
|
||||
else:
|
||||
assert False, "Unknown resolution." + str(frame.size)
|
||||
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
|
||||
size = (84, 110 if crop else 84)
|
||||
resized_screen = np.array(Image.fromarray(img).resize(size,
|
||||
resample=Image.BILINEAR), dtype=np.uint8)
|
||||
x_t = resized_screen[18:102, :] if crop else resized_screen
|
||||
x_t = np.reshape(x_t, [84, 84, 1])
|
||||
return x_t.astype(np.uint8)
|
||||
|
||||
|
||||
class ExtraTimeLimit(gym.Wrapper):
|
||||
def __init__(self, env, max_episode_steps=None):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._elapsed_steps = 0
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
self._elapsed_steps += 1
|
||||
if self._elapsed_steps > self._max_episode_steps:
|
||||
done = True
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
self._elapsed_steps = 0
|
||||
return self.env.reset()
|
||||
|
||||
|
||||
class AddRandomStateToInfo(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Adds the random state to the info field on the first step after reset
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(self, action):
|
||||
ob, r, d, info = self.env.step(action)
|
||||
if self.random_state_copy is not None:
|
||||
info['random_state'] = self.random_state_copy
|
||||
self.random_state_copy = None
|
||||
return ob, r, d, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.random_state_copy = copy(self.unwrapped.np_random)
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
|
||||
class MontezumaInfoWrapper(gym.Wrapper):
|
||||
ram_map = {
|
||||
"room": dict(
|
||||
index=3,
|
||||
),
|
||||
"x": dict(
|
||||
index=42,
|
||||
),
|
||||
"y": dict(
|
||||
index=43,
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, env):
|
||||
super(MontezumaInfoWrapper, self).__init__(env)
|
||||
self.visited = set()
|
||||
self.visited_rooms = set()
|
||||
|
||||
def step(self, action):
|
||||
obs, rew, done, info = self.env.step(action)
|
||||
ram_state = unwrap(self.env).ale.getRAM()
|
||||
for name, properties in MontezumaInfoWrapper.ram_map.items():
|
||||
info[name] = ram_state[properties['index']]
|
||||
pos = (info['x'], info['y'], info['room'])
|
||||
self.visited.add(pos)
|
||||
self.visited_rooms.add(info["room"])
|
||||
if done:
|
||||
info['mz_episode'] = dict(pos_count=len(self.visited),
|
||||
visited_rooms=copy(self.visited_rooms))
|
||||
self.visited.clear()
|
||||
self.visited_rooms.clear()
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
|
||||
class MarioXReward(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.current_level = [0, 0]
|
||||
self.visited_levels = set()
|
||||
self.visited_levels.add(tuple(self.current_level))
|
||||
self.current_max_x = 0.
|
||||
|
||||
def reset(self):
|
||||
ob = self.env.reset()
|
||||
self.current_level = [0, 0]
|
||||
self.visited_levels = set()
|
||||
self.visited_levels.add(tuple(self.current_level))
|
||||
self.current_max_x = 0.
|
||||
return ob
|
||||
|
||||
def step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
levellow, levelhigh, xscrollHi, xscrollLo = \
|
||||
info["levelLo"], info["levelHi"], info["xscrollHi"], info["xscrollLo"]
|
||||
currentx = xscrollHi * 256 + xscrollLo
|
||||
new_level = [levellow, levelhigh]
|
||||
if new_level != self.current_level:
|
||||
self.current_level = new_level
|
||||
self.current_max_x = 0.
|
||||
reward = 0.
|
||||
self.visited_levels.add(tuple(self.current_level))
|
||||
else:
|
||||
if currentx > self.current_max_x:
|
||||
delta = currentx - self.current_max_x
|
||||
self.current_max_x = currentx
|
||||
reward = delta
|
||||
else:
|
||||
reward = 0.
|
||||
if done:
|
||||
info["levels"] = copy(self.visited_levels)
|
||||
info["retro_episode"] = dict(levels=copy(self.visited_levels))
|
||||
return ob, reward, done, info
|
||||
|
||||
|
||||
class LimitedDiscreteActions(gym.ActionWrapper):
|
||||
KNOWN_BUTTONS = {"A", "B"}
|
||||
KNOWN_SHOULDERS = {"L", "R"}
|
||||
|
||||
'''
|
||||
Reproduces the action space from curiosity paper.
|
||||
'''
|
||||
|
||||
def __init__(self, env, all_buttons, whitelist=KNOWN_BUTTONS | KNOWN_SHOULDERS):
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
|
||||
self._num_buttons = len(all_buttons)
|
||||
button_keys = {i for i in range(len(all_buttons)) if all_buttons[i] in whitelist & self.KNOWN_BUTTONS}
|
||||
buttons = [(), *zip(button_keys), *itertools.combinations(button_keys, 2)]
|
||||
shoulder_keys = {i for i in range(len(all_buttons)) if all_buttons[i] in whitelist & self.KNOWN_SHOULDERS}
|
||||
shoulders = [(), *zip(shoulder_keys), *itertools.permutations(shoulder_keys, 2)]
|
||||
arrows = [(), (4,), (5,), (6,), (7,)] # (), up, down, left, right
|
||||
acts = []
|
||||
acts += arrows
|
||||
acts += buttons[1:]
|
||||
acts += [a + b for a in arrows[-2:] for b in buttons[1:]]
|
||||
self._actions = acts
|
||||
self.action_space = gym.spaces.Discrete(len(self._actions))
|
||||
|
||||
def action(self, a):
|
||||
mask = np.zeros(self._num_buttons)
|
||||
for i in self._actions[a]:
|
||||
mask[i] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class FrameSkip(gym.Wrapper):
|
||||
def __init__(self, env, n):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.n = n
|
||||
|
||||
def step(self, action):
|
||||
done = False
|
||||
totrew = 0
|
||||
for _ in range(self.n):
|
||||
ob, rew, done, info = self.env.step(action)
|
||||
totrew += rew
|
||||
if done: break
|
||||
return ob, totrew, done, info
|
||||
|
||||
|
||||
def make_mario_env(crop=True, frame_stack=True, clip_rewards=False):
|
||||
assert clip_rewards is False
|
||||
import gym
|
||||
import retro
|
||||
from baselines.common.atari_wrappers import FrameStack
|
||||
|
||||
gym.undo_logger_setup()
|
||||
env = retro.make('SuperMarioBros-Nes', 'Level1-1')
|
||||
buttons = env.BUTTONS
|
||||
env = MarioXReward(env)
|
||||
env = FrameSkip(env, 4)
|
||||
env = ProcessFrame84(env, crop=crop)
|
||||
if frame_stack:
|
||||
env = FrameStack(env, 4)
|
||||
env = LimitedDiscreteActions(env, buttons)
|
||||
return env
|
||||
|
||||
|
||||
class OneChannel(gym.ObservationWrapper):
|
||||
def __init__(self, env, crop=True):
|
||||
self.crop = crop
|
||||
super(OneChannel, self).__init__(env)
|
||||
assert env.observation_space.dtype == np.uint8
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||||
|
||||
def observation(self, obs):
|
||||
return obs[:, :, 2:3]
|
||||
|
||||
|
||||
class RetroALEActions(gym.ActionWrapper):
|
||||
def __init__(self, env, all_buttons, n_players=1):
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
self.n_players = n_players
|
||||
self._num_buttons = len(all_buttons)
|
||||
bs = [-1, 0, 4, 5, 6, 7]
|
||||
actions = []
|
||||
|
||||
def update_actions(old_actions, offset=0):
|
||||
actions = []
|
||||
for b in old_actions:
|
||||
for button in bs:
|
||||
action = []
|
||||
action.extend(b)
|
||||
if button != -1:
|
||||
action.append(button + offset)
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
current_actions = [[]]
|
||||
for i in range(self.n_players):
|
||||
current_actions = update_actions(current_actions, i * self._num_buttons)
|
||||
self._actions = current_actions
|
||||
self.action_space = gym.spaces.Discrete(len(self._actions))
|
||||
|
||||
def action(self, a):
|
||||
mask = np.zeros(self._num_buttons * self.n_players)
|
||||
for i in self._actions[a]:
|
||||
mask[i] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class NoReward(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(self, action):
|
||||
ob, rew, done, info = self.env.step(action)
|
||||
return ob, 0.0, done, info
|
||||
|
||||
|
||||
def make_multi_pong(frame_stack=True):
|
||||
import gym
|
||||
import retro
|
||||
from baselines.common.atari_wrappers import FrameStack
|
||||
gym.undo_logger_setup()
|
||||
game_env = env = retro.make('Pong-Atari2600', players=2)
|
||||
env = RetroALEActions(env, game_env.BUTTONS, n_players=2)
|
||||
env = NoReward(env)
|
||||
env = FrameSkip(env, 4)
|
||||
env = ProcessFrame84(env, crop=False)
|
||||
if frame_stack:
|
||||
env = FrameStack(env, 4)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_robo_pong(frame_stack=True):
|
||||
from baselines.common.atari_wrappers import FrameStack
|
||||
import roboenvs as robo
|
||||
|
||||
env = robo.make_robopong()
|
||||
env = robo.DiscretizeActionWrapper(env, 2)
|
||||
env = robo.MultiDiscreteToUsual(env)
|
||||
env = OneChannel(env)
|
||||
if frame_stack:
|
||||
env = FrameStack(env, 4)
|
||||
|
||||
env = AddRandomStateToInfo(env)
|
||||
return env
|
||||
|
||||
|
||||
def make_robo_hockey(frame_stack=True):
|
||||
from baselines.common.atari_wrappers import FrameStack
|
||||
import roboenvs as robo
|
||||
|
||||
env = robo.make_robohockey()
|
||||
env = robo.DiscretizeActionWrapper(env, 2)
|
||||
env = robo.MultiDiscreteToUsual(env)
|
||||
env = OneChannel(env)
|
||||
if frame_stack:
|
||||
env = FrameStack(env, 4)
|
||||
env = AddRandomStateToInfo(env)
|
||||
return env
|
||||
|
||||
|
||||
def make_robo_hockey(frame_stack=True):
|
||||
from baselines.common.atari_wrappers import FrameStack
|
||||
import roboenvs as robo
|
||||
|
||||
env = robo.make_robohockey()
|
||||
env = robo.DiscretizeActionWrapper(env, 2)
|
||||
env = robo.MultiDiscreteToUsual(env)
|
||||
env = OneChannel(env)
|
||||
if frame_stack:
|
||||
env = FrameStack(env, 4)
|
||||
env = AddRandomStateToInfo(env)
|
||||
return env
|
||||
|
||||
|
||||
def make_unity_maze(env_id, seed=0, rank=0, expID=0, frame_stack=True,
|
||||
logdir=None, ext_coeff=1.0, recordUnityVid=False, **kwargs):
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
try:
|
||||
sys.path.insert(0, os.path.abspath("ml-agents/python/"))
|
||||
from unityagents import UnityEnvironment
|
||||
from unity_wrapper import GymWrapper
|
||||
except ImportError:
|
||||
print("Import error in unity environment. Ignore if not using unity.")
|
||||
pass
|
||||
from baselines.common.atari_wrappers import FrameStack
|
||||
# gym.undo_logger_setup() # deprecated in new version of gym
|
||||
|
||||
# max 20 workers per expID, max 30 experiments per machine
|
||||
if rank >= 0 and rank <= 200:
|
||||
time.sleep(rank * 2)
|
||||
env = UnityEnvironment(file_name='envs/' + env_id, worker_id=(expID % 60) * 200 + rank)
|
||||
maxsteps = 3000 if 'big' in env_id else 500
|
||||
env = GymWrapper(env, seed=seed, rank=rank, expID=expID, maxsteps=maxsteps, **kwargs)
|
||||
if "big" in env_id:
|
||||
env = UnityRoomCounterWrapper(env, use_ext_reward=(ext_coeff != 0.0))
|
||||
if rank == 1 and recordUnityVid:
|
||||
env = RecordBestScores(env, directory=logdir, freq=1)
|
||||
print('Loaded environment %s with rank %d\n\n' % (env_id, rank))
|
||||
|
||||
# env = NoReward(env)
|
||||
# env = FrameSkip(env, 4)
|
||||
env = ProcessFrame84(env, crop=False)
|
||||
if frame_stack:
|
||||
env = FrameStack(env, 4)
|
||||
return env
|
||||
|
||||
|
||||
class StickyActionEnv(gym.Wrapper):
|
||||
def __init__(self, env, p=0.25):
|
||||
super(StickyActionEnv, self).__init__(env)
|
||||
self.p = p
|
||||
self.last_action = 0
|
||||
|
||||
def reset(self):
|
||||
self.last_action = 0
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, action):
|
||||
if self.unwrapped.np_random.uniform() < self.p:
|
||||
action = self.last_action
|
||||
self.last_action = action
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
return obs, reward, done, info
|
||||
|
||||
|
||||
############## Pixel-Noise #################
|
||||
class PixelNoiseWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, strength=80):
|
||||
""" The source must produce a image with a shape that's compatible to `env.observation_space`.
|
||||
"""
|
||||
super(PixelNoiseWrapper, self).__init__(env)
|
||||
self.env = env
|
||||
self.obs_shape = env.observation_space.shape[:2]
|
||||
self.strength = strength
|
||||
|
||||
def observation(self, obs):
|
||||
mask = (obs == (0, 0, 0)) # shape=(210,140,3)
|
||||
noise = np.maximum(np.random.randn(self.obs_shape[0], self.obs_shape[1], 3) * self.strength, 0)
|
||||
obs[mask] = noise[mask]
|
||||
self._last_ob = obs
|
||||
return obs
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
img = self._last_ob
|
||||
return img
|
||||
|
||||
|
||||
############# Random Box Noise #################
|
||||
class RandomBoxNoiseWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, strength=0.1):
|
||||
super(RandomBoxNoiseWrapper, self).__init__(env)
|
||||
self.obs_shape = env.observation_space.shape[:2] # 210, 160
|
||||
self.strength = strength
|
||||
|
||||
def observation(self, obs, w=20, ):
|
||||
n1 = self.obs_shape[1] // w
|
||||
n2 = self.obs_shape[0] // w
|
||||
|
||||
idx_list = np.arange(n1*n2)
|
||||
random.shuffle(idx_list)
|
||||
|
||||
num_of_box = n1 * n2 * self.strength # the ratio of random box
|
||||
idx_list = idx_list[:np.random.randint(num_of_box-5, num_of_box+5)]
|
||||
|
||||
for idx in idx_list:
|
||||
y = (idx // n1) * w
|
||||
x = (idx % n1) * w
|
||||
obs[y:y+w, x:x+w, :] += np.random.normal(0, 255*0.3, size=(w, w, 3)).astype(np.uint8)
|
||||
|
||||
obs = np.clip(obs, 0, 255)
|
||||
self._last_ob = obs
|
||||
return obs
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
img = self._last_ob
|
||||
return img
|
||||
|
Loading…
Reference in New Issue
Block a user