Add Dreamerv2 Files

This commit is contained in:
Vedant Dave 2023-07-17 10:48:40 +02:00
parent 1ddf72b0c3
commit 0ac3131dad
7 changed files with 2452 additions and 0 deletions

185
DreamerV2/configs.yaml Normal file
View File

@ -0,0 +1,185 @@
defaults:
gpu: 'none'
logdir: ./
traindir: null
evaldir: null
offline_traindir: ''
offline_evaldir: ''
seed: 0
steps: 1e7
eval_every: 1e4
log_every: 1e4
reset_every: 0
gpu_growth: True
precision: 32
debug: False
expl_gifs: False
# Environment
task: 'dmc_walker_walk'
size: [64, 64]
envs: 1
action_repeat: 2
time_limit: 1000
prefill: 2500
eval_noise: 0.0
clip_rewards: 'identity'
atari_grayscale: False
# Model
dyn_cell: 'gru'
dyn_hidden: 200
dyn_deter: 200
dyn_stoch: 50
dyn_discrete: 0
dyn_input_layers: 1
dyn_output_layers: 1
dyn_shared: False
dyn_mean_act: 'none'
dyn_std_act: 'sigmoid2'
dyn_min_std: 0.1
grad_heads: ['image', 'reward']
units: 400
reward_layers: 2
discount_layers: 3
value_layers: 3
actor_layers: 4
act: 'elu'
cnn_depth: 32
encoder_kernels: [4, 4, 4, 4]
decoder_kernels: [5, 5, 6, 6]
decoder_thin: True
value_head: 'normal'
kl_scale: '1.0'
kl_balance: '0.8'
kl_free: '1.0'
pred_discount: False
discount_scale: 1.0
reward_scale: 1.0
weight_decay: 0.0
# Training
batch_size: 50
batch_length: 50
train_every: 5
train_steps: 1
pretrain: 100
model_lr: 3e-4
value_lr: 8e-5
actor_lr: 8e-5
opt_eps: 1e-5
grad_clip: 100
value_grad_clip: 100
actor_grad_clip: 100
dataset_size: 0
oversample_ends: False
slow_value_target: True
slow_actor_target: True
slow_target_update: 100
slow_target_fraction: 1
opt: 'adam'
# Behavior.
discount: 0.99
discount_lambda: 0.95
imag_horizon: 15
imag_gradient: 'dynamics'
imag_gradient_mix: '0.1'
imag_sample: True
actor_dist: 'trunc_normal'
actor_entropy: '1e-4'
actor_state_entropy: 0.0
actor_init_std: 1.0
actor_min_std: 0.1
actor_disc: 5
actor_temp: 0.1
actor_outscale: 0.0
expl_amount: 0.0
eval_state_mean: False
collect_dyn_sample: True
behavior_stop_grad: True
value_decay: 0.0
future_entropy: False
# Exploration
expl_behavior: 'greedy'
expl_until: 0
expl_extr_scale: 0.0
expl_intr_scale: 1.0
disag_target: 'stoch'
disag_log: True
disag_models: 10
disag_offset: 1
disag_layers: 4
disag_units: 400
atari:
# General
task: 'atari_demon_attack'
steps: 3e7
eval_every: 1e5
log_every: 1e4
prefill: 50000
dataset_size: 2e6
pretrain: 0
precision: 16
# Environment
time_limit: 108000 # 30 minutes of game play.
atari_grayscale: True
action_repeat: 4
eval_noise: 0.001
train_every: 16
train_steps: 1
clip_rewards: 'tanh'
# Model
grad_heads: ['image', 'reward', 'discount']
dyn_cell: 'gru_layer_norm'
pred_discount: True
cnn_depth: 48
dyn_deter: 600
dyn_hidden: 600
dyn_stoch: 32
dyn_discrete: 32
reward_layers: 4
discount_layers: 4
value_layers: 4
actor_layers: 4
# Behavior
actor_dist: 'onehot'
actor_entropy: 'linear(3e-3,3e-4,2.5e6)'
expl_amount: 0.0
expl_until: 3e7
discount: 0.995
imag_gradient: 'both'
imag_gradient_mix: 'linear(0.1,0,2.5e6)'
# Training
discount_scale: 5.0
reward_scale: 1
weight_decay: 1e-6
model_lr: 2e-4
kl_scale: 0.1
kl_free: 0.0
actor_lr: 4e-5
value_lr: 1e-4
oversample_ends: True
# Disen
disen_cnn_depth: 16
disen_only_scale: 1.0
disen_discount_scale: 2000.0
disen_reward_scale: 2000.0
num_reward_opt_iters: 20
debug:
debug: True
pretrain: 1
prefill: 1
train_steps: 1
batch_size: 10
batch_length: 20

316
DreamerV2/dreamer.py Normal file
View File

@ -0,0 +1,316 @@
import argparse
import collections
import functools
import os
import pathlib
import sys
import warnings
warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['MUJOCO_GL'] = 'egl'
import numpy as np
import ruamel.yaml as yaml
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
tf.get_logger().setLevel('ERROR')
from tensorflow_probability import distributions as tfd
sys.path.append(str(pathlib.Path(__file__).parent))
import exploration as expl
import models
import tools
import wrappers
class Dreamer(tools.Module):
def __init__(self, config, logger, dataset):
self._config = config
self._logger = logger
self._float = prec.global_policy().compute_dtype
self._should_log = tools.Every(config.log_every)
self._should_train = tools.Every(config.train_every)
self._should_pretrain = tools.Once()
self._should_reset = tools.Every(config.reset_every)
self._should_expl = tools.Until(int(
config.expl_until / config.action_repeat))
self._metrics = collections.defaultdict(tf.metrics.Mean)
with tf.device('cpu:0'):
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64)
# Schedules.
config.actor_entropy = (
lambda x=config.actor_entropy: tools.schedule(x, self._step))
config.actor_state_entropy = (
lambda x=config.actor_state_entropy: tools.schedule(x, self._step))
config.imag_gradient_mix = (
lambda x=config.imag_gradient_mix: tools.schedule(x, self._step))
self._dataset = iter(dataset)
self._wm = models.WorldModel(self._step, config)
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad)
reward = lambda f, s, a: self._wm.heads['reward'](f).mode()
self._expl_behavior = dict(
greedy=lambda: self._task_behavior,
random=lambda: expl.Random(config),
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
)[config.expl_behavior]()
# Train step to initialize variables including optimizer statistics.
self._train(next(self._dataset))
def __call__(self, obs, reset, state=None, training=True):
step = self._step.numpy().item()
if self._should_reset(step):
state = None
if state is not None and reset.any():
mask = tf.cast(1 - reset, self._float)[:, None]
state = tf.nest.map_structure(lambda x: x * mask, state)
if training and self._should_train(step):
steps = (
self._config.pretrain if self._should_pretrain()
else self._config.train_steps)
for _ in range(steps):
self._train(next(self._dataset))
if self._should_log(step):
for name, mean in self._metrics.items():
self._logger.scalar(name, float(mean.result()))
mean.reset_states()
openl_joint, openl_main, openl_disen, openl_mask = self._wm.video_pred(next(self._dataset))
self._logger.video('train_openl_joint', openl_joint)
self._logger.video('train_openl_main', openl_main)
self._logger.video('train_openl_disen', openl_disen)
self._logger.video('train_openl_mask', openl_mask)
self._logger.write(fps=True)
action, state = self._policy(obs, state, training)
if training:
self._step.assign_add(len(reset))
self._logger.step = self._config.action_repeat \
* self._step.numpy().item()
return action, state
@tf.function
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs['image'])
latent = self._wm.dynamics.initial(len(obs['image']))
action = tf.zeros((batch_size, self._config.num_actions), self._float)
else:
latent, action = state
embed = self._wm.encoder(self._wm.preprocess(obs))
latent, _ = self._wm.dynamics.obs_step(
latent, action, embed, self._config.collect_dyn_sample)
if self._config.eval_state_mean:
latent['stoch'] = latent['mean']
feat = self._wm.dynamics.get_feat(latent)
if not training:
action = self._task_behavior.actor(feat).mode()
elif self._should_expl(self._step):
action = self._expl_behavior.actor(feat).sample()
else:
action = self._task_behavior.actor(feat).sample()
if self._config.actor_dist == 'onehot_gumble':
action = tf.cast(
tf.one_hot(tf.argmax(action, axis=-1), self._config.num_actions),
action.dtype)
action = self._exploration(action, training)
state = (latent, action)
return action, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
amount = tf.cast(amount, self._float)
if 'onehot' in self._config.actor_dist:
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
else:
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
raise NotImplementedError(self._config.action_noise)
@tf.function
def _train(self, data):
print('Tracing train function.')
metrics = {}
embed, post, feat, kl, mets = self._wm.train(data)
metrics.update(mets)
start = post
if self._config.pred_discount: # Last step could be terminal.
start = {k: v[:, :-1] for k, v in post.items()}
embed, feat, kl = embed[:, :-1], feat[:, :-1], kl[:, :-1]
reward = lambda f, s, a: self._wm.heads['reward'](f).mode()
metrics.update(self._task_behavior.train(start, reward)[-1])
if self._config.expl_behavior != 'greedy':
mets = self._expl_behavior.train(start, feat, embed, kl)[-1]
metrics.update({'expl_' + key: value for key, value in mets.items()})
for name, value in metrics.items():
self._metrics[name].update_state(value)
def count_steps(folder):
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in folder.glob('*.npz'))
def make_dataset(episodes, config):
example = episodes[next(iter(episodes.keys()))]
types = {k: v.dtype for k, v in example.items()}
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
generator = lambda: tools.sample_episodes(
episodes, config.batch_length, config.oversample_ends)
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.prefetch(10)
return dataset
def make_env(config, logger, mode, train_eps, eval_eps):
suite, task = config.task.split('_', 1)
if suite == 'dmc':
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
env = wrappers.NormalizeActions(env)
elif suite == 'atari':
env = wrappers.Atari(
task, config.action_repeat, config.size,
grayscale=config.atari_grayscale,
life_done=False and (mode == 'train'),
sticky_actions=True,
all_actions=True)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
callbacks = [functools.partial(
process_episode, config, logger, mode, train_eps, eval_eps)]
env = wrappers.CollectDataset(env, callbacks)
env = wrappers.RewardObs(env)
return env
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
cache = dict(train=train_eps, eval=eval_eps)[mode]
filename = tools.save_episodes(directory, [episode])[0]
length = len(episode['reward']) - 1
score = float(episode['reward'].astype(np.float64).sum())
video = episode['image']
if mode == 'eval':
cache.clear()
if mode == 'train' and config.dataset_size:
total = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if total <= config.dataset_size - length:
total += len(ep['reward']) - 1
else:
del cache[key]
logger.scalar('dataset_size', total + length)
cache[str(filename)] = episode
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
logger.scalar(f'{mode}_return', score)
logger.scalar(f'{mode}_length', length)
logger.scalar(f'{mode}_episodes', len(cache))
if mode == 'eval' or config.expl_gifs:
logger.video(f'{mode}_policy', video[None])
logger.write()
def main(logdir, config):
logdir = os.path.join(
logdir, config.task, 'Ours', str(config.seed))
logdir = pathlib.Path(logdir).expanduser()
config.traindir = config.traindir or logdir / 'train_eps'
config.evaldir = config.evaldir or logdir / 'eval_eps'
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat
config.act = getattr(tf.nn, config.act)
if config.debug:
tf.config.experimental_run_functions_eagerly(True)
if config.gpu_growth:
message = 'No GPU found. To actually train on CPU remove this assert.'
assert tf.config.experimental.list_physical_devices('GPU'), message
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
prec.set_policy(prec.Policy('mixed_float16'))
print('Logdir', logdir)
logdir.mkdir(parents=True, exist_ok=True)
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
logger = tools.Logger(logdir, config.action_repeat * step)
print('Create envs.')
if config.offline_traindir:
directory = config.offline_traindir.format(**vars(config))
else:
directory = config.traindir
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
if config.offline_evaldir:
directory = config.offline_evaldir.format(**vars(config))
else:
directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)
make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
train_envs = [make('train') for _ in range(config.envs)]
eval_envs = [make('eval') for _ in range(config.envs)]
acts = train_envs[0].action_space
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]
prefill = max(0, config.prefill - count_steps(config.traindir))
print(f'Prefill dataset ({prefill} steps).')
random_agent = lambda o, d, s: ([acts.sample() for _ in d], s)
tools.simulate(random_agent, train_envs, prefill)
tools.simulate(random_agent, eval_envs, episodes=1)
logger.step = config.action_repeat * count_steps(config.traindir)
print('Simulate agent.')
train_dataset = make_dataset(train_eps, config)
eval_dataset = iter(make_dataset(eval_eps, config))
agent = Dreamer(config, logger, train_dataset)
if (logdir / 'variables.pkl').exists():
agent.load(logdir / 'variables.pkl')
agent._should_pretrain._once = False
state = None
suite, task = config.task.split('_', 1)
num_eval_episodes = 10 if suite == 'procgen' else 1
while agent._step.numpy().item() < config.steps:
logger.write()
print('Start evaluation.')
openl_joint, openl_main, openl_disen, openl_mask = agent._wm.video_pred(next(eval_dataset))
logger.video('eval_openl_joint', openl_joint)
logger.video('eval_openl_main', openl_main)
logger.video('eval_openl_disen', openl_disen)
logger.video('eval_openl_mask', openl_mask)
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=num_eval_episodes)
print('Start training.')
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
agent.save(logdir / 'variables.pkl')
for env in train_envs + eval_envs:
try:
env.close()
except Exception:
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--configs', nargs='+', required=True)
args, remaining = parser.parse_known_args()
configs = yaml.safe_load((pathlib.Path(__file__).parent / 'configs.yaml').read_text())
config_ = {}
for name in args.configs:
config_.update(configs[name])
parser = argparse.ArgumentParser()
for key, value in config_.items():
arg_type = tools.args_type(value)
parser.add_argument(f'--{key}', type=arg_type, default=arg_type(value))
main(config_['logdir'], parser.parse_args(remaining))

83
DreamerV2/exploration.py Normal file
View File

@ -0,0 +1,83 @@
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow_probability import distributions as tfd
import models
import networks
import tools
class Random(tools.Module):
def __init__(self, config):
self._config = config
self._float = prec.global_policy().comput_dtype
def actor(self, feat):
shape = feat.shape[:-1] + [self._config.num_actions]
if self._config.actor_dist == 'onehot':
return tools.OneHotDist(tf.zeros(shape))
else:
ones = tf.ones(shape, self._float)
return tfd.Uniform(-ones, ones)
def train(self, start, feat, embed, kl):
return None, {}
class Plan2Explore(tools.Module):
def __init__(self, config, world_model, reward=None):
self._config = config
self._reward = reward
self._behavior = models.ImagBehavior(config, world_model)
self.actor = self._behavior.actor
size = {
'embed': 32 * config.cnn_depth,
'stoch': config.dyn_stoch,
'deter': config.dyn_deter,
'feat': config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target]
kw = dict(
shape=size, layers=config.disag_layers, units=config.disag_units,
act=config.act)
self._networks = [
networks.DenseHead(**kw) for _ in range(config.disag_models)]
self._opt = tools.Optimizer(
'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
def train(self, start, feat, embed, kl):
metrics = {}
target = {
'embed': embed,
'stoch': start['stoch'],
'deter': start['deter'],
'feat': feat,
}[self._config.disag_target]
metrics.update(self._train_ensemble(feat, target))
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
return None, metrics
def _intrinsic_reward(self, feat, state, action):
preds = [head(feat, tf.float32).mean() for head in self._networks]
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
if self._config.disag_log:
disag = tf.math.log(disag)
reward = self._config.expl_intr_scale * disag
if self._config.expl_extr_scale:
reward += tf.cast(self._config.expl_extr_scale * self._reward(
feat, state, action), tf.float32)
return reward
def _train_ensemble(self, inputs, targets):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset:]
inputs = inputs[:, :-self._config.disag_offset]
targets = tf.stop_gradient(targets)
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
metrics = self._opt(tape, loss, self._networks)
return metrics

429
DreamerV2/models.py Normal file
View File

@ -0,0 +1,429 @@
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
import networks
import tools
class WorldModel(tools.Module):
def __init__(self, step, config):
self._step = step
self._config = config
channels = (1 if config.atari_grayscale else 3)
shape = config.size + (channels,)
########
# Main #
########
self.encoder = networks.ConvEncoder(
config.cnn_depth, config.act, config.encoder_kernels)
self.dynamics = networks.RSSM(
config.dyn_stoch, config.dyn_deter, config.dyn_hidden,
config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared,
config.dyn_discrete, config.act, config.dyn_mean_act,
config.dyn_std_act, config.dyn_min_std, config.dyn_cell)
self.heads = {}
self.heads['reward'] = networks.DenseHead(
[], config.reward_layers, config.units, config.act)
if config.pred_discount:
self.heads['discount'] = networks.DenseHead(
[], config.discount_layers, config.units, config.act, dist='binary')
self._model_opt = tools.Optimizer(
'model', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
self._scales = dict(
reward=config.reward_scale, discount=config.discount_scale)
#########
# Disen #
#########
self.disen_encoder = networks.ConvEncoder(
config.disen_cnn_depth, config.act, config.encoder_kernels)
self.disen_dynamics = networks.RSSM(
config.dyn_stoch, config.dyn_deter, config.dyn_hidden,
config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared,
config.dyn_discrete, config.act, config.dyn_mean_act,
config.dyn_std_act, config.dyn_min_std, config.dyn_cell)
self.disen_heads = {}
self.disen_heads['reward'] = networks.DenseHead(
[], config.reward_layers, config.units, config.act)
if config.pred_discount:
self.disen_heads['discount'] = networks.DenseHead(
[], config.discount_layers, config.units, config.act, dist='binary')
self._disen_model_opt = tools.Optimizer(
'disen', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
self._disen_heads_opt = {}
self._disen_heads_opt['reward'] = tools.Optimizer(
'disen_reward', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
if config.pred_discount:
self._disen_heads_opt['discount'] = tools.Optimizer(
'disen_pcont', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
# negative signs for reward/discount here
self._disen_scales = dict(disen_only=config.disen_only_scale,
reward=-config.disen_reward_scale, discount=-config.disen_discount_scale)
self.disen_only_image_head = networks.ConvDecoder(
config.disen_cnn_depth, config.act, shape, config.decoder_kernels,
config.decoder_thin)
################
# Joint Decode #
################
self.image_head = networks.ConvDecoderMask(
config.cnn_depth, config.act, shape, config.decoder_kernels,
config.decoder_thin)
self.disen_image_head = networks.ConvDecoderMask(
config.disen_cnn_depth, config.act, shape, config.decoder_kernels,
config.decoder_thin)
self.joint_image_head = networks.ConvDecoderMaskEnsemble(
self.image_head, self.disen_image_head
)
def train(self, data):
data = self.preprocess(data)
with tf.GradientTape() as model_tape, tf.GradientTape() as disen_tape:
# kl schedule
kl_balance = tools.schedule(self._config.kl_balance, self._step)
kl_free = tools.schedule(self._config.kl_free, self._step)
kl_scale = tools.schedule(self._config.kl_scale, self._step)
# Main
embed = self.encoder(data)
post, prior = self.dynamics.observe(embed, data['action'])
kl_loss, kl_value = self.dynamics.kl_loss(
post, prior, kl_balance, kl_free, kl_scale)
feat = self.dynamics.get_feat(post)
likes = {}
for name, head in self.heads.items():
grad_head = (name in self._config.grad_heads)
inp = feat if grad_head else tf.stop_gradient(feat)
pred = head(inp, tf.float32)
like = pred.log_prob(tf.cast(data[name], tf.float32))
likes[name] = tf.reduce_mean(
like) * self._scales.get(name, 1.0)
# Disen
embed_disen = self.disen_encoder(data)
post_disen, prior_disen = self.disen_dynamics.observe(
embed_disen, data['action'])
kl_loss_disen, kl_value_disen = self.dynamics.kl_loss(
post_disen, prior_disen, kl_balance, kl_free, kl_scale)
feat_disen = self.disen_dynamics.get_feat(post_disen)
# Optimize disen reward/pcont till optimal
disen_metrics = dict(reward={}, discount={})
loss_disen = dict(reward=None, discount=None)
for _ in range(self._config.num_reward_opt_iters):
with tf.GradientTape() as disen_reward_tape, tf.GradientTape() as disen_pcont_tape:
disen_gradient_tapes = dict(
reward=disen_reward_tape, discount=disen_pcont_tape)
for name, head in self.disen_heads.items():
pred_disen = head(
tf.stop_gradient(feat_disen), tf.float32)
loss_disen[name] = -tf.reduce_mean(pred_disen.log_prob(
tf.cast(data[name], tf.float32)))
for name, head in self.disen_heads.items():
disen_metrics[name] = self._disen_heads_opt[name](
disen_gradient_tapes[name], loss_disen[name], [head], prefix='disen_neg')
# Compute likes for disen model (including negative gradients)
likes_disen = {}
for name, head in self.disen_heads.items():
pred_disen = head(feat_disen, tf.float32)
like_disen = pred_disen.log_prob(
tf.cast(data[name], tf.float32))
likes_disen[name] = tf.reduce_mean(
like_disen) * self._disen_scales.get(name, -1.0)
disen_only_image_pred = self.disen_only_image_head(
feat_disen, tf.float32)
disen_only_image_like = tf.reduce_mean(disen_only_image_pred.log_prob(
tf.cast(data['image'], tf.float32))) * self._disen_scales.get('disen_only', 1.0)
likes_disen['disen_only'] = disen_only_image_like
# Joint decode
image_pred_joint, _, _, _ = self.joint_image_head(
feat, feat_disen, tf.float32)
image_like = tf.reduce_mean(image_pred_joint.log_prob(
tf.cast(data['image'], tf.float32)))
likes['image'] = image_like
likes_disen['image'] = image_like
# Compute loss
model_loss = kl_loss - sum(likes.values())
disen_loss = kl_loss_disen - sum(likes_disen.values())
model_parts = [self.encoder, self.dynamics,
self.joint_image_head] + list(self.heads.values())
disen_parts = [self.disen_encoder, self.disen_dynamics,
self.joint_image_head, self.disen_only_image_head]
metrics = self._model_opt(
model_tape, model_loss, model_parts, prefix='main')
disen_model_metrics = self._disen_model_opt(
disen_tape, disen_loss, disen_parts, prefix='disen')
metrics['kl_balance'] = kl_balance
metrics['kl_free'] = kl_free
metrics['kl_scale'] = kl_scale
metrics.update({f'{name}_loss': -like for name,
like in likes.items()})
metrics['disen/disen_only_image_loss'] = -disen_only_image_like
metrics['disen/disen_reward_loss'] = -likes_disen['reward'] / \
self._disen_scales.get('reward', -1.0)
metrics['disen/disen_discount_loss'] = -likes_disen['discount'] / \
self._disen_scales.get('discount', -1.0)
metrics['kl'] = tf.reduce_mean(kl_value)
metrics['prior_ent'] = self.dynamics.get_dist(prior).entropy()
metrics['post_ent'] = self.dynamics.get_dist(post).entropy()
metrics['disen/kl'] = tf.reduce_mean(kl_value_disen)
metrics['disen/prior_ent'] = self.dynamics.get_dist(
prior_disen).entropy()
metrics['disen/post_ent'] = self.dynamics.get_dist(
post_disen).entropy()
metrics.update(
{f'{key}': value for key, value in disen_metrics['reward'].items()})
metrics.update(
{f'{key}': value for key, value in disen_metrics['discount'].items()})
metrics.update(
{f'{key}': value for key, value in disen_model_metrics.items()})
return embed, post, feat, kl_value, metrics
@tf.function
def preprocess(self, obs):
dtype = prec.global_policy().compute_dtype
obs = obs.copy()
obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5
obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward'])
if 'discount' in obs:
obs['discount'] *= self._config.discount
for key, value in obs.items():
if tf.dtypes.as_dtype(value.dtype) in (
tf.float16, tf.float32, tf.float64):
obs[key] = tf.cast(value, dtype)
return obs
@tf.function
def video_pred(self, data):
data = self.preprocess(data)
truth = data['image'][:6] + 0.5
embed = self.encoder(data)
embed_disen = self.disen_encoder(data)
states, _ = self.dynamics.observe(
embed[:6, :5], data['action'][:6, :5])
states_disen, _ = self.disen_dynamics.observe(
embed_disen[:6, :5], data['action'][:6, :5])
feats = self.dynamics.get_feat(states)
feats_disen = self.disen_dynamics.get_feat(states_disen)
recon_joint, recon_main, recon_disen, recon_mask = self.joint_image_head(
feats, feats_disen)
recon_joint = recon_joint.mode()[:6]
recon_main = recon_main.mode()[:6]
recon_disen = recon_disen.mode()[:6]
recon_mask = recon_mask[:6]
init = {k: v[:, -1] for k, v in states.items()}
init_disen = {k: v[:, -1] for k, v in states_disen.items()}
prior = self.dynamics.imagine(
data['action'][:6, 5:], init)
prior_disen = self.disen_dynamics.imagine(
data['action'][:6, 5:], init_disen)
_feats = self.dynamics.get_feat(prior)
_feats_disen = self.disen_dynamics.get_feat(prior_disen)
openl_joint, openl_main, openl_disen, openl_mask = self.joint_image_head(
_feats, _feats_disen)
openl_joint = openl_joint.mode()
openl_main = openl_main.mode()
openl_disen = openl_disen.mode()
model_joint = tf.concat(
[recon_joint[:, :5] + 0.5, openl_joint + 0.5], 1)
error_joint = (model_joint - truth + 1) / 2
model_main = tf.concat(
[recon_main[:, :5] + 0.5, openl_main + 0.5], 1)
error_main = (model_main - truth + 1) / 2
model_disen = tf.concat(
[recon_disen[:, :5] + 0.5, openl_disen + 0.5], 1)
error_disen = (model_disen - truth + 1) / 2
model_mask = tf.concat(
[recon_mask[:, :5] + 0.5, openl_mask + 0.5], 1)
output_joint = tf.concat([truth, model_joint, error_joint], 2)
output_main = tf.concat([truth, model_main, error_main], 2)
output_disen = tf.concat([truth, model_disen, error_disen], 2)
output_mask = model_mask
return output_joint, output_main, output_disen, output_mask
class ImagBehavior(tools.Module):
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
self._config = config
self._world_model = world_model
self._stop_grad_actor = stop_grad_actor
self._reward = reward
self.actor = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_min_std,
config.actor_dist, config.actor_temp, config.actor_outscale)
self.value = networks.DenseHead(
[], config.value_layers, config.units, config.act,
config.value_head)
if config.slow_value_target or config.slow_actor_target:
self._slow_value = networks.DenseHead(
[], config.value_layers, config.units, config.act)
self._updates = tf.Variable(0, tf.int64)
kw = dict(wd=config.weight_decay, opt=config.opt)
self._actor_opt = tools.Optimizer(
'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw)
self._value_opt = tools.Optimizer(
'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw)
def train(
self, start, objective=None, imagine=None, tape=None, repeats=None):
objective = objective or self._reward
self._update_slow_target()
metrics = {}
with (tape or tf.GradientTape()) as actor_tape:
assert bool(objective) != bool(imagine)
if objective:
imag_feat, imag_state, imag_action = self._imagine(
start, self.actor, self._config.imag_horizon, repeats)
reward = objective(imag_feat, imag_state, imag_action)
else:
imag_feat, imag_state, imag_action, reward = imagine(start)
actor_ent = self.actor(imag_feat, tf.float32).entropy()
state_ent = self._world_model.dynamics.get_dist(
imag_state, tf.float32).entropy()
target, weights = self._compute_target(
imag_feat, reward, actor_ent, state_ent,
self._config.slow_actor_target)
actor_loss, mets = self._compute_actor_loss(
imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
weights)
metrics.update(mets)
if self._config.slow_value_target != self._config.slow_actor_target:
target, weights = self._compute_target(
imag_feat, reward, actor_ent, state_ent,
self._config.slow_value_target)
with tf.GradientTape() as value_tape:
value = self.value(imag_feat, tf.float32)[:-1]
value_loss = -value.log_prob(tf.stop_gradient(target))
if self._config.value_decay:
value_loss += self._config.value_decay * value.mode()
value_loss = tf.reduce_mean(weights[:-1] * value_loss)
metrics['reward_mean'] = tf.reduce_mean(reward)
metrics['reward_std'] = tf.math.reduce_std(reward)
metrics['actor_ent'] = tf.reduce_mean(actor_ent)
metrics.update(self._actor_opt(actor_tape, actor_loss, [self.actor]))
metrics.update(self._value_opt(value_tape, value_loss, [self.value]))
return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None):
dynamics = self._world_model.dynamics
if repeats:
start = {k: tf.repeat(v, repeats, axis=1)
for k, v in start.items()}
def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = dynamics.get_feat(state)
inp = tf.stop_gradient(feat) if self._stop_grad_actor else feat
action = policy(inp).sample()
succ = dynamics.img_step(
state, action, sample=self._config.imag_sample)
return succ, feat, action
feat = 0 * dynamics.get_feat(start)
action = policy(feat).mode()
succ, feats, actions = tools.static_scan(
step, tf.range(horizon), (start, feat, action))
states = {k: tf.concat([
start[k][None], v[:-1]], 0) for k, v in succ.items()}
if repeats:
def unfold(tensor):
s = tensor.shape
return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:])
states, feats, actions = tf.nest.map_structure(
unfold, (states, feats, actions))
return feats, states, actions
def _compute_target(self, imag_feat, reward, actor_ent, state_ent, slow):
reward = tf.cast(reward, tf.float32)
if 'discount' in self._world_model.heads:
discount = self._world_model.heads['discount'](
imag_feat, tf.float32).mean()
else:
discount = self._config.discount * tf.ones_like(reward)
if self._config.future_entropy and tf.greater(
self._config.actor_entropy(), 0):
reward += self._config.actor_entropy() * actor_ent
if self._config.future_entropy and tf.greater(
self._config.actor_state_entropy(), 0):
reward += self._config.actor_state_entropy() * state_ent
if slow:
value = self._slow_value(imag_feat, tf.float32).mode()
else:
value = self.value(imag_feat, tf.float32).mode()
target = tools.lambda_return(
reward[:-1], value[:-1], discount[:-1],
bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0)
weights = tf.stop_gradient(tf.math.cumprod(tf.concat(
[tf.ones_like(discount[:1]), discount[:-1]], 0), 0))
return target, weights
def _compute_actor_loss(
self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
weights):
metrics = {}
inp = tf.stop_gradient(
imag_feat) if self._stop_grad_actor else imag_feat
policy = self.actor(inp, tf.float32)
actor_ent = policy.entropy()
if self._config.imag_gradient == 'dynamics':
actor_target = target
elif self._config.imag_gradient == 'reinforce':
imag_action = tf.cast(imag_action, tf.float32)
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
target - self.value(imag_feat[:-1], tf.float32).mode())
elif self._config.imag_gradient == 'both':
imag_action = tf.cast(imag_action, tf.float32)
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
target - self.value(imag_feat[:-1], tf.float32).mode())
mix = self._config.imag_gradient_mix()
actor_target = mix * target + (1 - mix) * actor_target
metrics['imag_gradient_mix'] = mix
else:
raise NotImplementedError(self._config.imag_gradient)
if not self._config.future_entropy and tf.greater(
self._config.actor_entropy(), 0):
actor_target += self._config.actor_entropy() * actor_ent[:-1]
if not self._config.future_entropy and tf.greater(
self._config.actor_state_entropy(), 0):
actor_target += self._config.actor_state_entropy() * state_ent[:-1]
actor_loss = -tf.reduce_mean(weights[:-1] * actor_target)
return actor_loss, metrics
def _update_slow_target(self):
if self._config.slow_value_target or self._config.slow_actor_target:
if self._updates % self._config.slow_target_update == 0:
mix = self._config.slow_target_fraction
for s, d in zip(self.value.variables, self._slow_value.variables):
d.assign(mix * s + (1 - mix) * d)
self._updates.assign_add(1)

465
DreamerV2/networks.py Normal file
View File

@ -0,0 +1,465 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as tfkl
from tensorflow_probability import distributions as tfd
from tensorflow.keras.mixed_precision import experimental as prec
import tools
class RSSM(tools.Module):
def __init__(
self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1,
shared=False, discrete=False, act=tf.nn.elu, mean_act='none',
std_act='softplus', min_std=0.1, cell='keras'):
super().__init__()
self._stoch = stoch
self._deter = deter
self._hidden = hidden
self._min_std = min_std
self._layers_input = layers_input
self._layers_output = layers_output
self._shared = shared
self._discrete = discrete
self._act = act
self._mean_act = mean_act
self._std_act = std_act
self._embed = None
if cell == 'gru':
self._cell = tfkl.GRUCell(self._deter)
elif cell == 'gru_layer_norm':
self._cell = GRUCell(self._deter, norm=True)
else:
raise NotImplementedError(cell)
def initial(self, batch_size):
dtype = prec.global_policy().compute_dtype
if self._discrete:
state = dict(
logit=tf.zeros(
[batch_size, self._stoch, self._discrete], dtype),
stoch=tf.zeros(
[batch_size, self._stoch, self._discrete], dtype),
deter=self._cell.get_initial_state(None, batch_size, dtype))
else:
state = dict(
mean=tf.zeros([batch_size, self._stoch], dtype),
std=tf.zeros([batch_size, self._stoch], dtype),
stoch=tf.zeros([batch_size, self._stoch], dtype),
deter=self._cell.get_initial_state(None, batch_size, dtype))
return state
@tf.function
def observe(self, embed, action, state=None):
def swap(x): return tf.transpose(
x, [1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(tf.shape(action)[0])
embed, action = swap(embed), swap(action)
post, prior = tools.static_scan(
lambda prev, inputs: self.obs_step(prev[0], *inputs),
(action, embed), (state, state))
post = {k: swap(v) for k, v in post.items()}
prior = {k: swap(v) for k, v in prior.items()}
return post, prior
@tf.function
def imagine(self, action, state=None):
def swap(x): return tf.transpose(
x, [1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(tf.shape(action)[0])
assert isinstance(state, dict), state
action = swap(action)
prior = tools.static_scan(self.img_step, action, state)
prior = {k: swap(v) for k, v in prior.items()}
return prior
def get_feat(self, state):
stoch = state['stoch']
if self._discrete:
shape = stoch.shape[:-2] + [self._stoch * self._discrete]
stoch = tf.reshape(stoch, shape)
return tf.concat([stoch, state['deter']], -1)
def get_dist(self, state, dtype=None):
if self._discrete:
logit = state['logit']
logit = tf.cast(logit, tf.float32)
dist = tfd.Independent(tools.OneHotDist(logit), 1)
if dtype != tf.float32:
dist = tools.DtypeDist(dist, dtype or state['logit'].dtype)
else:
mean, std = state['mean'], state['std']
if dtype:
mean = tf.cast(mean, dtype)
std = tf.cast(std, dtype)
dist = tfd.MultivariateNormalDiag(mean, std)
return dist
@tf.function
def obs_step(self, prev_state, prev_action, embed, sample=True):
if not self._embed:
self._embed = embed.shape[-1]
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
else:
x = tf.concat([prior['deter'], embed], -1)
for i in range(self._layers_output):
x = self.get(f'obi{i}', tfkl.Dense, self._hidden, self._act)(x)
stats = self._suff_stats_layer('obs', x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
post = {'stoch': stoch, 'deter': prior['deter'], **stats}
return post, prior
@tf.function
def img_step(self, prev_state, prev_action, embed=None, sample=True):
prev_stoch = prev_state['stoch']
if self._discrete:
shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
prev_stoch = tf.reshape(prev_stoch, shape)
if self._shared:
if embed is None:
shape = prev_action.shape[:-1] + [self._embed]
embed = tf.zeros(shape, prev_action.dtype)
x = tf.concat([prev_stoch, prev_action, embed], -1)
else:
x = tf.concat([prev_stoch, prev_action], -1)
for i in range(self._layers_input):
x = self.get(f'ini{i}', tfkl.Dense, self._hidden, self._act)(x)
x, deter = self._cell(x, [prev_state['deter']])
deter = deter[0] # Keras wraps the state in a list.
for i in range(self._layers_output):
x = self.get(f'imo{i}', tfkl.Dense, self._hidden, self._act)(x)
stats = self._suff_stats_layer('ims', x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
prior = {'stoch': stoch, 'deter': deter, **stats}
return prior
def _suff_stats_layer(self, name, x):
if self._discrete:
x = self.get(name, tfkl.Dense, self._stoch *
self._discrete, None)(x)
logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
return {'logit': logit}
else:
x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
mean, std = tf.split(x, 2, -1)
mean = {
'none': lambda: mean,
'tanh5': lambda: 5.0 * tf.math.tanh(mean / 5.0),
}[self._mean_act]()
std = {
'softplus': lambda: tf.nn.softplus(std),
'abs': lambda: tf.math.abs(std + 1),
'sigmoid': lambda: tf.nn.sigmoid(std),
'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2),
}[self._std_act]()
std = std + self._min_std
return {'mean': mean, 'std': std}
def kl_loss(self, post, prior, balance, free, scale):
kld = tfd.kl_divergence
def dist(x): return self.get_dist(x, tf.float32)
if balance == 0.5:
value = kld(dist(prior), dist(post))
loss = tf.reduce_mean(tf.maximum(value, free))
else:
def sg(x): return tf.nest.map_structure(tf.stop_gradient, x)
value = kld(dist(prior), dist(sg(post)))
pri = tf.reduce_mean(value)
pos = tf.reduce_mean(kld(dist(sg(prior)), dist(post)))
pri, pos = tf.maximum(pri, free), tf.maximum(pos, free)
loss = balance * pri + (1 - balance) * pos
loss *= scale
return loss, value
class ConvEncoder(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, kernels=(4, 4, 4, 4)):
self._act = act
self._depth = depth
self._kernels = kernels
def __call__(self, obs):
kwargs = dict(strides=2, activation=self._act)
Conv = tfkl.Conv2D
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
x = self.get('h1', Conv, 1 * self._depth,
self._kernels[0], **kwargs)(x)
x = self.get('h2', Conv, 2 * self._depth,
self._kernels[1], **kwargs)(x)
x = self.get('h3', Conv, 4 * self._depth,
self._kernels[2], **kwargs)(x)
x = self.get('h4', Conv, 8 * self._depth,
self._kernels[3], **kwargs)(x)
x = tf.reshape(x, [x.shape[0], np.prod(x.shape[1:])])
shape = tf.concat([tf.shape(obs['image'])[:-3], [x.shape[-1]]], 0)
return tf.reshape(x, shape)
class ConvDecoder(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
thin=True):
self._act = act
self._depth = depth
self._shape = shape
self._kernels = kernels
self._thin = thin
def __call__(self, features, dtype=None):
kwargs = dict(strides=2, activation=self._act)
ConvT = tfkl.Conv2DTranspose
if self._thin:
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
else:
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
x = self.get('h2', ConvT, 4 * self._depth,
self._kernels[0], **kwargs)(x)
x = self.get('h3', ConvT, 2 * self._depth,
self._kernels[1], **kwargs)(x)
x = self.get('h4', ConvT, 1 * self._depth,
self._kernels[2], **kwargs)(x)
x = self.get(
'h5', ConvT, self._shape[-1], self._kernels[3], strides=2)(x)
mean = tf.reshape(x, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
if dtype:
mean = tf.cast(mean, dtype)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
class ConvDecoderMask(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
thin=True):
self._act = act
self._depth = depth
self._shape = shape
self._kernels = kernels
self._thin = thin
def __call__(self, features, dtype=None):
kwargs = dict(strides=2, activation=self._act)
ConvT = tfkl.Conv2DTranspose
if self._thin:
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
else:
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
x = self.get('h2', ConvT, 4 * self._depth,
self._kernels[0], **kwargs)(x)
x = self.get('h3', ConvT, 2 * self._depth,
self._kernels[1], **kwargs)(x)
x = self.get('h4', ConvT, 1 * self._depth,
self._kernels[2], **kwargs)(x)
x = self.get(
'h5', ConvT, 2 * self._shape[-1], self._kernels[3], strides=2)(x)
mean, mask = tf.split(x, [self._shape[-1], self._shape[-1]], -1)
mean = tf.reshape(mean, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
mask = tf.reshape(mask, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
if dtype:
mean = tf.cast(mean, dtype)
mask = tf.cast(mask, dtype)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask
class ConvDecoderMaskEnsemble(tools.Module):
"""
ensemble two convdecoder with <Normal, mask> outputs
NOTE: remove pred1/pred2 for maximum performance.
"""
def __init__(self, decoder1, decoder2):
self._decoder1 = decoder1
self._decoder2 = decoder2
self._shape = decoder1._shape
def __call__(self, feat1, feat2, dtype=None):
kwargs = dict(strides=1, activation=tf.nn.sigmoid)
pred1, mask1 = self._decoder1(feat1, dtype)
pred2, mask2 = self._decoder2(feat2, dtype)
mean1 = pred1.submodules[0].loc
mean2 = pred2.submodules[0].loc
mask_feat = tf.concat([mask1, mask2], -1)
mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat)
mask_use1 = mask
mask_use2 = 1-mask
mean = mean1 * tf.cast(mask_use1, mean1.dtype) + \
mean2 * tf.cast(mask_use2, mean2.dtype)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, mean1.dtype)
class DenseHead(tools.Module):
def __init__(
self, shape, layers, units, act=tf.nn.elu, dist='normal', std=1.0):
self._shape = (shape,) if isinstance(shape, int) else shape
self._layers = layers
self._units = units
self._act = act
self._dist = dist
self._std = std
def __call__(self, features, dtype=None):
x = features
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
mean = self.get(f'hmean', tfkl.Dense, np.prod(self._shape))(x)
mean = tf.reshape(mean, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
if self._std == 'learned':
std = self.get(f'hstd', tfkl.Dense, np.prod(self._shape))(x)
std = tf.nn.softplus(std) + 0.01
std = tf.reshape(std, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
else:
std = self._std
if dtype:
mean, std = tf.cast(mean, dtype), tf.cast(std, dtype)
if self._dist == 'normal':
return tfd.Independent(tfd.Normal(mean, std), len(self._shape))
if self._dist == 'huber':
return tfd.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape))
if self._dist == 'binary':
return tfd.Independent(tfd.Bernoulli(mean), len(self._shape))
raise NotImplementedError(self._dist)
class ActionHead(tools.Module):
def __init__(
self, size, layers, units, act=tf.nn.elu, dist='trunc_normal',
init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0):
# assert min_std <= 2
self._size = size
self._layers = layers
self._units = units
self._dist = dist
self._act = act
self._min_std = min_std
self._init_std = init_std
self._action_disc = action_disc
self._temp = temp() if callable(temp) else temp
self._outscale = outscale
def __call__(self, features, dtype=None):
x = features
for index in range(self._layers):
kw = {}
if index == self._layers - 1 and self._outscale:
kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
self._outscale)
x = self.get(f'h{index}', tfkl.Dense,
self._units, self._act, **kw)(x)
if self._dist == 'tanh_normal':
# https://www.desmos.com/calculator/rcmcf5jwe7
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = tf.tanh(mean)
std = tf.nn.softplus(std + self._init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'tanh_normal_5':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = 5 * tf.tanh(mean / 5)
std = tf.nn.softplus(std + 5) + 5
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'normal':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
std = tf.nn.softplus(std + self._init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.Independent(dist, 1)
elif self._dist == 'normal_1':
mean = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
mean = tf.cast(mean, dtype)
dist = tfd.Normal(mean, 1)
dist = tfd.Independent(dist, 1)
elif self._dist == 'trunc_normal':
# https://www.desmos.com/calculator/mmuvuhnyxo
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
x = tf.cast(x, tf.float32)
mean, std = tf.split(x, 2, -1)
mean = tf.tanh(mean)
std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.DtypeDist(dist, dtype)
dist = tfd.Independent(dist, 1)
elif self._dist == 'onehot':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
x = tf.cast(x, tf.float32)
dist = tools.OneHotDist(x, dtype=dtype)
dist = tools.DtypeDist(dist, dtype)
elif self._dist == 'onehot_gumble':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
x = tf.cast(x, dtype)
temp = self._temp
dist = tools.GumbleDist(temp, x, dtype=dtype)
else:
raise NotImplementedError(self._dist)
return dist
class GRUCell(tf.keras.layers.AbstractRNNCell):
def __init__(self, size, norm=False, act=tf.tanh, update_bias=-1, **kwargs):
super().__init__()
self._size = size
self._act = act
self._norm = norm
self._update_bias = update_bias
self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs)
if norm:
self._norm = tfkl.LayerNormalization(dtype=tf.float32)
@property
def state_size(self):
return self._size
def call(self, inputs, state):
state = state[0] # Keras wraps the state in a list.
parts = self._layer(tf.concat([inputs, state], -1))
if self._norm:
dtype = parts.dtype
parts = tf.cast(parts, tf.float32)
parts = self._norm(parts)
parts = tf.cast(parts, dtype)
reset, cand, update = tf.split(parts, 3, -1)
reset = tf.nn.sigmoid(reset)
cand = self._act(reset * cand)
update = tf.nn.sigmoid(update + self._update_bias)
output = update * cand + (1 - update) * state
return output, [output]

694
DreamerV2/tools.py Normal file
View File

@ -0,0 +1,694 @@
import datetime
import io
import json
import pathlib
import pickle
import re
import time
import uuid
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_probability as tfp
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow_probability import distributions as tfd
# Patch to ignore seed to avoid synchronization across GPUs.
_orig_random_categorical = tf.random.categorical
def random_categorical(*args, **kwargs):
kwargs['seed'] = None
return _orig_random_categorical(*args, **kwargs)
tf.random.categorical = random_categorical
# Patch to ignore seed to avoid synchronization across GPUs.
_orig_random_normal = tf.random.normal
def random_normal(*args, **kwargs):
kwargs['seed'] = None
return _orig_random_normal(*args, **kwargs)
tf.random.normal = random_normal
class AttrDict(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
class Module(tf.Module):
def save(self, filename):
values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
amount = len(tf.nest.flatten(values))
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
print(f'Save checkpoint with {amount} tensors and {count} parameters.')
with pathlib.Path(filename).open('wb') as f:
pickle.dump(values, f)
def load(self, filename):
with pathlib.Path(filename).open('rb') as f:
values = pickle.load(f)
amount = len(tf.nest.flatten(values))
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
print(f'Load checkpoint with {amount} tensors and {count} parameters.')
tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)
def get(self, name, ctor, *args, **kwargs):
# Create or get layer by name to avoid mentioning it in the constructor.
if not hasattr(self, '_modules'):
self._modules = {}
if name not in self._modules:
self._modules[name] = ctor(*args, **kwargs)
return self._modules[name]
def var_nest_names(nest):
if isinstance(nest, dict):
items = ' '.join(f'{k}:{var_nest_names(v)}' for k, v in nest.items())
return '{' + items + '}'
if isinstance(nest, (list, tuple)):
items = ' '.join(var_nest_names(v) for v in nest)
return '[' + items + ']'
if hasattr(nest, 'name') and hasattr(nest, 'shape'):
return nest.name + str(nest.shape).replace(', ', 'x')
if hasattr(nest, 'shape'):
return str(nest.shape).replace(', ', 'x')
return '?'
class Logger:
def __init__(self, logdir, step):
self._logdir = logdir
self._writer = tf.summary.create_file_writer(str(logdir), max_queue=1000)
self._last_step = None
self._last_time = None
self._scalars = {}
self._images = {}
self._videos = {}
self.step = step
def scalar(self, name, value):
self._scalars[name] = float(value)
def image(self, name, value):
self._images[name] = np.array(value)
def video(self, name, value):
self._videos[name] = np.array(value)
def write(self, fps=False):
scalars = list(self._scalars.items())
if fps:
scalars.append(('fps', self._compute_fps(self.step)))
print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
with (self._logdir / 'metrics.jsonl').open('a') as f:
f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
with self._writer.as_default():
for name, value in scalars:
tf.summary.scalar('scalars/' + name, value, self.step)
for name, value in self._images.items():
tf.summary.image(name, value, self.step)
for name, value in self._videos.items():
video_summary(name, value, self.step)
self._writer.flush()
self._scalars = {}
self._images = {}
self._videos = {}
def _compute_fps(self, step):
if self._last_step is None:
self._last_time = time.time()
self._last_step = step
return 0
steps = step - self._last_step
duration = time.time() - self._last_time
self._last_time += duration
self._last_step = step
return steps / duration
def graph_summary(writer, step, fn, *args):
def inner(*args):
tf.summary.experimental.set_step(step.numpy().item())
with writer.as_default():
fn(*args)
return tf.numpy_function(inner, args, [])
def video_summary(name, video, step=None, fps=20):
name = name if isinstance(name, str) else name.decode('utf-8')
if np.issubdtype(video.dtype, np.floating):
video = np.clip(255 * video, 0, 255).astype(np.uint8)
B, T, H, W, C = video.shape
try:
frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
summary = tf1.Summary()
image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C)
image.encoded_image_string = encode_gif(frames, fps)
summary.value.add(tag=name, image=image)
tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
except (IOError, OSError) as e:
print('GIF summaries require ffmpeg in $PATH.', e)
frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C))
tf.summary.image(name, frames, step)
def encode_gif(frames, fps):
from subprocess import Popen, PIPE
h, w, c = frames[0].shape
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
cmd = ' '.join([
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
f'-r {fps:.02f} -f gif -'])
proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
for image in frames:
proc.stdin.write(image.tostring())
out, err = proc.communicate()
if proc.returncode:
raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
del proc
return out
def simulate(agent, envs, steps=0, episodes=0, state=None):
# Initialize or unpack simulation state.
if state is None:
step, episode = 0, 0
done = np.ones(len(envs), np.bool)
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None
else:
step, episode, done, length, obs, agent_state = state
while (steps and step < steps) or (episodes and episode < episodes):
# Reset envs if necessary.
if done.any():
indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
for index, result in zip(indices, results):
obs[index] = result
# Step agents.
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
action, agent_state = agent(obs, done, agent_state)
if isinstance(action, dict):
action = [
{k: np.array(action[k][i]) for k in action}
for i in range(len(envs))]
else:
action = np.array(action)
assert len(action) == len(envs)
# Step envs.
results = [e.step(a) for e, a in zip(envs, action)]
obs, _, done = zip(*[p[:3] for p in results])
obs = list(obs)
done = np.stack(done)
episode += int(done.sum())
length += 1
step += (done * length).sum()
length *= (1 - done)
# import pdb
# pdb.set_trace()
# Return new state to allow resuming the simulation.
return (step - steps, episode - episodes, done, length, obs, agent_state)
def save_episodes(directory, episodes):
directory = pathlib.Path(directory).expanduser()
directory.mkdir(parents=True, exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
filenames = []
for episode in episodes:
identifier = str(uuid.uuid4().hex)
length = len(episode['reward'])
filename = directory / f'{timestamp}-{identifier}-{length}.npz'
with io.BytesIO() as f1:
np.savez_compressed(f1, **episode)
f1.seek(0)
with filename.open('wb') as f2:
f2.write(f1.read())
filenames.append(filename)
return filenames
def sample_episodes(episodes, length=None, balance=False, seed=0):
random = np.random.RandomState(seed)
while True:
episode = random.choice(list(episodes.values()))
if length:
total = len(next(iter(episode.values())))
available = total - length
if available < 1:
# print(f'Skipped short episode of length {available}.')
continue
if balance:
index = min(random.randint(0, total), available)
else:
index = int(random.randint(0, available + 1))
episode = {k: v[index: index + length] for k, v in episode.items()}
yield episode
def load_episodes(directory, limit=None):
directory = pathlib.Path(directory).expanduser()
episodes = {}
total = 0
for filename in reversed(sorted(directory.glob('*.npz'))):
try:
with filename.open('rb') as f:
episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()}
except Exception as e:
print(f'Could not load episode: {e}')
continue
episodes[str(filename)] = episode
total += len(episode['reward']) - 1
if limit and total >= limit:
break
return episodes
class DtypeDist:
def __init__(self, dist, dtype=None):
self._dist = dist
self._dtype = dtype or prec.global_policy().compute_dtype
@property
def name(self):
return 'DtypeDist'
def __getattr__(self, name):
return getattr(self._dist, name)
def mean(self):
return tf.cast(self._dist.mean(), self._dtype)
def mode(self):
return tf.cast(self._dist.mode(), self._dtype)
def entropy(self):
return tf.cast(self._dist.entropy(), self._dtype)
def sample(self, *args, **kwargs):
return tf.cast(self._dist.sample(*args, **kwargs), self._dtype)
class SampleDist:
def __init__(self, dist, samples=100):
self._dist = dist
self._samples = samples
@property
def name(self):
return 'SampleDist'
def __getattr__(self, name):
return getattr(self._dist, name)
def mean(self):
samples = self._dist.sample(self._samples)
return tf.reduce_mean(samples, 0)
def mode(self):
sample = self._dist.sample(self._samples)
logprob = self._dist.log_prob(sample)
return tf.gather(sample, tf.argmax(logprob))[0]
def entropy(self):
sample = self._dist.sample(self._samples)
logprob = self.log_prob(sample)
return -tf.reduce_mean(logprob, 0)
class OneHotDist(tfd.OneHotCategorical):
def __init__(self, logits=None, probs=None, dtype=None):
self._sample_dtype = dtype or prec.global_policy().compute_dtype
super().__init__(logits=logits, probs=probs)
def mode(self):
return tf.cast(super().mode(), self._sample_dtype)
def sample(self, sample_shape=(), seed=None):
# Straight through biased gradient estimator.
sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
probs = super().probs_parameter()
while len(probs.shape) < len(sample.shape):
probs = probs[None]
sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype)
return sample
class GumbleDist(tfd.RelaxedOneHotCategorical):
def __init__(self, temp, logits=None, probs=None, dtype=None):
self._sample_dtype = dtype or prec.global_policy().compute_dtype
self._exact = tfd.OneHotCategorical(logits=logits, probs=probs)
super().__init__(temp, logits=logits, probs=probs)
def mode(self):
return tf.cast(self._exact.mode(), self._sample_dtype)
def entropy(self):
return tf.cast(self._exact.entropy(), self._sample_dtype)
def sample(self, sample_shape=(), seed=None):
return tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
class UnnormalizedHuber(tfd.Normal):
def __init__(self, loc, scale, threshold=1, **kwargs):
self._threshold = tf.cast(threshold, loc.dtype)
super().__init__(loc, scale, **kwargs)
def log_prob(self, event):
return -(tf.math.sqrt(
(event - self.mean()) ** 2 + self._threshold ** 2) - self._threshold)
class SafeTruncatedNormal(tfd.TruncatedNormal):
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
super().__init__(loc, scale, low, high)
self._clip = clip
self._mult = mult
def sample(self, *args, **kwargs):
event = super().sample(*args, **kwargs)
if self._clip:
clipped = tf.clip_by_value(
event, self.low + self._clip, self.high - self._clip)
event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped)
if self._mult:
event *= self._mult
return event
class TanhBijector(tfp.bijectors.Bijector):
def __init__(self, validate_args=False, name='tanh'):
super().__init__(
forward_min_event_ndims=0,
validate_args=validate_args,
name=name)
def _forward(self, x):
return tf.nn.tanh(x)
def _inverse(self, y):
dtype = y.dtype
y = tf.cast(y, tf.float32)
y = tf.where(
tf.less_equal(tf.abs(y), 1.),
tf.clip_by_value(y, -0.99999997, 0.99999997), y)
y = tf.atanh(y)
y = tf.cast(y, dtype)
return y
def _forward_log_det_jacobian(self, x):
log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype))
return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x))
def lambda_return(
reward, value, pcont, bootstrap, lambda_, axis):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
if isinstance(pcont, (int, float)):
pcont = pcont * tf.ones_like(reward)
dims = list(range(reward.shape.ndims))
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
if axis != 0:
reward = tf.transpose(reward, dims)
value = tf.transpose(value, dims)
pcont = tf.transpose(pcont, dims)
if bootstrap is None:
bootstrap = tf.zeros_like(value[-1])
next_values = tf.concat([value[1:], bootstrap[None]], 0)
inputs = reward + pcont * next_values * (1 - lambda_)
returns = static_scan(
lambda agg, cur: cur[0] + cur[1] * lambda_ * agg,
(inputs, pcont), bootstrap, reverse=True)
if axis != 0:
returns = tf.transpose(returns, dims)
return returns
class Optimizer(tf.Module):
def __init__(
self, name, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
opt='adam'):
assert 0 <= wd < 1
assert not clip or 1 <= clip
self._name = name
self._clip = clip
self._wd = wd
self._wd_pattern = wd_pattern
self._opt = {
'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps),
'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps),
'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps),
'sgd': lambda: tf.optimizers.SGD(lr),
'momentum': lambda: tf.optimizers.SGD(lr, 0.9),
}[opt]()
self._mixed = (prec.global_policy().compute_dtype == tf.float16)
if self._mixed:
self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic')
@property
def variables(self):
return self._opt.variables()
def __call__(self, tape, loss, modules, prefix=None):
assert loss.dtype is tf.float32, self._name
modules = modules if hasattr(modules, '__len__') else (modules,)
varibs = tf.nest.flatten([module.variables for module in modules])
count = sum(np.prod(x.shape) for x in varibs)
print(f'Found {count} {self._name} parameters.')
assert len(loss.shape) == 0, loss.shape
tf.debugging.check_numerics(loss, self._name + '_loss')
if self._mixed:
with tape:
loss = self._opt.get_scaled_loss(loss)
grads = tape.gradient(loss, varibs)
if self._mixed:
grads = self._opt.get_unscaled_gradients(grads)
norm = tf.linalg.global_norm(grads)
if not self._mixed:
tf.debugging.check_numerics(norm, self._name + '_norm')
if self._clip:
grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
if self._wd:
self._apply_weight_decay(varibs)
self._opt.apply_gradients(zip(grads, varibs))
metrics = {}
if prefix:
metrics[f'{prefix}/{self._name}_loss'] = loss
metrics[f'{prefix}/{self._name}_grad_norm'] = norm
if self._mixed:
metrics[f'{prefix}/{self._name}_loss_scale'] = \
self._opt.loss_scale._current_loss_scale
else:
metrics[f'{self._name}_loss'] = loss
metrics[f'{self._name}_grad_norm'] = norm
if self._mixed:
metrics[f'{self._name}_loss_scale'] = \
self._opt.loss_scale._current_loss_scale
return metrics
def _apply_weight_decay(self, varibs):
nontrivial = (self._wd_pattern != r'.*')
if nontrivial:
print('Applied weight decay to variables:')
for var in varibs:
if re.search(self._wd_pattern, self._name + '/' + var.name):
if nontrivial:
print('- ' + self._name + '/' + var.name)
var.assign((1 - self._wd) * var)
def args_type(default):
def parse_string(x):
if default is None:
return x
if isinstance(default, bool):
return bool(['False', 'True'].index(x))
if isinstance(default, int):
return float(x) if ('e' in x or '.' in x) else int(x)
if isinstance(default, (list, tuple)):
return tuple(args_type(default[0])(y) for y in x.split(','))
return type(default)(x)
def parse_object(x):
if isinstance(default, (list, tuple)):
return tuple(x)
return x
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
def static_scan(fn, inputs, start, reverse=False):
last = start
outputs = [[] for _ in tf.nest.flatten(start)]
indices = range(len(tf.nest.flatten(inputs)[0]))
if reverse:
indices = reversed(indices)
for index in indices:
inp = tf.nest.map_structure(lambda x: x[index], inputs)
last = fn(last, inp)
[o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))]
if reverse:
outputs = [list(reversed(x)) for x in outputs]
outputs = [tf.stack(x, 0) for x in outputs]
return tf.nest.pack_sequence_as(start, outputs)
def uniform_mixture(dist, dtype=None):
if dist.batch_shape[-1] == 1:
return tfd.BatchReshape(dist, dist.batch_shape[:-1])
dtype = dtype or prec.global_policy().compute_dtype
weights = tfd.Categorical(tf.zeros(dist.batch_shape, dtype))
return tfd.MixtureSameFamily(weights, dist)
def cat_mixture_entropy(dist):
if isinstance(dist, tfd.MixtureSameFamily):
probs = dist.components_distribution.probs_parameter()
else:
probs = dist.probs_parameter()
return -tf.reduce_mean(
tf.reduce_mean(probs, 2) *
tf.math.log(tf.reduce_mean(probs, 2) + 1e-8), -1)
@tf.function
def cem_planner(
state, num_actions, horizon, proposals, topk, iterations, imagine,
objective):
dtype = prec.global_policy().compute_dtype
B, P = list(state.values())[0].shape[0], proposals
H, A = horizon, num_actions
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
mean = tf.zeros((B, H, A), dtype)
std = tf.ones((B, H, A), dtype)
for _ in range(iterations):
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
proposals = proposals * std[:, None] + mean[:, None]
proposals = tf.clip_by_value(proposals, -1, 1)
flat_proposals = tf.reshape(proposals, (B * P, H, A))
states = imagine(flat_proposals, flat_state)
scores = objective(states)
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
_, indices = tf.math.top_k(scores, topk, sorted=False)
best = tf.gather(proposals, indices, axis=1, batch_dims=1)
mean, var = tf.nn.moments(best, 1)
std = tf.sqrt(var + 1e-6)
return mean[:, 0, :]
@tf.function
def grad_planner(
state, num_actions, horizon, proposals, iterations, imagine, objective,
kl_scale, step_size):
dtype = prec.global_policy().compute_dtype
B, P = list(state.values())[0].shape[0], proposals
H, A = horizon, num_actions
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
mean = tf.zeros((B, H, A), dtype)
rawstd = 0.54 * tf.ones((B, H, A), dtype)
for _ in range(iterations):
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(mean)
tape.watch(rawstd)
std = tf.nn.softplus(rawstd)
proposals = proposals * std[:, None] + mean[:, None]
proposals = (
tf.stop_gradient(tf.clip_by_value(proposals, -1, 1)) +
proposals - tf.stop_gradient(proposals))
flat_proposals = tf.reshape(proposals, (B * P, H, A))
states = imagine(flat_proposals, flat_state)
scores = objective(states)
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
div = tfd.kl_divergence(
tfd.Normal(mean, std),
tfd.Normal(tf.zeros_like(mean), tf.ones_like(std)))
elbo = tf.reduce_sum(scores) - kl_scale * div
elbo /= tf.cast(tf.reduce_prod(tf.shape(scores)), dtype)
grad_mean, grad_rawstd = tape.gradient(elbo, [mean, rawstd])
e, v = tf.nn.moments(grad_mean, [1, 2], keepdims=True)
grad_mean /= tf.sqrt(e * e + v + 1e-4)
e, v = tf.nn.moments(grad_rawstd, [1, 2], keepdims=True)
grad_rawstd /= tf.sqrt(e * e + v + 1e-4)
mean = tf.clip_by_value(mean + step_size * grad_mean, -1, 1)
rawstd = rawstd + step_size * grad_rawstd
return mean[:, 0, :]
class Every:
def __init__(self, every):
self._every = every
self._last = None
def __call__(self, step):
if not self._every:
return False
if self._last is None:
self._last = step
return True
if step >= self._last + self._every:
self._last += self._every
return True
return False
class Once:
def __init__(self):
self._once = True
def __call__(self):
if self._once:
self._once = False
return True
return False
class Until:
def __init__(self, until):
self._until = until
def __call__(self, step):
if not self._until:
return True
return step < self._until
def schedule(string, step):
try:
return float(string)
except ValueError:
step = tf.cast(step, tf.float32)
match = re.match(r'linear\((.+),(.+),(.+)\)', string)
if match:
initial, final, duration = [float(group) for group in match.groups()]
mix = tf.clip_by_value(step / duration, 0, 1)
return (1 - mix) * initial + mix * final
match = re.match(r'warmup\((.+),(.+)\)', string)
if match:
warmup, value = [float(group) for group in match.groups()]
scale = tf.clip_by_value(step / warmup, 0, 1)
return scale * value
match = re.match(r'exp\((.+),(.+),(.+)\)', string)
if match:
initial, final, halflife = [float(group) for group in match.groups()]
return (initial - final) * 0.5 ** (step / halflife) + final
raise NotImplementedError(string)

280
DreamerV2/wrappers.py Normal file
View File

@ -0,0 +1,280 @@
import threading
import gym
import numpy as np
class DeepMindControl:
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
domain, task = name.split('_', 1)
if domain == 'cup': # Only domain with multiple words.
domain = 'ball_in_cup'
if isinstance(domain, str):
from dm_control import suite
self._env = suite.load(domain, task)
else:
assert task is None
self._env = domain()
self._action_repeat = action_repeat
self._size = size
if camera is None:
camera = dict(quadruped=2).get(domain, 0)
self._camera = camera
@property
def observation_space(self):
spaces = {}
for key, value in self._env.observation_spec().items():
spaces[key] = gym.spaces.Box(
-np.inf, np.inf, value.shape, dtype=np.float32)
spaces['image'] = gym.spaces.Box(
0, 255, self._size + (3,), dtype=np.uint8)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
spec = self._env.action_spec()
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
def step(self, action):
assert np.isfinite(action).all(), action
reward = 0
for _ in range(self._action_repeat):
time_step = self._env.step(action)
reward += time_step.reward or 0
if time_step.last():
break
obs = dict(time_step.observation)
obs['image'] = self.render()
done = time_step.last()
info = {'discount': np.array(time_step.discount, np.float32)}
return obs, reward, done, info
def reset(self):
time_step = self._env.reset()
obs = dict(time_step.observation)
obs['image'] = self.render()
return obs
def render(self, *args, **kwargs):
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
raise ValueError("Only render mode 'rgb_array' is supported.")
return self._env.physics.render(*self._size, camera_id=self._camera)
class Atari:
LOCK = threading.Lock()
def __init__(
self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30,
life_done=False, sticky_actions=True, all_actions=False):
assert size[0] == size[1]
import gym.wrappers
import gym.envs.atari
with self.LOCK:
env = gym.envs.atari.AtariEnv(
game=name, obs_type='image', frameskip=1,
repeat_action_probability=0.25 if sticky_actions else 0.0,
full_action_space=all_actions)
# Avoid unnecessary rendering in inner env.
env._get_obs = lambda: None
# Tell wrapper that the inner env has no action repeat.
env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0')
env = gym.wrappers.AtariPreprocessing(
env, noops, action_repeat, size[0], life_done, grayscale)
self._env = env
self._grayscale = grayscale
@property
def observation_space(self):
return gym.spaces.Dict({
'image': self._env.observation_space,
'ram': gym.spaces.Box(0, 255, (128,), np.uint8),
})
@property
def action_space(self):
return self._env.action_space
def close(self):
return self._env.close()
def reset(self):
with self.LOCK:
image = self._env.reset()
if self._grayscale:
image = image[..., None]
obs = {'image': image, 'ram': self._env.env._get_ram()}
return obs
def step(self, action):
image, reward, done, info = self._env.step(action)
if self._grayscale:
image = image[..., None]
obs = {'image': image, 'ram': self._env.env._get_ram()}
return obs, reward, done, info
def render(self, mode):
return self._env.render(mode)
class CollectDataset:
def __init__(self, env, callbacks=None, precision=32):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
transition['action'] = action
transition['reward'] = reward
transition['discount'] = info.get('discount', np.array(1 - float(done)))
self._episode.append(transition)
if done:
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info['episode'] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
transition['action'] = np.zeros(self._env.action_space.shape)
transition['reward'] = 0.0
transition['discount'] = 1.0
self._episode = [transition]
return obs
def _convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, 'Must reset environment.'
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if 'discount' not in info:
info['discount'] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
class NormalizeActions:
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low),
np.isfinite(env.action_space.high))
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
class OneHotAction:
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
self._random = np.random.RandomState()
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f'Invalid one-hot action:\n{action}')
return self._env.step(index)
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class RewardObs:
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
@property
def observation_space(self):
spaces = self._env.observation_space.spaces
assert 'reward' not in spaces
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
return gym.spaces.Dict(spaces)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs['reward'] = reward
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
obs['reward'] = 0.0
return obs