DB
This commit is contained in:
commit
8ebc277814
60
README.md
Normal file
60
README.md
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Dynamic Bottleneck
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
This is a TensorFlow based implementation for our paper on
|
||||||
|
|
||||||
|
**"Dynamic Bottleneck for Robust Self-Supervised Exploration". NeurIPS 2021**
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
python3.6 or 3.7,
|
||||||
|
tensorflow-gpu 1.x, tensorflow-probability,
|
||||||
|
openAI [baselines](https://github.com/openai/baselines),
|
||||||
|
openAI [Gym](http://gym.openai.com/)
|
||||||
|
|
||||||
|
## Installation and Usage
|
||||||
|
|
||||||
|
### Atari games
|
||||||
|
|
||||||
|
The following command should train a pure exploration
|
||||||
|
agent on "Breakout" with default experiment parameters.
|
||||||
|
|
||||||
|
```
|
||||||
|
python run.py --env BreakoutNoFrameskip-v4
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Atari games with Random-Box noise
|
||||||
|
|
||||||
|
The following command should train a pure exploration
|
||||||
|
agent on "Breakout" with randomBox noise.
|
||||||
|
|
||||||
|
```
|
||||||
|
python run.py --env BreakoutNoFrameskip-v4 --randomBoxNoise
|
||||||
|
```
|
||||||
|
|
||||||
|
### Atari games with Gaussian noise
|
||||||
|
|
||||||
|
The following command should train a pure exploration
|
||||||
|
agent on "Breakout" with Gaussian noise.
|
||||||
|
|
||||||
|
```
|
||||||
|
python run.py --env BreakoutNoFrameskip-v4 --pixelNoise
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Atari games with sticky actions
|
||||||
|
|
||||||
|
The following command should train a pure exploration
|
||||||
|
agent on "sticky Breakout" with a probability of 0.25
|
||||||
|
|
||||||
|
```
|
||||||
|
python run.py --env BreakoutNoFrameskip-v4 --stickyAtari
|
||||||
|
```
|
||||||
|
|
||||||
|
### Baselines
|
||||||
|
|
||||||
|
- **ICM**: We use the official [code](https://github.com/openai/large-scale-curiosity) of "Curiosity-driven Exploration by Self-supervised Prediction, ICML 2017" and "Large-Scale Study of Curiosity-Driven Learning, ICLR 2019".
|
||||||
|
- **Disagreement**: We use the official [code](https://github.com/pathak22/exploration-by-disagreement) of "Self-Supervised Exploration via Disagreement, ICML 2019".
|
||||||
|
- **CB**: We use the official [code](https://github.com/whyjay/curiosity-bottleneck) of "Curiosity-Bottleneck: Exploration by Distilling Task-Specific Novelty, ICML 2019".
|
1
__init__.py
Normal file
1
__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
#############
|
71
cnn_policy.py
Normal file
71
cnn_policy.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
from baselines.common.distributions import make_pdtype
|
||||||
|
from utils import getsess, small_convnet, activ, fc, flatten_two_dims, unflatten_first_dim
|
||||||
|
|
||||||
|
|
||||||
|
class CnnPolicy(object):
|
||||||
|
def __init__(self, ob_space, ac_space, hidsize,
|
||||||
|
ob_mean, ob_std, feat_dim, layernormalize, nl, scope="policy"):
|
||||||
|
""" ob_space: (84,84,4); ac_space: 4;
|
||||||
|
ob_mean.shape=(84,84,4); ob_std=1.7; hidsize: 512;
|
||||||
|
feat_dim: 512; layernormalize: False; nl: tf.nn.leaky_relu.
|
||||||
|
"""
|
||||||
|
if layernormalize:
|
||||||
|
print("Warning: policy is operating on top of layer-normed features. It might slow down the training.")
|
||||||
|
self.layernormalize = layernormalize
|
||||||
|
self.nl = nl
|
||||||
|
self.ob_mean = ob_mean
|
||||||
|
self.ob_std = ob_std
|
||||||
|
with tf.variable_scope(scope):
|
||||||
|
self.ob_space = ob_space
|
||||||
|
self.ac_space = ac_space
|
||||||
|
self.ac_pdtype = make_pdtype(ac_space)
|
||||||
|
self.ph_ob = tf.placeholder(dtype=tf.int32,
|
||||||
|
shape=(None, None) + ob_space.shape, name='ob')
|
||||||
|
self.ph_ac = self.ac_pdtype.sample_placeholder([None, None], name='ac')
|
||||||
|
self.pd = self.vpred = None
|
||||||
|
self.hidsize = hidsize
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.scope = scope
|
||||||
|
pdparamsize = self.ac_pdtype.param_shape()[0]
|
||||||
|
|
||||||
|
sh = tf.shape(self.ph_ob) # ph_ob.shape = (None,None,84,84,4)
|
||||||
|
x = flatten_two_dims(self.ph_ob) # x.shape = (None,84,84,4)
|
||||||
|
|
||||||
|
self.flat_features = self.get_features(x, reuse=False) # shape=(None,512)
|
||||||
|
self.features = unflatten_first_dim(self.flat_features, sh) # shape=(None,None,512)
|
||||||
|
|
||||||
|
with tf.variable_scope(scope, reuse=False):
|
||||||
|
x = fc(self.flat_features, units=hidsize, activation=activ) # activ=tf.nn.relu
|
||||||
|
x = fc(x, units=hidsize, activation=activ) # value and policy
|
||||||
|
pdparam = fc(x, name='pd', units=pdparamsize, activation=None) # logits, shape=(None,4)
|
||||||
|
vpred = fc(x, name='value_function_output', units=1, activation=None) # shape=(None,1)
|
||||||
|
pdparam = unflatten_first_dim(pdparam, sh) # shape=(None,None,4)
|
||||||
|
self.vpred = unflatten_first_dim(vpred, sh)[:, :, 0] # value function shape=(None,None)
|
||||||
|
self.pd = pd = self.ac_pdtype.pdfromflat(pdparam) # mean,neglogp,kl,entropy,sample
|
||||||
|
self.a_samp = pd.sample() #
|
||||||
|
self.entropy = pd.entropy() # (None,None)
|
||||||
|
self.nlp_samp = pd.neglogp(self.a_samp) # -log pi(a|s) (None,None)
|
||||||
|
|
||||||
|
def get_features(self, x, reuse):
|
||||||
|
x_has_timesteps = (x.get_shape().ndims == 5)
|
||||||
|
if x_has_timesteps:
|
||||||
|
sh = tf.shape(x)
|
||||||
|
x = flatten_two_dims(x)
|
||||||
|
|
||||||
|
with tf.variable_scope(self.scope + "_features", reuse=reuse):
|
||||||
|
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
|
||||||
|
x = small_convnet(x, nl=self.nl, feat_dim=self.feat_dim, last_nl=None, layernormalize=self.layernormalize)
|
||||||
|
|
||||||
|
if x_has_timesteps:
|
||||||
|
x = unflatten_first_dim(x, sh)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_ac_value_nlp(self, ob):
|
||||||
|
# ob.shape=(128,84,84,1), ob[:,None].shape=(128,1,84,84,4)
|
||||||
|
a, vpred, nlp = \
|
||||||
|
getsess().run([self.a_samp, self.vpred, self.nlp_samp],
|
||||||
|
feed_dict={self.ph_ob: ob[:, None]})
|
||||||
|
return a[:, 0], vpred[:, 0], nlp[:, 0]
|
||||||
|
|
||||||
|
|
260
cppo_agent.py
Normal file
260
cppo_agent.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from baselines.common import explained_variance
|
||||||
|
from baselines.common.mpi_moments import mpi_moments
|
||||||
|
from baselines.common.running_mean_std import RunningMeanStd
|
||||||
|
from mpi4py import MPI
|
||||||
|
from mpi_utils import MpiAdamOptimizer
|
||||||
|
from rollouts import Rollout
|
||||||
|
from utils import bcast_tf_vars_from_root, get_mean_and_std
|
||||||
|
from vec_env import ShmemVecEnv as VecEnv
|
||||||
|
|
||||||
|
getsess = tf.get_default_session
|
||||||
|
|
||||||
|
|
||||||
|
class PpoOptimizer(object):
|
||||||
|
envs = None
|
||||||
|
|
||||||
|
def __init__(self, *, scope, ob_space, ac_space, stochpol, ent_coef, gamma, lam, nepochs, lr, cliprange,
|
||||||
|
nminibatches, normrew, normadv, use_news, ext_coeff, int_coeff, nsteps_per_seg, nsegs_per_env,
|
||||||
|
dynamic_bottleneck):
|
||||||
|
self.dynamic_bottleneck = dynamic_bottleneck
|
||||||
|
with tf.variable_scope(scope):
|
||||||
|
self.use_recorder = True
|
||||||
|
self.n_updates = 0
|
||||||
|
self.scope = scope
|
||||||
|
self.ob_space = ob_space # Box(84,84,4)
|
||||||
|
self.ac_space = ac_space # Discrete(4)
|
||||||
|
self.stochpol = stochpol # cnn policy
|
||||||
|
self.nepochs = nepochs # 3
|
||||||
|
self.lr = lr # 1e-4
|
||||||
|
self.cliprange = cliprange # 0.1
|
||||||
|
self.nsteps_per_seg = nsteps_per_seg # 128
|
||||||
|
self.nsegs_per_env = nsegs_per_env # 1
|
||||||
|
self.nminibatches = nminibatches # 8
|
||||||
|
self.gamma = gamma # 0.99
|
||||||
|
self.lam = lam # 0.99
|
||||||
|
self.normrew = normrew # 1
|
||||||
|
self.normadv = normadv # 1
|
||||||
|
self.use_news = use_news # False
|
||||||
|
self.ext_coeff = ext_coeff # 0.0
|
||||||
|
self.int_coeff = int_coeff # 1.0
|
||||||
|
self.ph_adv = tf.placeholder(tf.float32, [None, None])
|
||||||
|
self.ph_ret = tf.placeholder(tf.float32, [None, None])
|
||||||
|
self.ph_rews = tf.placeholder(tf.float32, [None, None])
|
||||||
|
self.ph_oldnlp = tf.placeholder(tf.float32, [None, None]) # -log pi(a|s)
|
||||||
|
self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
|
||||||
|
self.ph_lr = tf.placeholder(tf.float32, [])
|
||||||
|
self.ph_cliprange = tf.placeholder(tf.float32, [])
|
||||||
|
neglogpac = self.stochpol.pd.neglogp(self.stochpol.ph_ac)
|
||||||
|
entropy = tf.reduce_mean(self.stochpol.pd.entropy())
|
||||||
|
vpred = self.stochpol.vpred
|
||||||
|
|
||||||
|
vf_loss = 0.5 * tf.reduce_mean((vpred - self.ph_ret) ** 2)
|
||||||
|
ratio = tf.exp(self.ph_oldnlp - neglogpac) # p_new / p_old
|
||||||
|
negadv = - self.ph_adv
|
||||||
|
pg_losses1 = negadv * ratio
|
||||||
|
pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
|
||||||
|
pg_loss_surr = tf.maximum(pg_losses1, pg_losses2)
|
||||||
|
pg_loss = tf.reduce_mean(pg_loss_surr)
|
||||||
|
ent_loss = (- ent_coef) * entropy
|
||||||
|
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))
|
||||||
|
clipfrac = tf.reduce_mean(tf.to_float(tf.abs(pg_losses2 - pg_loss_surr) > 1e-6))
|
||||||
|
|
||||||
|
self.total_loss = pg_loss + ent_loss + vf_loss
|
||||||
|
self.to_report = {'tot': self.total_loss, 'pg': pg_loss, 'vf': vf_loss, 'ent': entropy, 'approxkl': approxkl, 'clipfrac': clipfrac}
|
||||||
|
|
||||||
|
# add bai
|
||||||
|
self.db_loss = None
|
||||||
|
|
||||||
|
def start_interaction(self, env_fns, dynamic_bottleneck, nlump=2):
|
||||||
|
self.loss_names, self._losses = zip(*list(self.to_report.items()))
|
||||||
|
|
||||||
|
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
|
params_db = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="DB")
|
||||||
|
print("***total params:", np.sum([np.prod(v.get_shape().as_list()) for v in params])) # idf:10,172,133
|
||||||
|
print("***DB params:", np.sum([np.prod(v.get_shape().as_list()) for v in params_db])) # idf:10,172,133
|
||||||
|
|
||||||
|
if MPI.COMM_WORLD.Get_size() > 1:
|
||||||
|
trainer = MpiAdamOptimizer(learning_rate=self.ph_lr, comm=MPI.COMM_WORLD)
|
||||||
|
else:
|
||||||
|
trainer = tf.train.AdamOptimizer(learning_rate=self.ph_lr)
|
||||||
|
gradsandvars = trainer.compute_gradients(self.total_loss, params) # 计算梯度
|
||||||
|
self._train = trainer.apply_gradients(gradsandvars)
|
||||||
|
|
||||||
|
# Train DB
|
||||||
|
# gradsandvars_db = trainer.compute_gradients(self.db_loss, params_db)
|
||||||
|
# self._train_db = trainer.apply_gradients(gradsandvars_db)
|
||||||
|
|
||||||
|
# Train DB with gradient clipping
|
||||||
|
gradients_db, variables_db = zip(*trainer.compute_gradients(self.db_loss, params_db))
|
||||||
|
gradients_db, self.norm_var = tf.clip_by_global_norm(gradients_db, 50.0)
|
||||||
|
self._train_db = trainer.apply_gradients(zip(gradients_db, variables_db))
|
||||||
|
|
||||||
|
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
|
getsess().run(tf.variables_initializer(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
|
||||||
|
bcast_tf_vars_from_root(getsess(), tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
|
||||||
|
|
||||||
|
self.all_visited_rooms = []
|
||||||
|
self.all_scores = []
|
||||||
|
self.nenvs = nenvs = len(env_fns) # 128
|
||||||
|
self.nlump = nlump # 1
|
||||||
|
self.lump_stride = nenvs // self.nlump # 128/1=128
|
||||||
|
self.envs = [
|
||||||
|
VecEnv(env_fns[l * self.lump_stride: (l + 1) * self.lump_stride], spaces=[self.ob_space, self.ac_space]) for
|
||||||
|
l in range(self.nlump)]
|
||||||
|
|
||||||
|
self.rollout = Rollout(ob_space=self.ob_space, ac_space=self.ac_space, nenvs=nenvs,
|
||||||
|
nsteps_per_seg=self.nsteps_per_seg,
|
||||||
|
nsegs_per_env=self.nsegs_per_env, nlumps=self.nlump,
|
||||||
|
envs=self.envs,
|
||||||
|
policy=self.stochpol,
|
||||||
|
int_rew_coeff=self.int_coeff,
|
||||||
|
ext_rew_coeff=self.ext_coeff,
|
||||||
|
record_rollouts=self.use_recorder,
|
||||||
|
dynamic_bottleneck=dynamic_bottleneck)
|
||||||
|
|
||||||
|
self.buf_advs = np.zeros((nenvs, self.rollout.nsteps), np.float32)
|
||||||
|
self.buf_rets = np.zeros((nenvs, self.rollout.nsteps), np.float32)
|
||||||
|
|
||||||
|
# add bai. Dynamic Bottleneck Reward Normalization
|
||||||
|
if self.normrew:
|
||||||
|
self.rff = RewardForwardFilter(self.gamma)
|
||||||
|
self.rff_rms = RunningMeanStd()
|
||||||
|
|
||||||
|
self.step_count = 0
|
||||||
|
self.t_last_update = time.time()
|
||||||
|
self.t_start = time.time()
|
||||||
|
|
||||||
|
def stop_interaction(self):
|
||||||
|
for env in self.envs:
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
def calculate_advantages(self, rews, use_news, gamma, lam):
|
||||||
|
nsteps = self.rollout.nsteps
|
||||||
|
lastgaelam = 0
|
||||||
|
for t in range(nsteps - 1, -1, -1): # nsteps-2 ... 0
|
||||||
|
nextnew = self.rollout.buf_news[:, t + 1] if t + 1 < nsteps else self.rollout.buf_new_last
|
||||||
|
if not use_news:
|
||||||
|
nextnew = 0
|
||||||
|
nextvals = self.rollout.buf_vpreds[:, t + 1] if t + 1 < nsteps else self.rollout.buf_vpred_last
|
||||||
|
nextnotnew = 1 - nextnew
|
||||||
|
delta = rews[:, t] + gamma * nextvals * nextnotnew - self.rollout.buf_vpreds[:, t]
|
||||||
|
self.buf_advs[:, t] = lastgaelam = delta + gamma * lam * nextnotnew * lastgaelam
|
||||||
|
self.buf_rets[:] = self.buf_advs + self.rollout.buf_vpreds
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
# add bai. use dynamic bottleneck
|
||||||
|
if self.normrew:
|
||||||
|
rffs = np.array([self.rff.update(rew) for rew in self.rollout.buf_rews.T])
|
||||||
|
rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
|
||||||
|
self.rff_rms.update_from_moments(rffs_mean, rffs_std ** 2, rffs_count)
|
||||||
|
rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var) # shape=(128,128)
|
||||||
|
else:
|
||||||
|
rews = np.copy(self.rollout.buf_rews)
|
||||||
|
|
||||||
|
self.calculate_advantages(rews=rews, use_news=self.use_news, gamma=self.gamma, lam=self.lam)
|
||||||
|
|
||||||
|
info = dict(
|
||||||
|
advmean=self.buf_advs.mean(),
|
||||||
|
advstd=self.buf_advs.std(),
|
||||||
|
retmean=self.buf_rets.mean(),
|
||||||
|
retstd=self.buf_rets.std(),
|
||||||
|
vpredmean=self.rollout.buf_vpreds.mean(),
|
||||||
|
vpredstd=self.rollout.buf_vpreds.std(),
|
||||||
|
ev=explained_variance(self.rollout.buf_vpreds.ravel(), self.buf_rets.ravel()),
|
||||||
|
DB_rew=np.mean(self.rollout.buf_rews), # add bai.
|
||||||
|
DB_rew_norm=np.mean(rews), # add bai.
|
||||||
|
recent_best_ext_ret=self.rollout.current_max
|
||||||
|
)
|
||||||
|
if self.rollout.best_ext_ret is not None:
|
||||||
|
info['best_ext_ret'] = self.rollout.best_ext_ret
|
||||||
|
|
||||||
|
if self.normadv:
|
||||||
|
m, s = get_mean_and_std(self.buf_advs)
|
||||||
|
self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
|
||||||
|
envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
|
||||||
|
envsperbatch = max(1, envsperbatch)
|
||||||
|
envinds = np.arange(self.nenvs * self.nsegs_per_env)
|
||||||
|
|
||||||
|
def resh(x):
|
||||||
|
if self.nsegs_per_env == 1:
|
||||||
|
return x
|
||||||
|
sh = x.shape
|
||||||
|
return x.reshape((sh[0] * self.nsegs_per_env, self.nsteps_per_seg) + sh[2:])
|
||||||
|
|
||||||
|
ph_buf = [
|
||||||
|
(self.stochpol.ph_ac, resh(self.rollout.buf_acs)),
|
||||||
|
(self.ph_rews, resh(self.rollout.buf_rews)),
|
||||||
|
(self.ph_oldvpred, resh(self.rollout.buf_vpreds)),
|
||||||
|
(self.ph_oldnlp, resh(self.rollout.buf_nlps)),
|
||||||
|
(self.stochpol.ph_ob, resh(self.rollout.buf_obs)), # numpy shape=(128,128,84,84,4)
|
||||||
|
(self.ph_ret, resh(self.buf_rets)), #
|
||||||
|
(self.ph_adv, resh(self.buf_advs)), #
|
||||||
|
]
|
||||||
|
ph_buf.extend([
|
||||||
|
(self.dynamic_bottleneck.last_ob, # shape=(128,1,84,84,4)
|
||||||
|
self.rollout.buf_obs_last.reshape([self.nenvs * self.nsegs_per_env, 1, *self.ob_space.shape]))
|
||||||
|
])
|
||||||
|
mblossvals = [] #
|
||||||
|
for _ in range(self.nepochs): # nepochs = 3
|
||||||
|
np.random.shuffle(envinds) # envinds = [0,1,2,...,127]
|
||||||
|
# nenvs=128, nsgs_per_env=1, envsperbatch=16
|
||||||
|
for start in range(0, self.nenvs * self.nsegs_per_env, envsperbatch):
|
||||||
|
end = start + envsperbatch
|
||||||
|
mbenvinds = envinds[start:end]
|
||||||
|
fd = {ph: buf[mbenvinds] for (ph, buf) in ph_buf} # feed_dict
|
||||||
|
fd.update({self.ph_lr: self.lr, self.ph_cliprange: self.cliprange}) # , self.dynamic_bottleneck.l2_aux_loss_tf: l2_aux_loss_fd})
|
||||||
|
mblossvals.append(getsess().run(self._losses + (self._train,), fd)[:-1]) #
|
||||||
|
|
||||||
|
# gradient norm computation
|
||||||
|
# print("gradient norm:", getsess().run(self.norm_var, fd))
|
||||||
|
|
||||||
|
# momentum update DB parameters
|
||||||
|
print("Momentum Update DB Encoder")
|
||||||
|
getsess().run(self.dynamic_bottleneck.momentum_updates)
|
||||||
|
DB_loss_info = getsess().run(self.dynamic_bottleneck.loss_info, fd)
|
||||||
|
|
||||||
|
#
|
||||||
|
mblossvals = [mblossvals[0]]
|
||||||
|
info.update(zip(['opt_' + ln for ln in self.loss_names], np.mean([mblossvals[0]], axis=0)))
|
||||||
|
info["rank"] = MPI.COMM_WORLD.Get_rank()
|
||||||
|
self.n_updates += 1
|
||||||
|
info["n_updates"] = self.n_updates
|
||||||
|
info.update({dn: (np.mean(dvs) if len(dvs) > 0 else 0) for (dn, dvs) in self.rollout.statlists.items()})
|
||||||
|
info.update(self.rollout.stats)
|
||||||
|
if "states_visited" in info:
|
||||||
|
info.pop("states_visited")
|
||||||
|
tnow = time.time()
|
||||||
|
info["ups"] = 1. / (tnow - self.t_last_update)
|
||||||
|
info["total_secs"] = tnow - self.t_start
|
||||||
|
info['tps'] = MPI.COMM_WORLD.Get_size() * self.rollout.nsteps * self.nenvs / (tnow - self.t_last_update)
|
||||||
|
self.t_last_update = tnow
|
||||||
|
|
||||||
|
return info, DB_loss_info
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.rollout.collect_rollout()
|
||||||
|
update_info, DB_loss_info = self.update()
|
||||||
|
return {'update': update_info, "DB_loss_info": DB_loss_info}
|
||||||
|
|
||||||
|
def get_var_values(self):
|
||||||
|
return self.stochpol.get_var_values()
|
||||||
|
|
||||||
|
def set_var_values(self, vv):
|
||||||
|
self.stochpol.set_var_values(vv)
|
||||||
|
|
||||||
|
|
||||||
|
class RewardForwardFilter(object):
|
||||||
|
def __init__(self, gamma):
|
||||||
|
self.rewems = None
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def update(self, rews):
|
||||||
|
if self.rewems is None:
|
||||||
|
self.rewems = rews
|
||||||
|
else:
|
||||||
|
self.rewems = self.rewems * self.gamma + rews
|
||||||
|
return self.rewems
|
168
dynamic_bottleneck.py
Normal file
168
dynamic_bottleneck.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_probability as tfp
|
||||||
|
import numpy as np
|
||||||
|
from utils import getsess
|
||||||
|
|
||||||
|
tfd = tfp.distributions
|
||||||
|
|
||||||
|
from utils import flatten_two_dims, unflatten_first_dim, SmallConv, TransitionNetwork, normal_parse_params, \
|
||||||
|
ProjectionHead, ContrastiveHead, rec_log_prob, GenerativeNetworkGaussianFix
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicBottleneck(object):
|
||||||
|
def __init__(self, policy, tau, loss_kl_weight, loss_l2_weight, loss_nce_weight, aug, feat_dim=512, scope='DB'):
|
||||||
|
self.scope = scope
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.policy = policy
|
||||||
|
self.hidsize = policy.hidsize # 512
|
||||||
|
self.ob_space = policy.ob_space # Box(84, 84, 4)
|
||||||
|
self.ac_space = policy.ac_space # Discrete(4)
|
||||||
|
self.obs = self.policy.ph_ob # shape=(None,None,84,84,4)
|
||||||
|
self.ob_mean = self.policy.ob_mean # shape=(None,None,84,84,4)
|
||||||
|
self.ob_std = self.policy.ob_std # 1.8
|
||||||
|
self.tau = tau # tau for update the momentum network
|
||||||
|
self.loss_kl_weight = loss_kl_weight # loss_kl_weight
|
||||||
|
self.loss_l2_weight = loss_l2_weight # loss_l2_weight
|
||||||
|
self.loss_nce_weight = loss_nce_weight # loss_nce_weight
|
||||||
|
self.aug = aug
|
||||||
|
|
||||||
|
with tf.variable_scope(scope):
|
||||||
|
self.feature_conv = SmallConv(feat_dim=self.feat_dim, name="DB_main") # (None, None, 512)
|
||||||
|
self.feature_conv_momentum = SmallConv(feat_dim=self.feat_dim, name="DB_momentum") # (None, None, 512)
|
||||||
|
self.transition_model = TransitionNetwork(name="DB_transition") # (None, None, 256)
|
||||||
|
self.generative_model = GenerativeNetworkGaussianFix(name="DB_generative") # (None, None, 512)
|
||||||
|
self.projection_head = ProjectionHead(name="DB_projection_main") # projection head
|
||||||
|
self.projection_head_momentum = ProjectionHead(name="DB_projection_momentum") # projection head Momentum
|
||||||
|
self.contrastive_head = ContrastiveHead(temperature=1.0, name="DB_contrastive")
|
||||||
|
|
||||||
|
# (None,1,84,84,4)
|
||||||
|
self.last_ob = tf.placeholder(dtype=tf.int32, shape=(None, 1) + self.ob_space.shape, name='last_ob')
|
||||||
|
self.next_ob = tf.concat([self.obs[:, 1:], self.last_ob], 1) # (None,None,84,84,4)
|
||||||
|
|
||||||
|
self.features = self.get_features(self.obs) # (None,None,512)
|
||||||
|
self.next_features = self.get_features(self.next_ob, momentum=True) # (None,None,512) stop gradient
|
||||||
|
self.ac = self.policy.ph_ac # (None, None)
|
||||||
|
self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2)
|
||||||
|
|
||||||
|
# transition model
|
||||||
|
latent_params = self.transition_model([self.features, self.ac_pad]) # (None, None, 256)
|
||||||
|
self.latent_dis = normal_parse_params(latent_params, 1e-3) # Gaussian. mu, sigma=(None, None, 128)
|
||||||
|
|
||||||
|
# prior
|
||||||
|
sh = tf.shape(self.latent_dis.mean()) # sh=(None, None, 128)
|
||||||
|
self.prior_dis = tfd.Normal(loc=tf.zeros(sh), scale=tf.ones(sh))
|
||||||
|
|
||||||
|
# kl
|
||||||
|
kl = tfp.distributions.kl_divergence(self.latent_dis, self.prior_dis) # (None, None, 128)
|
||||||
|
kl = tf.reduce_sum(kl, axis=-1) # (None, None)
|
||||||
|
|
||||||
|
# generative network
|
||||||
|
latent = self.latent_dis.sample() # (None, None, 128)
|
||||||
|
rec_params = self.generative_model(latent) # (None, None, 1024)
|
||||||
|
assert rec_params.get_shape().as_list()[-1] == 1024 and len(rec_params.get_shape().as_list()) == 3
|
||||||
|
rec_dis = normal_parse_params(rec_params, 0.1) # distribution
|
||||||
|
|
||||||
|
rec_vec = rec_dis.sample() # mean of rec_params
|
||||||
|
assert rec_vec.get_shape().as_list()[-1] == 512 and len(rec_vec.get_shape().as_list()) == 3
|
||||||
|
|
||||||
|
# contrastive projection
|
||||||
|
z_a = self.projection_head(rec_vec) # (None, 128)
|
||||||
|
z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128)
|
||||||
|
assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2
|
||||||
|
|
||||||
|
# contrastive loss
|
||||||
|
logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size)
|
||||||
|
labels = tf.one_hot(tf.range(int(16*128)), depth=16*128) # (batch_size, batch_size)
|
||||||
|
rec_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) # (batch_size, )
|
||||||
|
rec_log_nce = -1. * rec_loss
|
||||||
|
rec_log_nce = unflatten_first_dim(rec_log_nce, sh) # shape=(None, None) (128,128)
|
||||||
|
|
||||||
|
# L2 loss
|
||||||
|
log_prob = rec_dis.log_prob(self.next_features) # (None, None, 512)
|
||||||
|
assert len(log_prob.get_shape().as_list()) == 3 and log_prob.get_shape().as_list()[-1] == 512
|
||||||
|
rec_log_l2 = tf.reduce_sum(log_prob, axis=-1)
|
||||||
|
rec_log = rec_log_nce * self.loss_nce_weight + rec_log_l2 * self.loss_l2_weight
|
||||||
|
|
||||||
|
# loss
|
||||||
|
self.loss = kl * self.loss_kl_weight - rec_log # kl
|
||||||
|
self.loss_info = {"DB_NCELoss": -1.*tf.reduce_mean(rec_log_nce),
|
||||||
|
"DB_NCELoss_w": -1. * tf.reduce_mean(rec_log_nce) * self.loss_nce_weight,
|
||||||
|
"DB_L2Loss": -1.*tf.reduce_mean(rec_log_l2),
|
||||||
|
"DB_L2Loss_w": -1.*tf.reduce_mean(rec_log_l2) * self.loss_l2_weight,
|
||||||
|
"DB_KLLoss": tf.reduce_mean(kl),
|
||||||
|
"DB_KLLoss_w": tf.reduce_mean(kl) * self.loss_kl_weight,
|
||||||
|
"DB_Loss": tf.reduce_mean(self.loss)}
|
||||||
|
|
||||||
|
# intrinsic reward
|
||||||
|
self.intrinsic_reward = self.intrinsic_contrastive()
|
||||||
|
self.intrinsic_reward = tf.stop_gradient(self.intrinsic_reward)
|
||||||
|
|
||||||
|
# update the momentum network
|
||||||
|
self.init_updates, self.momentum_updates = self.get_momentum_updates(tau=self.tau)
|
||||||
|
print("*** DB Total Components:", len(self.ib_get_vars(name='DB/')), ", Total Variables:", self.ib_get_params(self.ib_get_vars(name='DB/')), "\n")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ib_get_vars(name):
|
||||||
|
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ib_get_params(vars):
|
||||||
|
return np.sum([np.prod(v.shape) for v in vars])
|
||||||
|
|
||||||
|
def get_momentum_updates(self, tau): # tau=0.001
|
||||||
|
main_var = self.ib_get_vars(name='DB/DB_features/DB_main') + self.ib_get_vars(name="DB/DB_projection_main")
|
||||||
|
momentum_var = self.ib_get_vars(name='DB/DB_features_1/DB_momentum') + self.ib_get_vars(name="DB/DB_projection_momentum")
|
||||||
|
|
||||||
|
# print("\n\n momentum_var:", momentum_var)
|
||||||
|
assert len(main_var) > 0 and len(main_var) == len(momentum_var)
|
||||||
|
print("***In DB, feature & projection has ", len(main_var), "components, ", self.ib_get_params(main_var), "parameters.")
|
||||||
|
|
||||||
|
soft_updates = []
|
||||||
|
init_updates = []
|
||||||
|
assert len(main_var) == len(momentum_var)
|
||||||
|
for var, tvar in zip(main_var, momentum_var):
|
||||||
|
init_updates.append(tf.assign(tvar, var))
|
||||||
|
soft_updates.append(tf.assign(tvar, (1. - tau) * tvar + tau * var))
|
||||||
|
assert len(init_updates) == len(main_var)
|
||||||
|
assert len(soft_updates) == len(main_var)
|
||||||
|
return tf.group(*init_updates), tf.group(*soft_updates)
|
||||||
|
|
||||||
|
def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4)
|
||||||
|
x_has_timesteps = (x.get_shape().ndims == 5) # True
|
||||||
|
if x_has_timesteps:
|
||||||
|
sh = tf.shape(x)
|
||||||
|
x = flatten_two_dims(x) # (None,84,84,4)
|
||||||
|
|
||||||
|
if self.aug:
|
||||||
|
print(x.get_shape().as_list())
|
||||||
|
x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4)
|
||||||
|
x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4)
|
||||||
|
x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4)
|
||||||
|
|
||||||
|
with tf.variable_scope(self.scope + "_features"):
|
||||||
|
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
|
||||||
|
if momentum:
|
||||||
|
x = tf.stop_gradient(self.feature_conv_momentum(x)) # (None,512)
|
||||||
|
else:
|
||||||
|
x = self.feature_conv(x) # (None,512)
|
||||||
|
if x_has_timesteps:
|
||||||
|
x = unflatten_first_dim(x, sh) # (None,None,512)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def intrinsic_contrastive(self):
|
||||||
|
kl = tfp.distributions.kl_divergence(self.latent_dis, self.prior_dis) # (None, None, 128)
|
||||||
|
rew = tf.reduce_sum(kl, axis=-1) # (None, None)
|
||||||
|
return rew
|
||||||
|
|
||||||
|
def calculate_db_reward(self, ob, last_ob, acs):
|
||||||
|
n_chunks = 8
|
||||||
|
n = ob.shape[0]
|
||||||
|
chunk_size = n // n_chunks
|
||||||
|
assert n % n_chunks == 0
|
||||||
|
sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size)
|
||||||
|
|
||||||
|
# compute reward
|
||||||
|
rew = np.concatenate([getsess().run(self.intrinsic_reward,
|
||||||
|
{self.obs: ob[sli(i)], self.last_ob: last_ob[sli(i)],
|
||||||
|
self.ac: acs[sli(i)]}) for i in range(n_chunks)], 0)
|
||||||
|
return rew
|
33
mpi_utils.py
Normal file
33
mpi_utils.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||||
|
"""Adam optimizer that averages gradients across mpi processes."""
|
||||||
|
|
||||||
|
def __init__(self, comm, **kwargs):
|
||||||
|
self.comm = comm
|
||||||
|
tf.train.AdamOptimizer.__init__(self, **kwargs)
|
||||||
|
|
||||||
|
def compute_gradients(self, loss, var_list, **kwargs):
|
||||||
|
grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
|
||||||
|
grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
|
||||||
|
flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0)
|
||||||
|
shapes = [v.shape.as_list() for g, v in grads_and_vars]
|
||||||
|
sizes = [int(np.prod(s)) for s in shapes]
|
||||||
|
|
||||||
|
_task_id, num_tasks = self.comm.Get_rank(), self.comm.Get_size()
|
||||||
|
buf = np.zeros(sum(sizes), np.float32)
|
||||||
|
|
||||||
|
def _collect_grads(flat_grad):
|
||||||
|
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
||||||
|
np.divide(buf, float(num_tasks), out=buf)
|
||||||
|
return buf
|
||||||
|
|
||||||
|
avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32)
|
||||||
|
avg_flat_grad.set_shape(flat_grad.shape)
|
||||||
|
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
|
||||||
|
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
|
||||||
|
for g, (_, v) in zip(avg_grads, grads_and_vars)]
|
||||||
|
|
||||||
|
return avg_grads_and_vars
|
101
para.json
Normal file
101
para.json
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
{
|
||||||
|
"standard": {
|
||||||
|
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"CrazyClimber":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"Kangaroo":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"KungFuMaster":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"MsPacman":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Seaquest":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Solaris":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Tennis":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"TimePilot":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"UpNDown":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"VideoPinball":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"WizardOfWor":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Zaxxon":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"other":{"kl": 0.1, "nce":0.01}},
|
||||||
|
|
||||||
|
"randomBox": {
|
||||||
|
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"CrazyClimber":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"Kangaroo":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"KungFuMaster":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"MsPacman":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Seaquest":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Solaris":{"kl": 0.1, "nce":0.1},
|
||||||
|
"Tennis":{"kl": 0.1, "nce":0.01},
|
||||||
|
"TimePilot":{"kl": 0.1, "nce":0.01},
|
||||||
|
"UpNDown":{"kl": 0.001, "nce":0.01},
|
||||||
|
"VideoPinball":{"kl": 0.1, "nce":0.01},
|
||||||
|
"WizardOfWor":{"kl": 0.1, "nce":0.1},
|
||||||
|
"Zaxxon":{"kl": 0.1, "nce":0.01},
|
||||||
|
"other":{"kl": 0.001, "nce":0.01}},
|
||||||
|
|
||||||
|
"stickyAtari":{
|
||||||
|
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"CrazyClimber":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"Kangaroo":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"KungFuMaster":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"MsPacman":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Seaquest":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Solaris":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Tennis":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"TimePilot":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"UpNDown":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"VideoPinball":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"WizardOfWor":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Zaxxon":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"other":{"kl": 0.1, "nce":0.01}},
|
||||||
|
|
||||||
|
"pixelNoise": {
|
||||||
|
"Alien":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"Asteroids":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"BankHeist":{"kl": 0.001, "nce": 0.1},
|
||||||
|
"BeamRider":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Boxing":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Breakout":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Centipede":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"ChopperCommand":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"CrazyClimber":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Gopher":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Gravitar":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"Kangaroo":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"KungFuMaster":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"MsPacman":{"kl": 0.001, "nce": 0.01},
|
||||||
|
"Seaquest":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Solaris":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Tennis":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"TimePilot":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"UpNDown":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"VideoPinball":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"WizardOfWor":{"kl": 0.1, "nce": 0.1},
|
||||||
|
"Zaxxon":{"kl": 0.1, "nce": 0.01},
|
||||||
|
"other":{"kl": 0.1, "nce":0.01}}
|
||||||
|
}
|
63
recorder.py
Normal file
63
recorder.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from baselines import logger
|
||||||
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
class Recorder(object):
|
||||||
|
def __init__(self, nenvs, nlumps):
|
||||||
|
self.nenvs = nenvs
|
||||||
|
self.nlumps = nlumps
|
||||||
|
self.nenvs_per_lump = nenvs // nlumps
|
||||||
|
self.acs = [[] for _ in range(nenvs)]
|
||||||
|
self.int_rews = [[] for _ in range(nenvs)]
|
||||||
|
self.ext_rews = [[] for _ in range(nenvs)]
|
||||||
|
self.ep_infos = [{} for _ in range(nenvs)]
|
||||||
|
self.filenames = [self.get_filename(i) for i in range(nenvs)]
|
||||||
|
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
|
logger.info("episode recordings saved to ", self.filenames[0])
|
||||||
|
|
||||||
|
def record(self, timestep, lump, acs, infos, int_rew, ext_rew, news):
|
||||||
|
for out_index in range(self.nenvs_per_lump):
|
||||||
|
in_index = out_index + lump * self.nenvs_per_lump
|
||||||
|
if timestep == 0:
|
||||||
|
self.acs[in_index].append(acs[out_index])
|
||||||
|
else:
|
||||||
|
if self.is_first_episode_step(in_index):
|
||||||
|
try:
|
||||||
|
self.ep_infos[in_index]['random_state'] = infos[out_index]['random_state']
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.int_rews[in_index].append(int_rew[out_index])
|
||||||
|
self.ext_rews[in_index].append(ext_rew[out_index])
|
||||||
|
|
||||||
|
if news[out_index]:
|
||||||
|
self.ep_infos[in_index]['ret'] = infos[out_index]['episode']['r']
|
||||||
|
self.ep_infos[in_index]['len'] = infos[out_index]['episode']['l']
|
||||||
|
self.dump_episode(in_index)
|
||||||
|
|
||||||
|
self.acs[in_index].append(acs[out_index])
|
||||||
|
|
||||||
|
def dump_episode(self, i):
|
||||||
|
episode = {'acs': self.acs[i],
|
||||||
|
'int_rew': self.int_rews[i],
|
||||||
|
'info': self.ep_infos[i]}
|
||||||
|
filename = self.filenames[i]
|
||||||
|
if self.episode_worth_saving(i):
|
||||||
|
with open(filename, 'ab') as f:
|
||||||
|
pickle.dump(episode, f, protocol=-1)
|
||||||
|
self.acs[i].clear()
|
||||||
|
self.int_rews[i].clear()
|
||||||
|
self.ext_rews[i].clear()
|
||||||
|
self.ep_infos[i].clear()
|
||||||
|
|
||||||
|
def episode_worth_saving(self, i):
|
||||||
|
return (i == 0 and MPI.COMM_WORLD.Get_rank() == 0)
|
||||||
|
|
||||||
|
def is_first_episode_step(self, i):
|
||||||
|
return len(self.int_rews[i]) == 0
|
||||||
|
|
||||||
|
def get_filename(self, i):
|
||||||
|
filename = os.path.join(logger.get_dir(), 'env{}_{}.pk'.format(MPI.COMM_WORLD.Get_rank(), i))
|
||||||
|
return filename
|
177
rollouts.py
Normal file
177
rollouts.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
from collections import deque, defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mpi4py import MPI
|
||||||
|
# from utils import add_noise
|
||||||
|
from recorder import Recorder
|
||||||
|
|
||||||
|
|
||||||
|
class Rollout(object):
|
||||||
|
def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy,
|
||||||
|
int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p):
|
||||||
|
# int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True
|
||||||
|
self.nenvs = nenvs # 128
|
||||||
|
self.nsteps_per_seg = nsteps_per_seg # 128
|
||||||
|
self.nsegs_per_env = nsegs_per_env # 1
|
||||||
|
self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128
|
||||||
|
self.ob_space = ob_space # Box (84,84,4)
|
||||||
|
self.ac_space = ac_space # Discrete(4)
|
||||||
|
self.nlumps = nlumps # 1
|
||||||
|
self.lump_stride = nenvs // self.nlumps # 128
|
||||||
|
self.envs = envs
|
||||||
|
self.policy = policy
|
||||||
|
self.dynamic_bottleneck = dynamic_bottleneck # Dynamic Bottleneck
|
||||||
|
|
||||||
|
self.reward_fun = lambda ext_rew, int_rew: ext_rew_coeff * np.clip(ext_rew, -1., 1.) + int_rew_coeff * int_rew
|
||||||
|
|
||||||
|
self.buf_vpreds = np.empty((nenvs, self.nsteps), np.float32)
|
||||||
|
self.buf_nlps = np.empty((nenvs, self.nsteps), np.float32)
|
||||||
|
self.buf_rews = np.empty((nenvs, self.nsteps), np.float32)
|
||||||
|
self.buf_ext_rews = np.empty((nenvs, self.nsteps), np.float32)
|
||||||
|
self.buf_acs = np.empty((nenvs, self.nsteps, *self.ac_space.shape), self.ac_space.dtype)
|
||||||
|
self.buf_obs = np.empty((nenvs, self.nsteps, *self.ob_space.shape), self.ob_space.dtype)
|
||||||
|
self.buf_obs_last = np.empty((nenvs, self.nsegs_per_env, *self.ob_space.shape), np.float32)
|
||||||
|
|
||||||
|
self.buf_news = np.zeros((nenvs, self.nsteps), np.float32)
|
||||||
|
self.buf_new_last = self.buf_news[:, 0, ...].copy()
|
||||||
|
self.buf_vpred_last = self.buf_vpreds[:, 0, ...].copy()
|
||||||
|
|
||||||
|
self.env_results = [None] * self.nlumps
|
||||||
|
self.int_rew = np.zeros((nenvs,), np.float32)
|
||||||
|
|
||||||
|
self.recorder = Recorder(nenvs=self.nenvs, nlumps=self.nlumps) if record_rollouts else None
|
||||||
|
self.statlists = defaultdict(lambda: deque([], maxlen=100))
|
||||||
|
self.stats = defaultdict(float)
|
||||||
|
self.best_ext_ret = None
|
||||||
|
self.all_visited_rooms = []
|
||||||
|
self.all_scores = []
|
||||||
|
|
||||||
|
self.step_count = 0
|
||||||
|
|
||||||
|
# add bai. Noise box in observation
|
||||||
|
# self.noisy_box = noisy_box
|
||||||
|
# self.noisy_p = noisy_p
|
||||||
|
|
||||||
|
def collect_rollout(self):
|
||||||
|
self.ep_infos_new = []
|
||||||
|
for t in range(self.nsteps):
|
||||||
|
self.rollout_step()
|
||||||
|
self.calculate_reward()
|
||||||
|
self.update_info()
|
||||||
|
|
||||||
|
def calculate_reward(self): # Reward comes from Dynamic Bottleneck
|
||||||
|
db_rew = self.dynamic_bottleneck.calculate_db_reward(
|
||||||
|
ob=self.buf_obs, last_ob=self.buf_obs_last, acs=self.buf_acs)
|
||||||
|
self.buf_rews[:] = self.reward_fun(int_rew=db_rew, ext_rew=self.buf_ext_rews)
|
||||||
|
|
||||||
|
def rollout_step(self):
|
||||||
|
t = self.step_count % self.nsteps
|
||||||
|
s = t % self.nsteps_per_seg
|
||||||
|
for l in range(self.nlumps): # nclumps=1
|
||||||
|
obs, prevrews, news, infos = self.env_get(l)
|
||||||
|
# if t > 0:
|
||||||
|
# prev_feat = self.prev_feat[l]
|
||||||
|
# prev_acs = self.prev_acs[l]
|
||||||
|
for info in infos:
|
||||||
|
epinfo = info.get('episode', {})
|
||||||
|
mzepinfo = info.get('mz_episode', {})
|
||||||
|
retroepinfo = info.get('retro_episode', {})
|
||||||
|
epinfo.update(mzepinfo)
|
||||||
|
epinfo.update(retroepinfo)
|
||||||
|
if epinfo:
|
||||||
|
if "n_states_visited" in info:
|
||||||
|
epinfo["n_states_visited"] = info["n_states_visited"]
|
||||||
|
epinfo["states_visited"] = info["states_visited"]
|
||||||
|
self.ep_infos_new.append((self.step_count, epinfo))
|
||||||
|
|
||||||
|
# slice(0,128) lump_stride=128
|
||||||
|
sli = slice(l * self.lump_stride, (l + 1) * self.lump_stride)
|
||||||
|
|
||||||
|
acs, vpreds, nlps = self.policy.get_ac_value_nlp(obs)
|
||||||
|
self.env_step(l, acs)
|
||||||
|
|
||||||
|
# self.prev_feat[l] = dyn_feat
|
||||||
|
# self.prev_acs[l] = acs
|
||||||
|
self.buf_obs[sli, t] = obs # obs.shape=(128,84,84,4)
|
||||||
|
self.buf_news[sli, t] = news # shape=(128,) True/False
|
||||||
|
self.buf_vpreds[sli, t] = vpreds # shape=(128,)
|
||||||
|
self.buf_nlps[sli, t] = nlps # -log pi(a|s), shape=(128,)
|
||||||
|
self.buf_acs[sli, t] = acs # shape=(128,)
|
||||||
|
if t > 0:
|
||||||
|
self.buf_ext_rews[sli, t - 1] = prevrews # prevrews.shape=(128,)
|
||||||
|
if self.recorder is not None:
|
||||||
|
self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli],
|
||||||
|
ext_rew=prevrews, news=news)
|
||||||
|
|
||||||
|
self.step_count += 1
|
||||||
|
if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128
|
||||||
|
for l in range(self.nlumps): # nclumps=1
|
||||||
|
sli = slice(l * self.lump_stride, (l + 1) * self.lump_stride)
|
||||||
|
nextobs, ext_rews, nextnews, _ = self.env_get(l)
|
||||||
|
self.buf_obs_last[sli, t // self.nsteps_per_seg] = nextobs
|
||||||
|
if t == self.nsteps - 1: # t=127
|
||||||
|
self.buf_new_last[sli] = nextnews
|
||||||
|
self.buf_ext_rews[sli, t] = ext_rews #
|
||||||
|
_, self.buf_vpred_last[sli], _ = self.policy.get_ac_value_nlp(nextobs) #
|
||||||
|
|
||||||
|
def update_info(self):
|
||||||
|
all_ep_infos = MPI.COMM_WORLD.allgather(self.ep_infos_new)
|
||||||
|
all_ep_infos = sorted(sum(all_ep_infos, []), key=lambda x: x[0])
|
||||||
|
if all_ep_infos:
|
||||||
|
all_ep_infos = [i_[1] for i_ in all_ep_infos] # remove the step_count
|
||||||
|
keys_ = all_ep_infos[0].keys()
|
||||||
|
all_ep_infos = {k: [i[k] for i in all_ep_infos] for k in keys_}
|
||||||
|
# all_ep_infos: {'r': [0.0, 0.0, 0.0], 'l': [124, 125, 127], 't': [6.60745, 12.034875, 10.772788]}
|
||||||
|
|
||||||
|
self.statlists['eprew'].extend(all_ep_infos['r'])
|
||||||
|
self.stats['eprew_recent'] = np.mean(all_ep_infos['r'])
|
||||||
|
self.statlists['eplen'].extend(all_ep_infos['l'])
|
||||||
|
self.stats['epcount'] += len(all_ep_infos['l'])
|
||||||
|
self.stats['tcount'] += sum(all_ep_infos['l'])
|
||||||
|
if 'visited_rooms' in keys_:
|
||||||
|
# Montezuma specific logging.
|
||||||
|
self.stats['visited_rooms'] = sorted(list(set.union(*all_ep_infos['visited_rooms'])))
|
||||||
|
self.stats['pos_count'] = np.mean(all_ep_infos['pos_count'])
|
||||||
|
self.all_visited_rooms.extend(self.stats['visited_rooms'])
|
||||||
|
self.all_scores.extend(all_ep_infos["r"])
|
||||||
|
self.all_scores = sorted(list(set(self.all_scores)))
|
||||||
|
self.all_visited_rooms = sorted(list(set(self.all_visited_rooms)))
|
||||||
|
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
|
print("All visited rooms")
|
||||||
|
print(self.all_visited_rooms)
|
||||||
|
print("All scores")
|
||||||
|
print(self.all_scores)
|
||||||
|
if 'levels' in keys_:
|
||||||
|
# Retro logging
|
||||||
|
temp = sorted(list(set.union(*all_ep_infos['levels'])))
|
||||||
|
self.all_visited_rooms.extend(temp)
|
||||||
|
self.all_visited_rooms = sorted(list(set(self.all_visited_rooms)))
|
||||||
|
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
|
print("All visited levels")
|
||||||
|
print(self.all_visited_rooms)
|
||||||
|
|
||||||
|
current_max = np.max(all_ep_infos['r'])
|
||||||
|
else:
|
||||||
|
current_max = None
|
||||||
|
self.ep_infos_new = []
|
||||||
|
|
||||||
|
# best_ext_ret
|
||||||
|
if current_max is not None:
|
||||||
|
if (self.best_ext_ret is None) or (current_max > self.best_ext_ret):
|
||||||
|
self.best_ext_ret = current_max
|
||||||
|
self.current_max = current_max
|
||||||
|
|
||||||
|
def env_step(self, l, acs):
|
||||||
|
self.envs[l].step_async(acs)
|
||||||
|
self.env_results[l] = None
|
||||||
|
|
||||||
|
def env_get(self, l):
|
||||||
|
if self.step_count == 0:
|
||||||
|
ob = self.envs[l].reset()
|
||||||
|
out = self.env_results[l] = (ob, None, np.ones(self.lump_stride, bool), {})
|
||||||
|
else:
|
||||||
|
if self.env_results[l] is None:
|
||||||
|
out = self.env_results[l] = self.envs[l].step_wait()
|
||||||
|
else:
|
||||||
|
out = self.env_results[l]
|
||||||
|
return out
|
287
run.py
Normal file
287
run.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
try:
|
||||||
|
from OpenGL import GLU
|
||||||
|
except:
|
||||||
|
print("no OpenGL.GLU")
|
||||||
|
import functools
|
||||||
|
import os.path as osp
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
import gym
|
||||||
|
import tensorflow as tf
|
||||||
|
from baselines import logger
|
||||||
|
from baselines.bench import Monitor
|
||||||
|
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
|
||||||
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
from dynamic_bottleneck import DynamicBottleneck
|
||||||
|
from cnn_policy import CnnPolicy
|
||||||
|
from cppo_agent import PpoOptimizer
|
||||||
|
from utils import random_agent_ob_mean_std
|
||||||
|
from wrappers import MontezumaInfoWrapper, make_mario_env, make_robo_pong, make_robo_hockey, \
|
||||||
|
make_multi_pong, AddRandomStateToInfo, MaxAndSkipEnv, ProcessFrame84, ExtraTimeLimit, StickyActionEnv
|
||||||
|
import datetime
|
||||||
|
from wrappers import PixelNoiseWrapper, RandomBoxNoiseWrapper
|
||||||
|
import json
|
||||||
|
|
||||||
|
getsess = tf.get_default_session
|
||||||
|
|
||||||
|
|
||||||
|
def start_experiment(**args):
|
||||||
|
make_env = partial(make_env_all_params, add_monitor=True, args=args)
|
||||||
|
|
||||||
|
trainer = Trainer(make_env=make_env,
|
||||||
|
num_timesteps=args['num_timesteps'], hps=args,
|
||||||
|
envs_per_process=args['envs_per_process'])
|
||||||
|
log, tf_sess, saver, logger_dir = get_experiment_environment(**args)
|
||||||
|
with log, tf_sess:
|
||||||
|
logdir = logger.get_dir()
|
||||||
|
print("results will be saved to ", logdir)
|
||||||
|
trainer.train(saver, logger_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer(object):
|
||||||
|
def __init__(self, make_env, hps, num_timesteps, envs_per_process):
|
||||||
|
self.make_env = make_env
|
||||||
|
self.hps = hps
|
||||||
|
self.envs_per_process = envs_per_process
|
||||||
|
self.num_timesteps = num_timesteps
|
||||||
|
self._set_env_vars()
|
||||||
|
|
||||||
|
self.policy = CnnPolicy(scope='pol',
|
||||||
|
ob_space=self.ob_space,
|
||||||
|
ac_space=self.ac_space,
|
||||||
|
hidsize=512,
|
||||||
|
feat_dim=512,
|
||||||
|
ob_mean=self.ob_mean,
|
||||||
|
ob_std=self.ob_std,
|
||||||
|
layernormalize=False,
|
||||||
|
nl=tf.nn.leaky_relu)
|
||||||
|
|
||||||
|
self.dynamic_bottleneck = DynamicBottleneck(
|
||||||
|
policy=self.policy, feat_dim=512, tau=hps['momentum_tau'], loss_kl_weight=hps['loss_kl_weight'],
|
||||||
|
loss_nce_weight=hps['loss_nce_weight'], loss_l2_weight=hps['loss_l2_weight'], aug=hps['aug'])
|
||||||
|
|
||||||
|
self.agent = PpoOptimizer(
|
||||||
|
scope='ppo',
|
||||||
|
ob_space=self.ob_space,
|
||||||
|
ac_space=self.ac_space,
|
||||||
|
stochpol=self.policy,
|
||||||
|
use_news=hps['use_news'],
|
||||||
|
gamma=hps['gamma'],
|
||||||
|
lam=hps["lambda"],
|
||||||
|
nepochs=hps['nepochs'],
|
||||||
|
nminibatches=hps['nminibatches'],
|
||||||
|
lr=hps['lr'],
|
||||||
|
cliprange=0.1,
|
||||||
|
nsteps_per_seg=hps['nsteps_per_seg'],
|
||||||
|
nsegs_per_env=hps['nsegs_per_env'],
|
||||||
|
ent_coef=hps['ent_coeff'],
|
||||||
|
normrew=hps['norm_rew'],
|
||||||
|
normadv=hps['norm_adv'],
|
||||||
|
ext_coeff=hps['ext_coeff'],
|
||||||
|
int_coeff=hps['int_coeff'],
|
||||||
|
dynamic_bottleneck=self.dynamic_bottleneck
|
||||||
|
)
|
||||||
|
|
||||||
|
self.agent.to_report['db'] = tf.reduce_mean(self.dynamic_bottleneck.loss)
|
||||||
|
self.agent.total_loss += self.agent.to_report['db']
|
||||||
|
|
||||||
|
self.agent.db_loss = tf.reduce_mean(self.dynamic_bottleneck.loss)
|
||||||
|
|
||||||
|
self.agent.to_report['feat_var'] = tf.reduce_mean(tf.nn.moments(self.dynamic_bottleneck.features, [0, 1])[1])
|
||||||
|
|
||||||
|
def _set_env_vars(self):
|
||||||
|
env = self.make_env(0, add_monitor=False)
|
||||||
|
# ob_space.shape=(84, 84, 4) ac_space.shape=Discrete(4)
|
||||||
|
self.ob_space, self.ac_space = env.observation_space, env.action_space
|
||||||
|
self.ob_mean, self.ob_std = random_agent_ob_mean_std(env)
|
||||||
|
del env
|
||||||
|
self.envs = [functools.partial(self.make_env, i) for i in range(self.envs_per_process)]
|
||||||
|
|
||||||
|
def train(self, saver, logger_dir):
|
||||||
|
self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamic_bottleneck=self.dynamic_bottleneck)
|
||||||
|
previous_saved_tcount = 0
|
||||||
|
|
||||||
|
# add bai. initialize IB parameters
|
||||||
|
print("***Init Momentum Network in Dynamic-Bottleneck.")
|
||||||
|
getsess().run(self.dynamic_bottleneck.init_updates)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
info = self.agent.step() #
|
||||||
|
if info['DB_loss_info']: # add bai. for debug
|
||||||
|
logger.logkvs(info['DB_loss_info'])
|
||||||
|
if info['update']:
|
||||||
|
logger.logkvs(info['update'])
|
||||||
|
logger.dumpkvs()
|
||||||
|
if self.hps["save_period"] and (int(self.agent.rollout.stats['tcount'] / self.hps["save_freq"]) > previous_saved_tcount):
|
||||||
|
previous_saved_tcount += 1
|
||||||
|
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
|
||||||
|
print("Periodically model saved in path:", save_path)
|
||||||
|
if self.agent.rollout.stats['tcount'] > self.num_timesteps:
|
||||||
|
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
|
||||||
|
print("Model saved in path:", save_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
self.agent.stop_interaction()
|
||||||
|
|
||||||
|
|
||||||
|
def make_env_all_params(rank, add_monitor, args):
|
||||||
|
if args["env_kind"] == 'atari':
|
||||||
|
env = gym.make(args['env'])
|
||||||
|
assert 'NoFrameskip' in env.spec.id
|
||||||
|
if args["stickyAtari"]: #
|
||||||
|
env._max_episode_steps = args['max_episode_steps'] * 4
|
||||||
|
env = StickyActionEnv(env)
|
||||||
|
else:
|
||||||
|
env = NoopResetEnv(env, noop_max=args['noop_max'])
|
||||||
|
env = MaxAndSkipEnv(env, skip=4) #
|
||||||
|
if args['pixelNoise']: # add pixel noise
|
||||||
|
env = PixelNoiseWrapper(env)
|
||||||
|
if args['randomBoxNoise']:
|
||||||
|
env = RandomBoxNoiseWrapper(env)
|
||||||
|
env = ProcessFrame84(env, crop=False) #
|
||||||
|
env = FrameStack(env, 4) #
|
||||||
|
# env = ExtraTimeLimit(env, args['max_episode_steps'])
|
||||||
|
if not args["stickyAtari"]:
|
||||||
|
env = ExtraTimeLimit(env, args['max_episode_steps']) #
|
||||||
|
if 'Montezuma' in args['env']: #
|
||||||
|
env = MontezumaInfoWrapper(env)
|
||||||
|
env = AddRandomStateToInfo(env)
|
||||||
|
elif args["env_kind"] == 'mario': #
|
||||||
|
env = make_mario_env()
|
||||||
|
elif args["env_kind"] == "retro_multi": #
|
||||||
|
env = make_multi_pong()
|
||||||
|
elif args["env_kind"] == 'robopong':
|
||||||
|
if args["env"] == "pong":
|
||||||
|
env = make_robo_pong()
|
||||||
|
elif args["env"] == "hockey":
|
||||||
|
env = make_robo_hockey()
|
||||||
|
|
||||||
|
if add_monitor:
|
||||||
|
env = Monitor(env, osp.join(logger.get_dir(), '%.2i' % rank))
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def get_experiment_environment(**args):
|
||||||
|
from utils import setup_mpi_gpus, setup_tensorflow_session
|
||||||
|
from baselines.common import set_global_seeds
|
||||||
|
from gym.utils.seeding import hash_seed
|
||||||
|
process_seed = args["seed"] + 1000 * MPI.COMM_WORLD.Get_rank()
|
||||||
|
process_seed = hash_seed(process_seed, max_bytes=4)
|
||||||
|
set_global_seeds(process_seed)
|
||||||
|
setup_mpi_gpus()
|
||||||
|
|
||||||
|
# log dir name
|
||||||
|
logger_dir = './logs/' + args["env"].replace("NoFrameskip-v4", "")
|
||||||
|
# logger_dir += "-KLloss-"+str(args["loss_kl_weight"])
|
||||||
|
# logger_dir += "-NCEloss-" + str(args["loss_nce_weight"])
|
||||||
|
# logger_dir += "-L2loss-" + str(args["loss_l2_weight"])
|
||||||
|
if args['pixelNoise'] is True:
|
||||||
|
logger_dir += "-pixelNoise"
|
||||||
|
if args['randomBoxNoise'] is True:
|
||||||
|
logger_dir += "-randomBoxNoise"
|
||||||
|
if args['stickyAtari'] is True:
|
||||||
|
logger_dir += "-stickyAtari"
|
||||||
|
if args["comments"] != "":
|
||||||
|
logger_dir += '-' + args["comments"]
|
||||||
|
logger_dir += datetime.datetime.now().strftime("-%m-%d-%H-%M-%S")
|
||||||
|
|
||||||
|
# write config
|
||||||
|
logger.configure(dir=logger_dir)
|
||||||
|
with open(os.path.join(logger_dir, 'parameters.txt'), 'w') as f:
|
||||||
|
f.write("\n".join([str(x[0]) + ": " + str(x[1]) for x in args.items()]))
|
||||||
|
|
||||||
|
logger_context = logger.scoped_configure(
|
||||||
|
dir=logger_dir,
|
||||||
|
format_strs=['stdout', 'log', 'csv'] if MPI.COMM_WORLD.Get_rank() == 0 else ['log'])
|
||||||
|
tf_context = setup_tensorflow_session()
|
||||||
|
|
||||||
|
# saver
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
return logger_context, tf_context, saver, logger_dir
|
||||||
|
|
||||||
|
|
||||||
|
def add_environments_params(parser):
|
||||||
|
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4', type=str)
|
||||||
|
parser.add_argument('--max-episode-steps', help='maximum number of timesteps for episode', default=4500, type=int)
|
||||||
|
parser.add_argument('--env_kind', type=str, default="atari")
|
||||||
|
parser.add_argument('--noop_max', type=int, default=30)
|
||||||
|
parser.add_argument('--stickyAtari', action='store_true', default=False)
|
||||||
|
parser.add_argument('--pixelNoise', action='store_true', default=False)
|
||||||
|
parser.add_argument('--randomBoxNoise', action='store_true', default=False)
|
||||||
|
|
||||||
|
|
||||||
|
def add_optimization_params(parser):
|
||||||
|
parser.add_argument('--lambda', type=float, default=0.95)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99) # lambda, gamma 用于计算 GAE advantage
|
||||||
|
parser.add_argument('--nminibatches', type=int, default=8)
|
||||||
|
parser.add_argument('--norm_adv', type=int, default=1) #
|
||||||
|
parser.add_argument('--norm_rew', type=int, default=1) #
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4) #
|
||||||
|
parser.add_argument('--ent_coeff', type=float, default=0.001) #
|
||||||
|
parser.add_argument('--nepochs', type=int, default=3) #
|
||||||
|
parser.add_argument('--num_timesteps', type=int, default=int(1e8))
|
||||||
|
parser.add_argument('--save_period', action='store_true', default=False) # 1e7
|
||||||
|
parser.add_argument('--save_freq', type=int, default=int(1e7)) # 1e7
|
||||||
|
# Parameters of Dynamic-Bottleneck
|
||||||
|
parser.add_argument('--loss_kl_weight', type=float, default=0.1) # KL loss weight
|
||||||
|
parser.add_argument('--loss_l2_weight', type=float, default=0.1) # l2 loss weight
|
||||||
|
parser.add_argument('--loss_nce_weight', type=float, default=0.01) # nce loss weight
|
||||||
|
parser.add_argument('--momentum_tau', type=float, default=0.001) # momentum tau
|
||||||
|
parser.add_argument('--aug', action='store_true', default=False) # data augmentation (bottleneck)
|
||||||
|
parser.add_argument('--comments', type=str, default="")
|
||||||
|
|
||||||
|
|
||||||
|
def add_rollout_params(parser):
|
||||||
|
parser.add_argument('--nsteps_per_seg', type=int, default=128)
|
||||||
|
parser.add_argument('--nsegs_per_env', type=int, default=1)
|
||||||
|
parser.add_argument('--envs_per_process', type=int, default=128)
|
||||||
|
parser.add_argument('--nlumps', type=int, default=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
add_environments_params(parser)
|
||||||
|
add_optimization_params(parser)
|
||||||
|
add_rollout_params(parser)
|
||||||
|
|
||||||
|
parser.add_argument('--exp_name', type=str, default='')
|
||||||
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
parser.add_argument('--dyn_from_pixels', type=int, default=0)
|
||||||
|
parser.add_argument('--use_news', type=int, default=0)
|
||||||
|
parser.add_argument('--ext_coeff', type=float, default=0.)
|
||||||
|
parser.add_argument('--int_coeff', type=float, default=1.)
|
||||||
|
parser.add_argument('--layernorm', type=int, default=0)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# load paramets
|
||||||
|
with open("para.json") as f:
|
||||||
|
d = json.load(f)
|
||||||
|
env_name_para = args.env.replace("NoFrameskip-v4", "")
|
||||||
|
if env_name_para not in list(d["standard"].keys()):
|
||||||
|
env_name_para = "other"
|
||||||
|
|
||||||
|
if args.pixelNoise is True:
|
||||||
|
print("pixel noise")
|
||||||
|
args.loss_kl_weight = d["pixelNoise"][env_name_para]["kl"]
|
||||||
|
args.loss_nce_weight = d["pixelNoise"][env_name_para]["nce"]
|
||||||
|
elif args.randomBoxNoise is True:
|
||||||
|
print("random box noise")
|
||||||
|
args.loss_kl_weight = d["randomBox"][env_name_para]["kl"]
|
||||||
|
args.loss_nce_weight = d["randomBox"][env_name_para]["nce"]
|
||||||
|
elif args.stickyAtari is True:
|
||||||
|
print("sticky noise")
|
||||||
|
args.loss_kl_weight = d["stickyAtari"][env_name_para]["kl"]
|
||||||
|
args.loss_nce_weight = d["stickyAtari"][env_name_para]["nce"]
|
||||||
|
else:
|
||||||
|
print("standard atari")
|
||||||
|
args.loss_kl_weight = d["standard"][env_name_para]["kl"]
|
||||||
|
args.loss_nce_weight = d["standard"][env_name_para]["nce"]
|
||||||
|
|
||||||
|
print("env_name:", env_name_para, "kl:", args.loss_kl_weight, ", nce:", args.loss_nce_weight)
|
||||||
|
start_experiment(**args.__dict__)
|
||||||
|
|
430
utils.py
Normal file
430
utils.py
Normal file
@ -0,0 +1,430 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from baselines.common.tf_util import normc_initializer
|
||||||
|
from mpi4py import MPI
|
||||||
|
import tensorflow_probability as tfp
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
tfd = tfp.distributions
|
||||||
|
|
||||||
|
layers = tf.keras.layers
|
||||||
|
|
||||||
|
|
||||||
|
def bcast_tf_vars_from_root(sess, vars):
|
||||||
|
"""
|
||||||
|
Send the root node's parameters to every worker.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
sess: the TensorFlow session.
|
||||||
|
vars: all parameter variables including optimizer's
|
||||||
|
"""
|
||||||
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
for var in vars:
|
||||||
|
if rank == 0:
|
||||||
|
MPI.COMM_WORLD.bcast(sess.run(var))
|
||||||
|
else:
|
||||||
|
sess.run(tf.assign(var, MPI.COMM_WORLD.bcast(None)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_mean_and_std(array):
|
||||||
|
comm = MPI.COMM_WORLD
|
||||||
|
task_id, num_tasks = comm.Get_rank(), comm.Get_size()
|
||||||
|
local_mean = np.array(np.mean(array))
|
||||||
|
sum_of_means = np.zeros((), dtype=np.float32)
|
||||||
|
comm.Allreduce(local_mean, sum_of_means, op=MPI.SUM)
|
||||||
|
mean = sum_of_means / num_tasks
|
||||||
|
|
||||||
|
n_array = array - mean
|
||||||
|
sqs = n_array ** 2
|
||||||
|
local_mean = np.array(np.mean(sqs))
|
||||||
|
sum_of_means = np.zeros((), dtype=np.float32)
|
||||||
|
comm.Allreduce(local_mean, sum_of_means, op=MPI.SUM)
|
||||||
|
var = sum_of_means / num_tasks
|
||||||
|
std = var ** 0.5
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def guess_available_gpus(n_gpus=None):
|
||||||
|
if n_gpus is not None:
|
||||||
|
return list(range(n_gpus))
|
||||||
|
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||||
|
cuda_visible_divices = os.environ['CUDA_VISIBLE_DEVICES']
|
||||||
|
cuda_visible_divices = cuda_visible_divices.split(',')
|
||||||
|
return [int(n) for n in cuda_visible_divices]
|
||||||
|
nvidia_dir = '/proc/driver/nvidia/gpus/'
|
||||||
|
if os.path.exists(nvidia_dir):
|
||||||
|
n_gpus = len(os.listdir(nvidia_dir))
|
||||||
|
return list(range(n_gpus))
|
||||||
|
raise Exception("Couldn't guess the available gpus on this machine")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_mpi_gpus():
|
||||||
|
"""
|
||||||
|
Set CUDA_VISIBLE_DEVICES using MPI.
|
||||||
|
"""
|
||||||
|
available_gpus = guess_available_gpus()
|
||||||
|
|
||||||
|
node_id = platform.node()
|
||||||
|
nodes_ordered_by_rank = MPI.COMM_WORLD.allgather(node_id)
|
||||||
|
processes_outranked_on_this_node = [n for n in nodes_ordered_by_rank[:MPI.COMM_WORLD.Get_rank()] if n == node_id]
|
||||||
|
local_rank = len(processes_outranked_on_this_node)
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(available_gpus[local_rank])
|
||||||
|
|
||||||
|
|
||||||
|
def guess_available_cpus():
|
||||||
|
return int(multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
|
||||||
|
def setup_tensorflow_session():
|
||||||
|
num_cpu = guess_available_cpus()
|
||||||
|
|
||||||
|
tf_config = tf.ConfigProto(
|
||||||
|
inter_op_parallelism_threads=num_cpu,
|
||||||
|
intra_op_parallelism_threads=num_cpu
|
||||||
|
)
|
||||||
|
tf_config.gpu_options.allow_growth = True
|
||||||
|
return tf.Session(config=tf_config)
|
||||||
|
|
||||||
|
|
||||||
|
def random_agent_ob_mean_std(env, nsteps=10000):
|
||||||
|
ob = np.asarray(env.reset())
|
||||||
|
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
|
obs = [ob]
|
||||||
|
for _ in range(nsteps):
|
||||||
|
ac = env.action_space.sample()
|
||||||
|
ob, _, done, _ = env.step(ac)
|
||||||
|
if done:
|
||||||
|
ob = env.reset()
|
||||||
|
obs.append(np.asarray(ob))
|
||||||
|
mean = np.mean(obs, 0).astype(np.float32)
|
||||||
|
std = np.std(obs, 0).mean().astype(np.float32)
|
||||||
|
else:
|
||||||
|
mean = np.empty(shape=ob.shape, dtype=np.float32)
|
||||||
|
std = np.empty(shape=(), dtype=np.float32)
|
||||||
|
MPI.COMM_WORLD.Bcast(mean, root=0)
|
||||||
|
MPI.COMM_WORLD.Bcast(std, root=0)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def layernorm(x):
|
||||||
|
m, v = tf.nn.moments(x, -1, keep_dims=True)
|
||||||
|
return (x - m) / (tf.sqrt(v) + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
getsess = tf.get_default_session
|
||||||
|
|
||||||
|
fc = partial(tf.layers.dense, kernel_initializer=normc_initializer(1.))
|
||||||
|
activ = tf.nn.relu
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_two_dims(x):
|
||||||
|
return tf.reshape(x, [-1] + x.get_shape().as_list()[2:])
|
||||||
|
|
||||||
|
|
||||||
|
def unflatten_first_dim(x, sh):
|
||||||
|
return tf.reshape(x, [sh[0], sh[1]] + x.get_shape().as_list()[1:])
|
||||||
|
|
||||||
|
|
||||||
|
def add_pos_bias(x):
|
||||||
|
with tf.variable_scope(name_or_scope=None, default_name="pos_bias"):
|
||||||
|
b = tf.get_variable(name="pos_bias", shape=[1] + x.get_shape().as_list()[1:], dtype=tf.float32,
|
||||||
|
initializer=tf.zeros_initializer())
|
||||||
|
return x + b
|
||||||
|
|
||||||
|
|
||||||
|
def small_convnet(x, nl, feat_dim, last_nl, layernormalize, batchnorm=False):
|
||||||
|
# nl=512, feat_dim=None, last_nl=0, layernormalize=0, batchnorm=False
|
||||||
|
bn = tf.layers.batch_normalization if batchnorm else lambda x: x
|
||||||
|
x = bn(tf.layers.conv2d(x, filters=32, kernel_size=8, strides=(4, 4), activation=nl))
|
||||||
|
x = bn(tf.layers.conv2d(x, filters=64, kernel_size=4, strides=(2, 2), activation=nl))
|
||||||
|
x = bn(tf.layers.conv2d(x, filters=64, kernel_size=3, strides=(1, 1), activation=nl))
|
||||||
|
x = tf.reshape(x, (-1, np.prod(x.get_shape().as_list()[1:])))
|
||||||
|
x = bn(fc(x, units=feat_dim, activation=None))
|
||||||
|
if last_nl is not None:
|
||||||
|
x = last_nl(x)
|
||||||
|
if layernormalize:
|
||||||
|
x = layernorm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# new add
|
||||||
|
class SmallConv(tf.keras.Model):
|
||||||
|
def __init__(self, feat_dim, name=None):
|
||||||
|
super(SmallConv, self).__init__(name=name)
|
||||||
|
self.conv1 = layers.Conv2D(filters=32, kernel_size=8, strides=(4, 4), activation=tf.nn.leaky_relu)
|
||||||
|
self.conv2 = layers.Conv2D(filters=64, kernel_size=4, strides=(2, 2), activation=tf.nn.leaky_relu)
|
||||||
|
self.conv3 = layers.Conv2D(filters=64, kernel_size=3, strides=(1, 1), activation=tf.nn.leaky_relu)
|
||||||
|
self.fc = layers.Dense(units=feat_dim, activation=None)
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = tf.reshape(x, (-1, np.prod(x.get_shape().as_list()[1:])))
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# new add
|
||||||
|
class ResBlock(tf.keras.Model):
|
||||||
|
def __init__(self, hidsize):
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
self.hidsize = hidsize
|
||||||
|
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||||
|
self.dense2 = layers.Dense(hidsize, activation=None)
|
||||||
|
|
||||||
|
def call(self, xs):
|
||||||
|
x, a = xs
|
||||||
|
res = self.dense1(tf.concat([x, a], axis=-1))
|
||||||
|
res = self.dense2(tf.concat([res, a], axis=-1))
|
||||||
|
assert x.get_shape().as_list()[-1] == self.hidsize and res.get_shape().as_list()[-1] == self.hidsize
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
|
||||||
|
# new add
|
||||||
|
class TransitionNetwork(tf.keras.Model):
|
||||||
|
def __init__(self, hidsize=256, name=None):
|
||||||
|
super(TransitionNetwork, self).__init__(name=name)
|
||||||
|
self.hidsize = hidsize
|
||||||
|
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||||
|
self.residual_block1 = ResBlock(hidsize)
|
||||||
|
self.residual_block2 = ResBlock(hidsize)
|
||||||
|
self.dense2 = layers.Dense(hidsize, activation=None)
|
||||||
|
|
||||||
|
def call(self, xs):
|
||||||
|
s, a = xs
|
||||||
|
sh = tf.shape(a) # sh=(None,None,4)
|
||||||
|
assert len(s.get_shape().as_list()) == 3 and s.get_shape().as_list()[-1] in [512, 256]
|
||||||
|
assert len(a.get_shape().as_list()) == 3
|
||||||
|
|
||||||
|
x = flatten_two_dims(s) # shape=(None,512)
|
||||||
|
a = flatten_two_dims(a) # shape=(None,4)
|
||||||
|
|
||||||
|
#
|
||||||
|
x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256)
|
||||||
|
x = self.residual_block1([x, a]) # (None, 256)
|
||||||
|
x = self.residual_block2([x, a]) # (None, 256)
|
||||||
|
x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256)
|
||||||
|
x = unflatten_first_dim(x, sh) # shape=(None, None, 256)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeNetworkGaussianFix(tf.keras.Model):
|
||||||
|
def __init__(self, hidsize=256, outsize=512, name=None):
|
||||||
|
super(GenerativeNetworkGaussianFix, self).__init__(name=name)
|
||||||
|
self.outsize = outsize
|
||||||
|
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||||
|
self.dense2 = layers.Dense(outsize, activation=tf.nn.leaky_relu)
|
||||||
|
self.var_single = tf.Variable(1.0, trainable=True)
|
||||||
|
|
||||||
|
self.residual_block1 = tf.keras.Sequential([
|
||||||
|
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||||
|
layers.Dense(hidsize, activation=None)
|
||||||
|
])
|
||||||
|
self.residual_block2 = tf.keras.Sequential([
|
||||||
|
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||||
|
layers.Dense(hidsize, activation=None)
|
||||||
|
])
|
||||||
|
self.residual_block3 = tf.keras.Sequential([
|
||||||
|
layers.Dense(outsize, activation=tf.nn.leaky_relu), # 512
|
||||||
|
layers.Dense(outsize, activation=None)
|
||||||
|
])
|
||||||
|
|
||||||
|
def call(self, z):
|
||||||
|
sh = tf.shape(z) # z, sh=(None,None,128)
|
||||||
|
assert z.get_shape().as_list()[-1] == 128 and len(z.get_shape().as_list()) == 3
|
||||||
|
z = flatten_two_dims(z) # shape=(None,128)
|
||||||
|
|
||||||
|
x = self.dense1(z) # (None, 256)
|
||||||
|
x = x + self.residual_block1(x) # (None, 256)
|
||||||
|
x = x + self.residual_block2(x) # (None, 256)
|
||||||
|
|
||||||
|
# variance
|
||||||
|
var_tile = tf.tile(tf.expand_dims(tf.expand_dims(self.var_single, axis=0), axis=0), [16*128, self.outsize])
|
||||||
|
|
||||||
|
# mean
|
||||||
|
x = self.dense2(x) # (None, 512)
|
||||||
|
x = x + self.residual_block3(x) # (None, 512) mean
|
||||||
|
|
||||||
|
# concat and return
|
||||||
|
x = tf.concat([x, var_tile], axis=-1) # (None, 1024)
|
||||||
|
x = unflatten_first_dim(x, sh) # shape=(None, None, 1024)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeNetworkGaussian(tf.keras.Model):
|
||||||
|
def __init__(self, hidsize=256, outsize=512, name=None):
|
||||||
|
super(GenerativeNetworkGaussian, self).__init__(name=name)
|
||||||
|
self.dense1 = layers.Dense(hidsize, activation=tf.nn.leaky_relu)
|
||||||
|
self.dense2 = layers.Dense(outsize, activation=tf.nn.leaky_relu)
|
||||||
|
self.dense3 = layers.Dense(outsize*2, activation=tf.nn.leaky_relu)
|
||||||
|
|
||||||
|
self.residual_block1 = tf.keras.Sequential([
|
||||||
|
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||||
|
layers.Dense(hidsize, activation=None)
|
||||||
|
])
|
||||||
|
self.residual_block2 = tf.keras.Sequential([
|
||||||
|
layers.Dense(hidsize, activation=tf.nn.leaky_relu), # 256
|
||||||
|
layers.Dense(hidsize, activation=None)
|
||||||
|
])
|
||||||
|
self.residual_block3 = tf.keras.Sequential([
|
||||||
|
layers.Dense(outsize, activation=tf.nn.leaky_relu), # 512
|
||||||
|
layers.Dense(outsize, activation=None)
|
||||||
|
])
|
||||||
|
|
||||||
|
def call(self, z):
|
||||||
|
sh = tf.shape(z) # z, sh=(None,None,128)
|
||||||
|
assert z.get_shape().as_list()[-1] == 128 and len(z.get_shape().as_list()) == 3
|
||||||
|
z = flatten_two_dims(z) # shape=(None,128)
|
||||||
|
|
||||||
|
x = self.dense1(z) # (None, 256)
|
||||||
|
x = x + self.residual_block1(x) # (None, 256)
|
||||||
|
x = x + self.residual_block2(x) # (None, 256)
|
||||||
|
x = self.dense2(x) # (None, 512)
|
||||||
|
x = x + self.residual_block3(x) # (None, 512)
|
||||||
|
x = self.dense3(x) # (None, 1024)
|
||||||
|
x = unflatten_first_dim(x, sh) # shape=(None, None, 1024)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectionHead(tf.keras.Model):
|
||||||
|
def __init__(self, name=None):
|
||||||
|
super(ProjectionHead, self).__init__(name=name)
|
||||||
|
self.dense1 = layers.Dense(256, activation=None)
|
||||||
|
self.dense2 = layers.Dense(128, activation=None)
|
||||||
|
self.ln1 = layers.LayerNormalization()
|
||||||
|
self.ln2 = layers.LayerNormalization()
|
||||||
|
|
||||||
|
def call(self, x, ln=False):
|
||||||
|
assert x.get_shape().as_list()[-1] == 512 and len(x.get_shape().as_list()) == 3
|
||||||
|
x = flatten_two_dims(x) # shape=(None,512)
|
||||||
|
x = self.dense1(x) # shape=(None,256)
|
||||||
|
x = self.ln1(x) # layer norm
|
||||||
|
x = tf.nn.relu(x) # relu
|
||||||
|
x = self.dense2(x) # shape=(None,128)
|
||||||
|
x = self.ln2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ContrastiveHead(tf.keras.Model):
|
||||||
|
def __init__(self, temperature, z_dim=128, name=None):
|
||||||
|
super(ContrastiveHead, self).__init__(name=name)
|
||||||
|
self.W = tf.Variable(tf.random.uniform((z_dim, z_dim)), name='W_Contras')
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def call(self, z_a_pos):
|
||||||
|
z_a, z_pos = z_a_pos
|
||||||
|
Wz = tf.linalg.matmul(self.W, z_pos, transpose_b=True) # (z_dim,B) Wz.shape = (50,32)
|
||||||
|
logits = tf.linalg.matmul(z_a, Wz) # (B,B) logits.shape = (32,32)
|
||||||
|
logits = logits - tf.reduce_max(logits, 1)[:, None] # logits
|
||||||
|
logits = logits * self.temperature
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def rec_log_prob(rec_params, s_next, min_sigma=1e-2):
|
||||||
|
# rec_params.shape = (None, None, 1024)
|
||||||
|
distr = normal_parse_params(rec_params, min_sigma)
|
||||||
|
log_prob = distr.log_prob(s_next) # (None, None, 512)
|
||||||
|
assert len(log_prob.get_shape().as_list()) == 3 and log_prob.get_shape().as_list()[-1] == 512
|
||||||
|
return tf.reduce_sum(log_prob, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def normal_parse_params(params, min_sigma=0.0):
|
||||||
|
n = params.shape[0]
|
||||||
|
d = params.shape[-1] # channel
|
||||||
|
mu = params[..., :d // 2] #
|
||||||
|
sigma_params = params[..., d // 2:]
|
||||||
|
sigma = tf.math.softplus(sigma_params)
|
||||||
|
sigma = tf.clip_by_value(t=sigma, clip_value_min=min_sigma, clip_value_max=1e5)
|
||||||
|
|
||||||
|
distr = tfd.Normal(loc=mu, scale=sigma) #
|
||||||
|
return distr
|
||||||
|
|
||||||
|
|
||||||
|
def tile_images(array, n_cols=None, max_images=None, div=1):
|
||||||
|
if max_images is not None:
|
||||||
|
array = array[:max_images]
|
||||||
|
if len(array.shape) == 4 and array.shape[3] == 1:
|
||||||
|
array = array[:, :, :, 0]
|
||||||
|
assert len(array.shape) in [3, 4], "wrong number of dimensions - shape {}".format(array.shape)
|
||||||
|
if len(array.shape) == 4:
|
||||||
|
assert array.shape[3] == 3, "wrong number of channels- shape {}".format(array.shape)
|
||||||
|
if n_cols is None:
|
||||||
|
n_cols = max(int(np.sqrt(array.shape[0])) // div * div, div)
|
||||||
|
n_rows = int(np.ceil(float(array.shape[0]) / n_cols))
|
||||||
|
|
||||||
|
def cell(i, j):
|
||||||
|
ind = i * n_cols + j
|
||||||
|
return array[ind] if ind < array.shape[0] else np.zeros(array[0].shape)
|
||||||
|
|
||||||
|
def row(i):
|
||||||
|
return np.concatenate([cell(i, j) for j in range(n_cols)], axis=1)
|
||||||
|
|
||||||
|
return np.concatenate([row(i) for i in range(n_rows)], axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import distutils.spawn
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def save_np_as_mp4(frames, filename, frames_per_sec=30):
|
||||||
|
print(filename)
|
||||||
|
if distutils.spawn.find_executable('avconv') is not None:
|
||||||
|
backend = 'avconv'
|
||||||
|
elif distutils.spawn.find_executable('ffmpeg') is not None:
|
||||||
|
backend = 'ffmpeg'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""Found neither the ffmpeg nor avconv executables. On OS X, you can install ffmpeg via `brew install ffmpeg`. On most Ubuntu variants, `sudo apt-get install ffmpeg` should do it. On Ubuntu 14.04, however, you'll need to install avconv with `sudo apt-get install libav-tools`.""")
|
||||||
|
|
||||||
|
h, w = frames[0].shape[:2]
|
||||||
|
output_path = filename
|
||||||
|
cmdline = (backend,
|
||||||
|
'-nostats',
|
||||||
|
'-loglevel', 'error', # suppress warnings
|
||||||
|
'-y',
|
||||||
|
'-r', '%d' % frames_per_sec,
|
||||||
|
|
||||||
|
# input
|
||||||
|
'-f', 'rawvideo',
|
||||||
|
'-s:v', '{}x{}'.format(w, h),
|
||||||
|
'-pix_fmt', 'rgb24',
|
||||||
|
'-i', '-', # this used to be /dev/stdin, which is not Windows-friendly
|
||||||
|
|
||||||
|
# output
|
||||||
|
'-vcodec', 'libx264',
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
output_path)
|
||||||
|
|
||||||
|
print('saving ', output_path)
|
||||||
|
if hasattr(os, 'setsid'): # setsid not present on Windows
|
||||||
|
process = subprocess.Popen(cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
|
||||||
|
else:
|
||||||
|
process = subprocess.Popen(cmdline, stdin=subprocess.PIPE)
|
||||||
|
process.stdin.write(np.array(frames).tobytes())
|
||||||
|
process.stdin.close()
|
||||||
|
ret = process.wait()
|
||||||
|
if ret != 0:
|
||||||
|
print("VideoRecorder encoder exited with status {}".format(ret))
|
||||||
|
|
||||||
|
|
||||||
|
# ExponentialSchedule
|
||||||
|
class ExponentialSchedule(object):
|
||||||
|
def __init__(self, start_value, decay_factor, end_value, outside_value=None):
|
||||||
|
"""Exponential Schedule.
|
||||||
|
y = start_value * (1.0 - decay_factor) ^ t
|
||||||
|
"""
|
||||||
|
assert 0.0 <= decay_factor <= 1.0
|
||||||
|
self.start_value = start_value
|
||||||
|
self.decay_factor = decay_factor
|
||||||
|
self.end_value = end_value
|
||||||
|
|
||||||
|
def value(self, t):
|
||||||
|
v = self.start_value * np.power(1.0 - self.decay_factor, t/int(1e5))
|
||||||
|
return np.maximum(v, self.end_value)
|
222
vec_env.py
Normal file
222
vec_env.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
"""
|
||||||
|
An interface for asynchronous vectorized environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from multiprocessing import Pipe, Array, Process
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from baselines import logger
|
||||||
|
|
||||||
|
_NP_TO_CT = {np.float32: ctypes.c_float,
|
||||||
|
np.int32: ctypes.c_int32,
|
||||||
|
np.int8: ctypes.c_int8,
|
||||||
|
np.uint8: ctypes.c_char,
|
||||||
|
np.bool: ctypes.c_bool}
|
||||||
|
_CT_TO_NP = {v: k for k, v in _NP_TO_CT.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class CloudpickleWrapper(object):
|
||||||
|
"""
|
||||||
|
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, x):
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
import cloudpickle
|
||||||
|
return cloudpickle.dumps(self.x)
|
||||||
|
|
||||||
|
def __setstate__(self, ob):
|
||||||
|
import pickle
|
||||||
|
self.x = pickle.loads(ob)
|
||||||
|
|
||||||
|
|
||||||
|
class VecEnv(ABC):
|
||||||
|
"""
|
||||||
|
An abstract asynchronous, vectorized environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_envs, observation_space, action_space):
|
||||||
|
self.num_envs = num_envs
|
||||||
|
self.observation_space = observation_space
|
||||||
|
self.action_space = action_space
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset all the environments and return an array of
|
||||||
|
observations, or a tuple of observation arrays.
|
||||||
|
|
||||||
|
If step_async is still doing work, that work will
|
||||||
|
be cancelled and step_wait() should not be called
|
||||||
|
until step_async() is invoked again.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def step_async(self, actions):
|
||||||
|
"""
|
||||||
|
Tell all the environments to start taking a step
|
||||||
|
with the given actions.
|
||||||
|
Call step_wait() to get the results of the step.
|
||||||
|
|
||||||
|
You should not call this if a step_async run is
|
||||||
|
already pending.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def step_wait(self):
|
||||||
|
"""
|
||||||
|
Wait for the step taken with step_async().
|
||||||
|
|
||||||
|
Returns (obs, rews, dones, infos):
|
||||||
|
- obs: an array of observations, or a tuple of
|
||||||
|
arrays of observations.
|
||||||
|
- rews: an array of rewards
|
||||||
|
- dones: an array of "episode done" booleans
|
||||||
|
- infos: a sequence of info objects
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Clean up the environments' resources.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def step(self, actions):
|
||||||
|
self.step_async(actions)
|
||||||
|
return self.step_wait()
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
logger.warn('Render not defined for %s' % self)
|
||||||
|
|
||||||
|
|
||||||
|
class ShmemVecEnv(VecEnv):
|
||||||
|
"""
|
||||||
|
An AsyncEnv that uses multiprocessing to run multiple
|
||||||
|
environments in parallel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env_fns, spaces=None):
|
||||||
|
"""
|
||||||
|
If you don't specify observation_space, we'll have to create a dummy
|
||||||
|
environment to get it.
|
||||||
|
"""
|
||||||
|
if spaces:
|
||||||
|
observation_space, action_space = spaces
|
||||||
|
else:
|
||||||
|
logger.log('Creating dummy env object to get spaces')
|
||||||
|
with logger.scoped_configure(format_strs=[]):
|
||||||
|
dummy = env_fns[0]()
|
||||||
|
observation_space, action_space = dummy.observation_space, dummy.action_space
|
||||||
|
dummy.close()
|
||||||
|
del dummy
|
||||||
|
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||||
|
|
||||||
|
obs_spaces = observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (
|
||||||
|
self.observation_space,)
|
||||||
|
self.obs_bufs = [tuple(Array(_NP_TO_CT[s.dtype.type], int(np.prod(s.shape))) for s in obs_spaces) for _ in
|
||||||
|
env_fns]
|
||||||
|
self.obs_shapes = [s.shape for s in obs_spaces]
|
||||||
|
self.obs_dtypes = [s.dtype for s in obs_spaces]
|
||||||
|
|
||||||
|
self.parent_pipes = []
|
||||||
|
self.procs = []
|
||||||
|
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
|
||||||
|
wrapped_fn = CloudpickleWrapper(env_fn)
|
||||||
|
parent_pipe, child_pipe = Pipe()
|
||||||
|
proc = Process(target=_subproc_worker,
|
||||||
|
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes))
|
||||||
|
proc.daemon = True
|
||||||
|
self.procs.append(proc)
|
||||||
|
self.parent_pipes.append(parent_pipe)
|
||||||
|
proc.start()
|
||||||
|
child_pipe.close()
|
||||||
|
self.waiting_step = False
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if self.waiting_step:
|
||||||
|
logger.warn('Called reset() while waiting for the step to complete')
|
||||||
|
self.step_wait()
|
||||||
|
for pipe in self.parent_pipes:
|
||||||
|
pipe.send(('reset', None))
|
||||||
|
return self._decode_obses([pipe.recv() for pipe in self.parent_pipes])
|
||||||
|
|
||||||
|
def step_async(self, actions):
|
||||||
|
assert len(actions) == len(self.parent_pipes)
|
||||||
|
for pipe, act in zip(self.parent_pipes, actions):
|
||||||
|
pipe.send(('step', act))
|
||||||
|
|
||||||
|
def step_wait(self):
|
||||||
|
outs = [pipe.recv() for pipe in self.parent_pipes]
|
||||||
|
obs, rews, dones, infos = zip(*outs)
|
||||||
|
return self._decode_obses(obs), np.array(rews), np.array(dones), infos
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.waiting_step:
|
||||||
|
self.step_wait()
|
||||||
|
for pipe in self.parent_pipes:
|
||||||
|
pipe.send(('close', None))
|
||||||
|
for pipe in self.parent_pipes:
|
||||||
|
pipe.recv()
|
||||||
|
pipe.close()
|
||||||
|
for proc in self.procs:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
def _decode_obses(self, obs):
|
||||||
|
"""
|
||||||
|
Turn the observation responses into a single numpy
|
||||||
|
array, possibly via shared memory.
|
||||||
|
"""
|
||||||
|
obs = []
|
||||||
|
for i, shape in enumerate(self.obs_shapes):
|
||||||
|
bufs = [b[i] for b in self.obs_bufs]
|
||||||
|
o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[i]).reshape(shape) for b in bufs]
|
||||||
|
obs.append(np.array(o))
|
||||||
|
return tuple(obs) if len(obs) > 1 else obs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_buf, obs_shape):
|
||||||
|
"""
|
||||||
|
Control a single environment instance using IPC and
|
||||||
|
shared memory.
|
||||||
|
|
||||||
|
If obs_buf is not None, it is a shared-memory buffer
|
||||||
|
for communicating observations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _write_obs(obs):
|
||||||
|
if not isinstance(obs, tuple):
|
||||||
|
obs = (obs,)
|
||||||
|
for o, b, s in zip(obs, obs_buf, obs_shape):
|
||||||
|
dst = b.get_obj()
|
||||||
|
dst_np = np.frombuffer(dst, dtype=_CT_TO_NP[dst._type_]).reshape(s) # pylint: disable=W0212
|
||||||
|
np.copyto(dst_np, o)
|
||||||
|
|
||||||
|
env = env_fn_wrapper.x()
|
||||||
|
parent_pipe.close()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
cmd, data = pipe.recv()
|
||||||
|
if cmd == 'reset':
|
||||||
|
pipe.send(_write_obs(env.reset()))
|
||||||
|
elif cmd == 'step':
|
||||||
|
obs, reward, done, info = env.step(data)
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
pipe.send((_write_obs(obs), reward, done, info))
|
||||||
|
elif cmd == 'close':
|
||||||
|
pipe.send(None)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Got unrecognized cmd %s' % cmd)
|
||||||
|
finally:
|
||||||
|
env.close()
|
475
wrappers.py
Normal file
475
wrappers.py
Normal file
@ -0,0 +1,475 @@
|
|||||||
|
import itertools
|
||||||
|
from collections import deque
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap(env):
|
||||||
|
if hasattr(env, "unwrapped"):
|
||||||
|
return env.unwrapped
|
||||||
|
elif hasattr(env, "env"):
|
||||||
|
return unwrap(env.env)
|
||||||
|
elif hasattr(env, "leg_env"):
|
||||||
|
return unwrap(env.leg_env)
|
||||||
|
else:
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
class MaxAndSkipEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env, skip=4):
|
||||||
|
"""Return only every `skip`-th frame"""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
# most recent raw observations (for max pooling across time steps)
|
||||||
|
self._obs_buffer = deque(maxlen=2)
|
||||||
|
self._skip = skip
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""Repeat action, sum reward, and max over last observations."""
|
||||||
|
total_reward = 0.0
|
||||||
|
done = None
|
||||||
|
acc_info = {}
|
||||||
|
for _ in range(self._skip):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
acc_info.update(info)
|
||||||
|
self._obs_buffer.append(obs)
|
||||||
|
total_reward += reward
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||||
|
|
||||||
|
return max_frame, total_reward, done, acc_info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||||
|
self._obs_buffer.clear()
|
||||||
|
obs = self.env.reset()
|
||||||
|
self._obs_buffer.append(obs)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessFrame84(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, crop=True):
|
||||||
|
self.crop = crop
|
||||||
|
super(ProcessFrame84, self).__init__(env)
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||||||
|
|
||||||
|
def observation(self, obs):
|
||||||
|
return ProcessFrame84.process(obs, crop=self.crop)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process(frame, crop=True):
|
||||||
|
if frame.size == 210 * 160 * 3:
|
||||||
|
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
|
||||||
|
elif frame.size == 250 * 160 * 3:
|
||||||
|
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
|
||||||
|
elif frame.size == 224 * 240 * 3: # mario resolution
|
||||||
|
img = np.reshape(frame, [224, 240, 3]).astype(np.float32)
|
||||||
|
else:
|
||||||
|
assert False, "Unknown resolution." + str(frame.size)
|
||||||
|
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
|
||||||
|
size = (84, 110 if crop else 84)
|
||||||
|
resized_screen = np.array(Image.fromarray(img).resize(size,
|
||||||
|
resample=Image.BILINEAR), dtype=np.uint8)
|
||||||
|
x_t = resized_screen[18:102, :] if crop else resized_screen
|
||||||
|
x_t = np.reshape(x_t, [84, 84, 1])
|
||||||
|
return x_t.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraTimeLimit(gym.Wrapper):
|
||||||
|
def __init__(self, env, max_episode_steps=None):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self._max_episode_steps = max_episode_steps
|
||||||
|
self._elapsed_steps = 0
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
observation, reward, done, info = self.env.step(action)
|
||||||
|
self._elapsed_steps += 1
|
||||||
|
if self._elapsed_steps > self._max_episode_steps:
|
||||||
|
done = True
|
||||||
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._elapsed_steps = 0
|
||||||
|
return self.env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class AddRandomStateToInfo(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
"""Adds the random state to the info field on the first step after reset
|
||||||
|
"""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
ob, r, d, info = self.env.step(action)
|
||||||
|
if self.random_state_copy is not None:
|
||||||
|
info['random_state'] = self.random_state_copy
|
||||||
|
self.random_state_copy = None
|
||||||
|
return ob, r, d, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||||
|
self.random_state_copy = copy(self.unwrapped.np_random)
|
||||||
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MontezumaInfoWrapper(gym.Wrapper):
|
||||||
|
ram_map = {
|
||||||
|
"room": dict(
|
||||||
|
index=3,
|
||||||
|
),
|
||||||
|
"x": dict(
|
||||||
|
index=42,
|
||||||
|
),
|
||||||
|
"y": dict(
|
||||||
|
index=43,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, env):
|
||||||
|
super(MontezumaInfoWrapper, self).__init__(env)
|
||||||
|
self.visited = set()
|
||||||
|
self.visited_rooms = set()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, rew, done, info = self.env.step(action)
|
||||||
|
ram_state = unwrap(self.env).ale.getRAM()
|
||||||
|
for name, properties in MontezumaInfoWrapper.ram_map.items():
|
||||||
|
info[name] = ram_state[properties['index']]
|
||||||
|
pos = (info['x'], info['y'], info['room'])
|
||||||
|
self.visited.add(pos)
|
||||||
|
self.visited_rooms.add(info["room"])
|
||||||
|
if done:
|
||||||
|
info['mz_episode'] = dict(pos_count=len(self.visited),
|
||||||
|
visited_rooms=copy(self.visited_rooms))
|
||||||
|
self.visited.clear()
|
||||||
|
self.visited_rooms.clear()
|
||||||
|
return obs, rew, done, info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class MarioXReward(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.current_level = [0, 0]
|
||||||
|
self.visited_levels = set()
|
||||||
|
self.visited_levels.add(tuple(self.current_level))
|
||||||
|
self.current_max_x = 0.
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
ob = self.env.reset()
|
||||||
|
self.current_level = [0, 0]
|
||||||
|
self.visited_levels = set()
|
||||||
|
self.visited_levels.add(tuple(self.current_level))
|
||||||
|
self.current_max_x = 0.
|
||||||
|
return ob
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
ob, reward, done, info = self.env.step(action)
|
||||||
|
levellow, levelhigh, xscrollHi, xscrollLo = \
|
||||||
|
info["levelLo"], info["levelHi"], info["xscrollHi"], info["xscrollLo"]
|
||||||
|
currentx = xscrollHi * 256 + xscrollLo
|
||||||
|
new_level = [levellow, levelhigh]
|
||||||
|
if new_level != self.current_level:
|
||||||
|
self.current_level = new_level
|
||||||
|
self.current_max_x = 0.
|
||||||
|
reward = 0.
|
||||||
|
self.visited_levels.add(tuple(self.current_level))
|
||||||
|
else:
|
||||||
|
if currentx > self.current_max_x:
|
||||||
|
delta = currentx - self.current_max_x
|
||||||
|
self.current_max_x = currentx
|
||||||
|
reward = delta
|
||||||
|
else:
|
||||||
|
reward = 0.
|
||||||
|
if done:
|
||||||
|
info["levels"] = copy(self.visited_levels)
|
||||||
|
info["retro_episode"] = dict(levels=copy(self.visited_levels))
|
||||||
|
return ob, reward, done, info
|
||||||
|
|
||||||
|
|
||||||
|
class LimitedDiscreteActions(gym.ActionWrapper):
|
||||||
|
KNOWN_BUTTONS = {"A", "B"}
|
||||||
|
KNOWN_SHOULDERS = {"L", "R"}
|
||||||
|
|
||||||
|
'''
|
||||||
|
Reproduces the action space from curiosity paper.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, env, all_buttons, whitelist=KNOWN_BUTTONS | KNOWN_SHOULDERS):
|
||||||
|
gym.ActionWrapper.__init__(self, env)
|
||||||
|
|
||||||
|
self._num_buttons = len(all_buttons)
|
||||||
|
button_keys = {i for i in range(len(all_buttons)) if all_buttons[i] in whitelist & self.KNOWN_BUTTONS}
|
||||||
|
buttons = [(), *zip(button_keys), *itertools.combinations(button_keys, 2)]
|
||||||
|
shoulder_keys = {i for i in range(len(all_buttons)) if all_buttons[i] in whitelist & self.KNOWN_SHOULDERS}
|
||||||
|
shoulders = [(), *zip(shoulder_keys), *itertools.permutations(shoulder_keys, 2)]
|
||||||
|
arrows = [(), (4,), (5,), (6,), (7,)] # (), up, down, left, right
|
||||||
|
acts = []
|
||||||
|
acts += arrows
|
||||||
|
acts += buttons[1:]
|
||||||
|
acts += [a + b for a in arrows[-2:] for b in buttons[1:]]
|
||||||
|
self._actions = acts
|
||||||
|
self.action_space = gym.spaces.Discrete(len(self._actions))
|
||||||
|
|
||||||
|
def action(self, a):
|
||||||
|
mask = np.zeros(self._num_buttons)
|
||||||
|
for i in self._actions[a]:
|
||||||
|
mask[i] = 1
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class FrameSkip(gym.Wrapper):
|
||||||
|
def __init__(self, env, n):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
done = False
|
||||||
|
totrew = 0
|
||||||
|
for _ in range(self.n):
|
||||||
|
ob, rew, done, info = self.env.step(action)
|
||||||
|
totrew += rew
|
||||||
|
if done: break
|
||||||
|
return ob, totrew, done, info
|
||||||
|
|
||||||
|
|
||||||
|
def make_mario_env(crop=True, frame_stack=True, clip_rewards=False):
|
||||||
|
assert clip_rewards is False
|
||||||
|
import gym
|
||||||
|
import retro
|
||||||
|
from baselines.common.atari_wrappers import FrameStack
|
||||||
|
|
||||||
|
gym.undo_logger_setup()
|
||||||
|
env = retro.make('SuperMarioBros-Nes', 'Level1-1')
|
||||||
|
buttons = env.BUTTONS
|
||||||
|
env = MarioXReward(env)
|
||||||
|
env = FrameSkip(env, 4)
|
||||||
|
env = ProcessFrame84(env, crop=crop)
|
||||||
|
if frame_stack:
|
||||||
|
env = FrameStack(env, 4)
|
||||||
|
env = LimitedDiscreteActions(env, buttons)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
class OneChannel(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, crop=True):
|
||||||
|
self.crop = crop
|
||||||
|
super(OneChannel, self).__init__(env)
|
||||||
|
assert env.observation_space.dtype == np.uint8
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||||||
|
|
||||||
|
def observation(self, obs):
|
||||||
|
return obs[:, :, 2:3]
|
||||||
|
|
||||||
|
|
||||||
|
class RetroALEActions(gym.ActionWrapper):
|
||||||
|
def __init__(self, env, all_buttons, n_players=1):
|
||||||
|
gym.ActionWrapper.__init__(self, env)
|
||||||
|
self.n_players = n_players
|
||||||
|
self._num_buttons = len(all_buttons)
|
||||||
|
bs = [-1, 0, 4, 5, 6, 7]
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
def update_actions(old_actions, offset=0):
|
||||||
|
actions = []
|
||||||
|
for b in old_actions:
|
||||||
|
for button in bs:
|
||||||
|
action = []
|
||||||
|
action.extend(b)
|
||||||
|
if button != -1:
|
||||||
|
action.append(button + offset)
|
||||||
|
actions.append(action)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
current_actions = [[]]
|
||||||
|
for i in range(self.n_players):
|
||||||
|
current_actions = update_actions(current_actions, i * self._num_buttons)
|
||||||
|
self._actions = current_actions
|
||||||
|
self.action_space = gym.spaces.Discrete(len(self._actions))
|
||||||
|
|
||||||
|
def action(self, a):
|
||||||
|
mask = np.zeros(self._num_buttons * self.n_players)
|
||||||
|
for i in self._actions[a]:
|
||||||
|
mask[i] = 1
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class NoReward(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
ob, rew, done, info = self.env.step(action)
|
||||||
|
return ob, 0.0, done, info
|
||||||
|
|
||||||
|
|
||||||
|
def make_multi_pong(frame_stack=True):
|
||||||
|
import gym
|
||||||
|
import retro
|
||||||
|
from baselines.common.atari_wrappers import FrameStack
|
||||||
|
gym.undo_logger_setup()
|
||||||
|
game_env = env = retro.make('Pong-Atari2600', players=2)
|
||||||
|
env = RetroALEActions(env, game_env.BUTTONS, n_players=2)
|
||||||
|
env = NoReward(env)
|
||||||
|
env = FrameSkip(env, 4)
|
||||||
|
env = ProcessFrame84(env, crop=False)
|
||||||
|
if frame_stack:
|
||||||
|
env = FrameStack(env, 4)
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def make_robo_pong(frame_stack=True):
|
||||||
|
from baselines.common.atari_wrappers import FrameStack
|
||||||
|
import roboenvs as robo
|
||||||
|
|
||||||
|
env = robo.make_robopong()
|
||||||
|
env = robo.DiscretizeActionWrapper(env, 2)
|
||||||
|
env = robo.MultiDiscreteToUsual(env)
|
||||||
|
env = OneChannel(env)
|
||||||
|
if frame_stack:
|
||||||
|
env = FrameStack(env, 4)
|
||||||
|
|
||||||
|
env = AddRandomStateToInfo(env)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def make_robo_hockey(frame_stack=True):
|
||||||
|
from baselines.common.atari_wrappers import FrameStack
|
||||||
|
import roboenvs as robo
|
||||||
|
|
||||||
|
env = robo.make_robohockey()
|
||||||
|
env = robo.DiscretizeActionWrapper(env, 2)
|
||||||
|
env = robo.MultiDiscreteToUsual(env)
|
||||||
|
env = OneChannel(env)
|
||||||
|
if frame_stack:
|
||||||
|
env = FrameStack(env, 4)
|
||||||
|
env = AddRandomStateToInfo(env)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def make_robo_hockey(frame_stack=True):
|
||||||
|
from baselines.common.atari_wrappers import FrameStack
|
||||||
|
import roboenvs as robo
|
||||||
|
|
||||||
|
env = robo.make_robohockey()
|
||||||
|
env = robo.DiscretizeActionWrapper(env, 2)
|
||||||
|
env = robo.MultiDiscreteToUsual(env)
|
||||||
|
env = OneChannel(env)
|
||||||
|
if frame_stack:
|
||||||
|
env = FrameStack(env, 4)
|
||||||
|
env = AddRandomStateToInfo(env)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def make_unity_maze(env_id, seed=0, rank=0, expID=0, frame_stack=True,
|
||||||
|
logdir=None, ext_coeff=1.0, recordUnityVid=False, **kwargs):
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
try:
|
||||||
|
sys.path.insert(0, os.path.abspath("ml-agents/python/"))
|
||||||
|
from unityagents import UnityEnvironment
|
||||||
|
from unity_wrapper import GymWrapper
|
||||||
|
except ImportError:
|
||||||
|
print("Import error in unity environment. Ignore if not using unity.")
|
||||||
|
pass
|
||||||
|
from baselines.common.atari_wrappers import FrameStack
|
||||||
|
# gym.undo_logger_setup() # deprecated in new version of gym
|
||||||
|
|
||||||
|
# max 20 workers per expID, max 30 experiments per machine
|
||||||
|
if rank >= 0 and rank <= 200:
|
||||||
|
time.sleep(rank * 2)
|
||||||
|
env = UnityEnvironment(file_name='envs/' + env_id, worker_id=(expID % 60) * 200 + rank)
|
||||||
|
maxsteps = 3000 if 'big' in env_id else 500
|
||||||
|
env = GymWrapper(env, seed=seed, rank=rank, expID=expID, maxsteps=maxsteps, **kwargs)
|
||||||
|
if "big" in env_id:
|
||||||
|
env = UnityRoomCounterWrapper(env, use_ext_reward=(ext_coeff != 0.0))
|
||||||
|
if rank == 1 and recordUnityVid:
|
||||||
|
env = RecordBestScores(env, directory=logdir, freq=1)
|
||||||
|
print('Loaded environment %s with rank %d\n\n' % (env_id, rank))
|
||||||
|
|
||||||
|
# env = NoReward(env)
|
||||||
|
# env = FrameSkip(env, 4)
|
||||||
|
env = ProcessFrame84(env, crop=False)
|
||||||
|
if frame_stack:
|
||||||
|
env = FrameStack(env, 4)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
class StickyActionEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env, p=0.25):
|
||||||
|
super(StickyActionEnv, self).__init__(env)
|
||||||
|
self.p = p
|
||||||
|
self.last_action = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.last_action = 0
|
||||||
|
return self.env.reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
if self.unwrapped.np_random.uniform() < self.p:
|
||||||
|
action = self.last_action
|
||||||
|
self.last_action = action
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
|
||||||
|
############## Pixel-Noise #################
|
||||||
|
class PixelNoiseWrapper(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, strength=80):
|
||||||
|
""" The source must produce a image with a shape that's compatible to `env.observation_space`.
|
||||||
|
"""
|
||||||
|
super(PixelNoiseWrapper, self).__init__(env)
|
||||||
|
self.env = env
|
||||||
|
self.obs_shape = env.observation_space.shape[:2]
|
||||||
|
self.strength = strength
|
||||||
|
|
||||||
|
def observation(self, obs):
|
||||||
|
mask = (obs == (0, 0, 0)) # shape=(210,140,3)
|
||||||
|
noise = np.maximum(np.random.randn(self.obs_shape[0], self.obs_shape[1], 3) * self.strength, 0)
|
||||||
|
obs[mask] = noise[mask]
|
||||||
|
self._last_ob = obs
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def render(self, mode='rgb_array'):
|
||||||
|
img = self._last_ob
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
############# Random Box Noise #################
|
||||||
|
class RandomBoxNoiseWrapper(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, strength=0.1):
|
||||||
|
super(RandomBoxNoiseWrapper, self).__init__(env)
|
||||||
|
self.obs_shape = env.observation_space.shape[:2] # 210, 160
|
||||||
|
self.strength = strength
|
||||||
|
|
||||||
|
def observation(self, obs, w=20, ):
|
||||||
|
n1 = self.obs_shape[1] // w
|
||||||
|
n2 = self.obs_shape[0] // w
|
||||||
|
|
||||||
|
idx_list = np.arange(n1*n2)
|
||||||
|
random.shuffle(idx_list)
|
||||||
|
|
||||||
|
num_of_box = n1 * n2 * self.strength # the ratio of random box
|
||||||
|
idx_list = idx_list[:np.random.randint(num_of_box-5, num_of_box+5)]
|
||||||
|
|
||||||
|
for idx in idx_list:
|
||||||
|
y = (idx // n1) * w
|
||||||
|
x = (idx % n1) * w
|
||||||
|
obs[y:y+w, x:x+w, :] += np.random.normal(0, 255*0.3, size=(w, w, 3)).astype(np.uint8)
|
||||||
|
|
||||||
|
obs = np.clip(obs, 0, 255)
|
||||||
|
self._last_ob = obs
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def render(self, mode='rgb_array'):
|
||||||
|
img = self._last_ob
|
||||||
|
return img
|
||||||
|
|
Loading…
Reference in New Issue
Block a user