Add Dreamerv2 Files
This commit is contained in:
parent
1ddf72b0c3
commit
0ac3131dad
185
DreamerV2/configs.yaml
Normal file
185
DreamerV2/configs.yaml
Normal file
@ -0,0 +1,185 @@
|
||||
defaults:
|
||||
gpu: 'none'
|
||||
logdir: ./
|
||||
traindir: null
|
||||
evaldir: null
|
||||
offline_traindir: ''
|
||||
offline_evaldir: ''
|
||||
seed: 0
|
||||
steps: 1e7
|
||||
eval_every: 1e4
|
||||
log_every: 1e4
|
||||
reset_every: 0
|
||||
gpu_growth: True
|
||||
precision: 32
|
||||
debug: False
|
||||
expl_gifs: False
|
||||
|
||||
# Environment
|
||||
task: 'dmc_walker_walk'
|
||||
size: [64, 64]
|
||||
envs: 1
|
||||
action_repeat: 2
|
||||
time_limit: 1000
|
||||
prefill: 2500
|
||||
eval_noise: 0.0
|
||||
clip_rewards: 'identity'
|
||||
atari_grayscale: False
|
||||
|
||||
# Model
|
||||
dyn_cell: 'gru'
|
||||
dyn_hidden: 200
|
||||
dyn_deter: 200
|
||||
dyn_stoch: 50
|
||||
dyn_discrete: 0
|
||||
dyn_input_layers: 1
|
||||
dyn_output_layers: 1
|
||||
dyn_shared: False
|
||||
dyn_mean_act: 'none'
|
||||
dyn_std_act: 'sigmoid2'
|
||||
dyn_min_std: 0.1
|
||||
grad_heads: ['image', 'reward']
|
||||
units: 400
|
||||
reward_layers: 2
|
||||
discount_layers: 3
|
||||
value_layers: 3
|
||||
actor_layers: 4
|
||||
act: 'elu'
|
||||
cnn_depth: 32
|
||||
encoder_kernels: [4, 4, 4, 4]
|
||||
decoder_kernels: [5, 5, 6, 6]
|
||||
decoder_thin: True
|
||||
value_head: 'normal'
|
||||
kl_scale: '1.0'
|
||||
kl_balance: '0.8'
|
||||
kl_free: '1.0'
|
||||
pred_discount: False
|
||||
discount_scale: 1.0
|
||||
reward_scale: 1.0
|
||||
weight_decay: 0.0
|
||||
|
||||
# Training
|
||||
batch_size: 50
|
||||
batch_length: 50
|
||||
train_every: 5
|
||||
train_steps: 1
|
||||
pretrain: 100
|
||||
model_lr: 3e-4
|
||||
value_lr: 8e-5
|
||||
actor_lr: 8e-5
|
||||
opt_eps: 1e-5
|
||||
grad_clip: 100
|
||||
value_grad_clip: 100
|
||||
actor_grad_clip: 100
|
||||
dataset_size: 0
|
||||
oversample_ends: False
|
||||
slow_value_target: True
|
||||
slow_actor_target: True
|
||||
slow_target_update: 100
|
||||
slow_target_fraction: 1
|
||||
opt: 'adam'
|
||||
|
||||
# Behavior.
|
||||
discount: 0.99
|
||||
discount_lambda: 0.95
|
||||
imag_horizon: 15
|
||||
imag_gradient: 'dynamics'
|
||||
imag_gradient_mix: '0.1'
|
||||
imag_sample: True
|
||||
actor_dist: 'trunc_normal'
|
||||
actor_entropy: '1e-4'
|
||||
actor_state_entropy: 0.0
|
||||
actor_init_std: 1.0
|
||||
actor_min_std: 0.1
|
||||
actor_disc: 5
|
||||
actor_temp: 0.1
|
||||
actor_outscale: 0.0
|
||||
expl_amount: 0.0
|
||||
eval_state_mean: False
|
||||
collect_dyn_sample: True
|
||||
behavior_stop_grad: True
|
||||
value_decay: 0.0
|
||||
future_entropy: False
|
||||
|
||||
# Exploration
|
||||
expl_behavior: 'greedy'
|
||||
expl_until: 0
|
||||
expl_extr_scale: 0.0
|
||||
expl_intr_scale: 1.0
|
||||
disag_target: 'stoch'
|
||||
disag_log: True
|
||||
disag_models: 10
|
||||
disag_offset: 1
|
||||
disag_layers: 4
|
||||
disag_units: 400
|
||||
|
||||
atari:
|
||||
|
||||
# General
|
||||
task: 'atari_demon_attack'
|
||||
steps: 3e7
|
||||
eval_every: 1e5
|
||||
log_every: 1e4
|
||||
prefill: 50000
|
||||
dataset_size: 2e6
|
||||
pretrain: 0
|
||||
precision: 16
|
||||
|
||||
# Environment
|
||||
time_limit: 108000 # 30 minutes of game play.
|
||||
atari_grayscale: True
|
||||
action_repeat: 4
|
||||
eval_noise: 0.001
|
||||
train_every: 16
|
||||
train_steps: 1
|
||||
clip_rewards: 'tanh'
|
||||
|
||||
# Model
|
||||
grad_heads: ['image', 'reward', 'discount']
|
||||
dyn_cell: 'gru_layer_norm'
|
||||
pred_discount: True
|
||||
cnn_depth: 48
|
||||
dyn_deter: 600
|
||||
dyn_hidden: 600
|
||||
dyn_stoch: 32
|
||||
dyn_discrete: 32
|
||||
reward_layers: 4
|
||||
discount_layers: 4
|
||||
value_layers: 4
|
||||
actor_layers: 4
|
||||
|
||||
# Behavior
|
||||
actor_dist: 'onehot'
|
||||
actor_entropy: 'linear(3e-3,3e-4,2.5e6)'
|
||||
expl_amount: 0.0
|
||||
expl_until: 3e7
|
||||
discount: 0.995
|
||||
imag_gradient: 'both'
|
||||
imag_gradient_mix: 'linear(0.1,0,2.5e6)'
|
||||
|
||||
# Training
|
||||
discount_scale: 5.0
|
||||
reward_scale: 1
|
||||
weight_decay: 1e-6
|
||||
model_lr: 2e-4
|
||||
kl_scale: 0.1
|
||||
kl_free: 0.0
|
||||
actor_lr: 4e-5
|
||||
value_lr: 1e-4
|
||||
oversample_ends: True
|
||||
|
||||
# Disen
|
||||
disen_cnn_depth: 16
|
||||
disen_only_scale: 1.0
|
||||
disen_discount_scale: 2000.0
|
||||
disen_reward_scale: 2000.0
|
||||
num_reward_opt_iters: 20
|
||||
|
||||
debug:
|
||||
|
||||
debug: True
|
||||
pretrain: 1
|
||||
prefill: 1
|
||||
train_steps: 1
|
||||
batch_size: 10
|
||||
batch_length: 20
|
316
DreamerV2/dreamer.py
Normal file
316
DreamerV2/dreamer.py
Normal file
@ -0,0 +1,316 @@
|
||||
import argparse
|
||||
import collections
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
os.environ['MUJOCO_GL'] = 'egl'
|
||||
|
||||
import numpy as np
|
||||
import ruamel.yaml as yaml
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
|
||||
tf.get_logger().setLevel('ERROR')
|
||||
|
||||
from tensorflow_probability import distributions as tfd
|
||||
|
||||
sys.path.append(str(pathlib.Path(__file__).parent))
|
||||
|
||||
import exploration as expl
|
||||
import models
|
||||
import tools
|
||||
import wrappers
|
||||
|
||||
class Dreamer(tools.Module):
|
||||
|
||||
def __init__(self, config, logger, dataset):
|
||||
self._config = config
|
||||
self._logger = logger
|
||||
self._float = prec.global_policy().compute_dtype
|
||||
self._should_log = tools.Every(config.log_every)
|
||||
self._should_train = tools.Every(config.train_every)
|
||||
self._should_pretrain = tools.Once()
|
||||
self._should_reset = tools.Every(config.reset_every)
|
||||
self._should_expl = tools.Until(int(
|
||||
config.expl_until / config.action_repeat))
|
||||
self._metrics = collections.defaultdict(tf.metrics.Mean)
|
||||
with tf.device('cpu:0'):
|
||||
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64)
|
||||
# Schedules.
|
||||
config.actor_entropy = (
|
||||
lambda x=config.actor_entropy: tools.schedule(x, self._step))
|
||||
config.actor_state_entropy = (
|
||||
lambda x=config.actor_state_entropy: tools.schedule(x, self._step))
|
||||
config.imag_gradient_mix = (
|
||||
lambda x=config.imag_gradient_mix: tools.schedule(x, self._step))
|
||||
self._dataset = iter(dataset)
|
||||
self._wm = models.WorldModel(self._step, config)
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad)
|
||||
reward = lambda f, s, a: self._wm.heads['reward'](f).mode()
|
||||
self._expl_behavior = dict(
|
||||
greedy=lambda: self._task_behavior,
|
||||
random=lambda: expl.Random(config),
|
||||
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
|
||||
)[config.expl_behavior]()
|
||||
# Train step to initialize variables including optimizer statistics.
|
||||
self._train(next(self._dataset))
|
||||
|
||||
def __call__(self, obs, reset, state=None, training=True):
|
||||
step = self._step.numpy().item()
|
||||
if self._should_reset(step):
|
||||
state = None
|
||||
if state is not None and reset.any():
|
||||
mask = tf.cast(1 - reset, self._float)[:, None]
|
||||
state = tf.nest.map_structure(lambda x: x * mask, state)
|
||||
if training and self._should_train(step):
|
||||
steps = (
|
||||
self._config.pretrain if self._should_pretrain()
|
||||
else self._config.train_steps)
|
||||
for _ in range(steps):
|
||||
self._train(next(self._dataset))
|
||||
if self._should_log(step):
|
||||
for name, mean in self._metrics.items():
|
||||
self._logger.scalar(name, float(mean.result()))
|
||||
mean.reset_states()
|
||||
openl_joint, openl_main, openl_disen, openl_mask = self._wm.video_pred(next(self._dataset))
|
||||
self._logger.video('train_openl_joint', openl_joint)
|
||||
self._logger.video('train_openl_main', openl_main)
|
||||
self._logger.video('train_openl_disen', openl_disen)
|
||||
self._logger.video('train_openl_mask', openl_mask)
|
||||
self._logger.write(fps=True)
|
||||
action, state = self._policy(obs, state, training)
|
||||
if training:
|
||||
self._step.assign_add(len(reset))
|
||||
self._logger.step = self._config.action_repeat \
|
||||
* self._step.numpy().item()
|
||||
return action, state
|
||||
|
||||
@tf.function
|
||||
def _policy(self, obs, state, training):
|
||||
if state is None:
|
||||
batch_size = len(obs['image'])
|
||||
latent = self._wm.dynamics.initial(len(obs['image']))
|
||||
action = tf.zeros((batch_size, self._config.num_actions), self._float)
|
||||
else:
|
||||
latent, action = state
|
||||
embed = self._wm.encoder(self._wm.preprocess(obs))
|
||||
latent, _ = self._wm.dynamics.obs_step(
|
||||
latent, action, embed, self._config.collect_dyn_sample)
|
||||
if self._config.eval_state_mean:
|
||||
latent['stoch'] = latent['mean']
|
||||
feat = self._wm.dynamics.get_feat(latent)
|
||||
if not training:
|
||||
action = self._task_behavior.actor(feat).mode()
|
||||
elif self._should_expl(self._step):
|
||||
action = self._expl_behavior.actor(feat).sample()
|
||||
else:
|
||||
action = self._task_behavior.actor(feat).sample()
|
||||
if self._config.actor_dist == 'onehot_gumble':
|
||||
action = tf.cast(
|
||||
tf.one_hot(tf.argmax(action, axis=-1), self._config.num_actions),
|
||||
action.dtype)
|
||||
action = self._exploration(action, training)
|
||||
state = (latent, action)
|
||||
return action, state
|
||||
|
||||
def _exploration(self, action, training):
|
||||
amount = self._config.expl_amount if training else self._config.eval_noise
|
||||
if amount == 0:
|
||||
return action
|
||||
amount = tf.cast(amount, self._float)
|
||||
if 'onehot' in self._config.actor_dist:
|
||||
probs = amount / self._config.num_actions + (1 - amount) * action
|
||||
return tools.OneHotDist(probs=probs).sample()
|
||||
else:
|
||||
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
|
||||
raise NotImplementedError(self._config.action_noise)
|
||||
|
||||
@tf.function
|
||||
def _train(self, data):
|
||||
print('Tracing train function.')
|
||||
metrics = {}
|
||||
embed, post, feat, kl, mets = self._wm.train(data)
|
||||
metrics.update(mets)
|
||||
start = post
|
||||
if self._config.pred_discount: # Last step could be terminal.
|
||||
start = {k: v[:, :-1] for k, v in post.items()}
|
||||
embed, feat, kl = embed[:, :-1], feat[:, :-1], kl[:, :-1]
|
||||
reward = lambda f, s, a: self._wm.heads['reward'](f).mode()
|
||||
metrics.update(self._task_behavior.train(start, reward)[-1])
|
||||
if self._config.expl_behavior != 'greedy':
|
||||
mets = self._expl_behavior.train(start, feat, embed, kl)[-1]
|
||||
metrics.update({'expl_' + key: value for key, value in mets.items()})
|
||||
for name, value in metrics.items():
|
||||
self._metrics[name].update_state(value)
|
||||
|
||||
|
||||
def count_steps(folder):
|
||||
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in folder.glob('*.npz'))
|
||||
|
||||
|
||||
def make_dataset(episodes, config):
|
||||
example = episodes[next(iter(episodes.keys()))]
|
||||
types = {k: v.dtype for k, v in example.items()}
|
||||
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
|
||||
generator = lambda: tools.sample_episodes(
|
||||
episodes, config.batch_length, config.oversample_ends)
|
||||
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
|
||||
dataset = dataset.batch(config.batch_size, drop_remainder=True)
|
||||
dataset = dataset.prefetch(10)
|
||||
return dataset
|
||||
|
||||
|
||||
def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
suite, task = config.task.split('_', 1)
|
||||
if suite == 'dmc':
|
||||
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
|
||||
env = wrappers.NormalizeActions(env)
|
||||
elif suite == 'atari':
|
||||
env = wrappers.Atari(
|
||||
task, config.action_repeat, config.size,
|
||||
grayscale=config.atari_grayscale,
|
||||
life_done=False and (mode == 'train'),
|
||||
sticky_actions=True,
|
||||
all_actions=True)
|
||||
env = wrappers.OneHotAction(env)
|
||||
else:
|
||||
raise NotImplementedError(suite)
|
||||
env = wrappers.TimeLimit(env, config.time_limit)
|
||||
callbacks = [functools.partial(
|
||||
process_episode, config, logger, mode, train_eps, eval_eps)]
|
||||
env = wrappers.CollectDataset(env, callbacks)
|
||||
env = wrappers.RewardObs(env)
|
||||
return env
|
||||
|
||||
|
||||
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
|
||||
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
|
||||
cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
filename = tools.save_episodes(directory, [episode])[0]
|
||||
length = len(episode['reward']) - 1
|
||||
score = float(episode['reward'].astype(np.float64).sum())
|
||||
video = episode['image']
|
||||
if mode == 'eval':
|
||||
cache.clear()
|
||||
if mode == 'train' and config.dataset_size:
|
||||
total = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
if total <= config.dataset_size - length:
|
||||
total += len(ep['reward']) - 1
|
||||
else:
|
||||
del cache[key]
|
||||
logger.scalar('dataset_size', total + length)
|
||||
cache[str(filename)] = episode
|
||||
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
|
||||
logger.scalar(f'{mode}_return', score)
|
||||
logger.scalar(f'{mode}_length', length)
|
||||
logger.scalar(f'{mode}_episodes', len(cache))
|
||||
if mode == 'eval' or config.expl_gifs:
|
||||
logger.video(f'{mode}_policy', video[None])
|
||||
logger.write()
|
||||
|
||||
|
||||
def main(logdir, config):
|
||||
|
||||
logdir = os.path.join(
|
||||
logdir, config.task, 'Ours', str(config.seed))
|
||||
|
||||
logdir = pathlib.Path(logdir).expanduser()
|
||||
config.traindir = config.traindir or logdir / 'train_eps'
|
||||
config.evaldir = config.evaldir or logdir / 'eval_eps'
|
||||
config.steps //= config.action_repeat
|
||||
config.eval_every //= config.action_repeat
|
||||
config.log_every //= config.action_repeat
|
||||
config.time_limit //= config.action_repeat
|
||||
config.act = getattr(tf.nn, config.act)
|
||||
|
||||
if config.debug:
|
||||
tf.config.experimental_run_functions_eagerly(True)
|
||||
if config.gpu_growth:
|
||||
message = 'No GPU found. To actually train on CPU remove this assert.'
|
||||
assert tf.config.experimental.list_physical_devices('GPU'), message
|
||||
for gpu in tf.config.experimental.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
assert config.precision in (16, 32), config.precision
|
||||
if config.precision == 16:
|
||||
prec.set_policy(prec.Policy('mixed_float16'))
|
||||
print('Logdir', logdir)
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
config.traindir.mkdir(parents=True, exist_ok=True)
|
||||
config.evaldir.mkdir(parents=True, exist_ok=True)
|
||||
step = count_steps(config.traindir)
|
||||
logger = tools.Logger(logdir, config.action_repeat * step)
|
||||
|
||||
print('Create envs.')
|
||||
if config.offline_traindir:
|
||||
directory = config.offline_traindir.format(**vars(config))
|
||||
else:
|
||||
directory = config.traindir
|
||||
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
|
||||
if config.offline_evaldir:
|
||||
directory = config.offline_evaldir.format(**vars(config))
|
||||
else:
|
||||
directory = config.evaldir
|
||||
eval_eps = tools.load_episodes(directory, limit=1)
|
||||
make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
|
||||
train_envs = [make('train') for _ in range(config.envs)]
|
||||
eval_envs = [make('eval') for _ in range(config.envs)]
|
||||
acts = train_envs[0].action_space
|
||||
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]
|
||||
|
||||
prefill = max(0, config.prefill - count_steps(config.traindir))
|
||||
print(f'Prefill dataset ({prefill} steps).')
|
||||
random_agent = lambda o, d, s: ([acts.sample() for _ in d], s)
|
||||
tools.simulate(random_agent, train_envs, prefill)
|
||||
tools.simulate(random_agent, eval_envs, episodes=1)
|
||||
logger.step = config.action_repeat * count_steps(config.traindir)
|
||||
|
||||
print('Simulate agent.')
|
||||
train_dataset = make_dataset(train_eps, config)
|
||||
eval_dataset = iter(make_dataset(eval_eps, config))
|
||||
agent = Dreamer(config, logger, train_dataset)
|
||||
if (logdir / 'variables.pkl').exists():
|
||||
agent.load(logdir / 'variables.pkl')
|
||||
agent._should_pretrain._once = False
|
||||
|
||||
state = None
|
||||
suite, task = config.task.split('_', 1)
|
||||
num_eval_episodes = 10 if suite == 'procgen' else 1
|
||||
while agent._step.numpy().item() < config.steps:
|
||||
logger.write()
|
||||
print('Start evaluation.')
|
||||
openl_joint, openl_main, openl_disen, openl_mask = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video('eval_openl_joint', openl_joint)
|
||||
logger.video('eval_openl_main', openl_main)
|
||||
logger.video('eval_openl_disen', openl_disen)
|
||||
logger.video('eval_openl_mask', openl_mask)
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=num_eval_episodes)
|
||||
print('Start training.')
|
||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||
agent.save(logdir / 'variables.pkl')
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
env.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--configs', nargs='+', required=True)
|
||||
args, remaining = parser.parse_known_args()
|
||||
configs = yaml.safe_load((pathlib.Path(__file__).parent / 'configs.yaml').read_text())
|
||||
config_ = {}
|
||||
for name in args.configs:
|
||||
config_.update(configs[name])
|
||||
parser = argparse.ArgumentParser()
|
||||
for key, value in config_.items():
|
||||
arg_type = tools.args_type(value)
|
||||
parser.add_argument(f'--{key}', type=arg_type, default=arg_type(value))
|
||||
main(config_['logdir'], parser.parse_args(remaining))
|
83
DreamerV2/exploration.py
Normal file
83
DreamerV2/exploration.py
Normal file
@ -0,0 +1,83 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
from tensorflow_probability import distributions as tfd
|
||||
|
||||
import models
|
||||
import networks
|
||||
import tools
|
||||
|
||||
class Random(tools.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
self._float = prec.global_policy().comput_dtype
|
||||
|
||||
def actor(self, feat):
|
||||
shape = feat.shape[:-1] + [self._config.num_actions]
|
||||
if self._config.actor_dist == 'onehot':
|
||||
return tools.OneHotDist(tf.zeros(shape))
|
||||
else:
|
||||
ones = tf.ones(shape, self._float)
|
||||
return tfd.Uniform(-ones, ones)
|
||||
|
||||
def train(self, start, feat, embed, kl):
|
||||
return None, {}
|
||||
|
||||
|
||||
class Plan2Explore(tools.Module):
|
||||
|
||||
def __init__(self, config, world_model, reward=None):
|
||||
self._config = config
|
||||
self._reward = reward
|
||||
self._behavior = models.ImagBehavior(config, world_model)
|
||||
self.actor = self._behavior.actor
|
||||
size = {
|
||||
'embed': 32 * config.cnn_depth,
|
||||
'stoch': config.dyn_stoch,
|
||||
'deter': config.dyn_deter,
|
||||
'feat': config.dyn_stoch + config.dyn_deter,
|
||||
}[self._config.disag_target]
|
||||
kw = dict(
|
||||
shape=size, layers=config.disag_layers, units=config.disag_units,
|
||||
act=config.act)
|
||||
self._networks = [
|
||||
networks.DenseHead(**kw) for _ in range(config.disag_models)]
|
||||
self._opt = tools.Optimizer(
|
||||
'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
|
||||
config.weight_decay, opt=config.opt)
|
||||
|
||||
def train(self, start, feat, embed, kl):
|
||||
metrics = {}
|
||||
target = {
|
||||
'embed': embed,
|
||||
'stoch': start['stoch'],
|
||||
'deter': start['deter'],
|
||||
'feat': feat,
|
||||
}[self._config.disag_target]
|
||||
metrics.update(self._train_ensemble(feat, target))
|
||||
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
|
||||
return None, metrics
|
||||
|
||||
def _intrinsic_reward(self, feat, state, action):
|
||||
preds = [head(feat, tf.float32).mean() for head in self._networks]
|
||||
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
|
||||
if self._config.disag_log:
|
||||
disag = tf.math.log(disag)
|
||||
reward = self._config.expl_intr_scale * disag
|
||||
if self._config.expl_extr_scale:
|
||||
reward += tf.cast(self._config.expl_extr_scale * self._reward(
|
||||
feat, state, action), tf.float32)
|
||||
return reward
|
||||
|
||||
def _train_ensemble(self, inputs, targets):
|
||||
if self._config.disag_offset:
|
||||
targets = targets[:, self._config.disag_offset:]
|
||||
inputs = inputs[:, :-self._config.disag_offset]
|
||||
targets = tf.stop_gradient(targets)
|
||||
inputs = tf.stop_gradient(inputs)
|
||||
with tf.GradientTape() as tape:
|
||||
preds = [head(inputs) for head in self._networks]
|
||||
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
|
||||
loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
|
||||
metrics = self._opt(tape, loss, self._networks)
|
||||
return metrics
|
429
DreamerV2/models.py
Normal file
429
DreamerV2/models.py
Normal file
@ -0,0 +1,429 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
|
||||
import networks
|
||||
import tools
|
||||
|
||||
|
||||
class WorldModel(tools.Module):
|
||||
|
||||
def __init__(self, step, config):
|
||||
self._step = step
|
||||
self._config = config
|
||||
channels = (1 if config.atari_grayscale else 3)
|
||||
shape = config.size + (channels,)
|
||||
|
||||
########
|
||||
# Main #
|
||||
########
|
||||
self.encoder = networks.ConvEncoder(
|
||||
config.cnn_depth, config.act, config.encoder_kernels)
|
||||
self.dynamics = networks.RSSM(
|
||||
config.dyn_stoch, config.dyn_deter, config.dyn_hidden,
|
||||
config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared,
|
||||
config.dyn_discrete, config.act, config.dyn_mean_act,
|
||||
config.dyn_std_act, config.dyn_min_std, config.dyn_cell)
|
||||
self.heads = {}
|
||||
self.heads['reward'] = networks.DenseHead(
|
||||
[], config.reward_layers, config.units, config.act)
|
||||
if config.pred_discount:
|
||||
self.heads['discount'] = networks.DenseHead(
|
||||
[], config.discount_layers, config.units, config.act, dist='binary')
|
||||
self._model_opt = tools.Optimizer(
|
||||
'model', config.model_lr, config.opt_eps, config.grad_clip,
|
||||
config.weight_decay, opt=config.opt)
|
||||
self._scales = dict(
|
||||
reward=config.reward_scale, discount=config.discount_scale)
|
||||
|
||||
#########
|
||||
# Disen #
|
||||
#########
|
||||
self.disen_encoder = networks.ConvEncoder(
|
||||
config.disen_cnn_depth, config.act, config.encoder_kernels)
|
||||
self.disen_dynamics = networks.RSSM(
|
||||
config.dyn_stoch, config.dyn_deter, config.dyn_hidden,
|
||||
config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared,
|
||||
config.dyn_discrete, config.act, config.dyn_mean_act,
|
||||
config.dyn_std_act, config.dyn_min_std, config.dyn_cell)
|
||||
|
||||
self.disen_heads = {}
|
||||
self.disen_heads['reward'] = networks.DenseHead(
|
||||
[], config.reward_layers, config.units, config.act)
|
||||
if config.pred_discount:
|
||||
self.disen_heads['discount'] = networks.DenseHead(
|
||||
[], config.discount_layers, config.units, config.act, dist='binary')
|
||||
|
||||
self._disen_model_opt = tools.Optimizer(
|
||||
'disen', config.model_lr, config.opt_eps, config.grad_clip,
|
||||
config.weight_decay, opt=config.opt)
|
||||
|
||||
self._disen_heads_opt = {}
|
||||
self._disen_heads_opt['reward'] = tools.Optimizer(
|
||||
'disen_reward', config.model_lr, config.opt_eps, config.grad_clip,
|
||||
config.weight_decay, opt=config.opt)
|
||||
if config.pred_discount:
|
||||
self._disen_heads_opt['discount'] = tools.Optimizer(
|
||||
'disen_pcont', config.model_lr, config.opt_eps, config.grad_clip,
|
||||
config.weight_decay, opt=config.opt)
|
||||
|
||||
# negative signs for reward/discount here
|
||||
self._disen_scales = dict(disen_only=config.disen_only_scale,
|
||||
reward=-config.disen_reward_scale, discount=-config.disen_discount_scale)
|
||||
|
||||
self.disen_only_image_head = networks.ConvDecoder(
|
||||
config.disen_cnn_depth, config.act, shape, config.decoder_kernels,
|
||||
config.decoder_thin)
|
||||
|
||||
################
|
||||
# Joint Decode #
|
||||
################
|
||||
self.image_head = networks.ConvDecoderMask(
|
||||
config.cnn_depth, config.act, shape, config.decoder_kernels,
|
||||
config.decoder_thin)
|
||||
self.disen_image_head = networks.ConvDecoderMask(
|
||||
config.disen_cnn_depth, config.act, shape, config.decoder_kernels,
|
||||
config.decoder_thin)
|
||||
self.joint_image_head = networks.ConvDecoderMaskEnsemble(
|
||||
self.image_head, self.disen_image_head
|
||||
)
|
||||
|
||||
def train(self, data):
|
||||
data = self.preprocess(data)
|
||||
with tf.GradientTape() as model_tape, tf.GradientTape() as disen_tape:
|
||||
|
||||
# kl schedule
|
||||
kl_balance = tools.schedule(self._config.kl_balance, self._step)
|
||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||
kl_scale = tools.schedule(self._config.kl_scale, self._step)
|
||||
|
||||
# Main
|
||||
embed = self.encoder(data)
|
||||
post, prior = self.dynamics.observe(embed, data['action'])
|
||||
kl_loss, kl_value = self.dynamics.kl_loss(
|
||||
post, prior, kl_balance, kl_free, kl_scale)
|
||||
feat = self.dynamics.get_feat(post)
|
||||
likes = {}
|
||||
for name, head in self.heads.items():
|
||||
grad_head = (name in self._config.grad_heads)
|
||||
inp = feat if grad_head else tf.stop_gradient(feat)
|
||||
pred = head(inp, tf.float32)
|
||||
like = pred.log_prob(tf.cast(data[name], tf.float32))
|
||||
likes[name] = tf.reduce_mean(
|
||||
like) * self._scales.get(name, 1.0)
|
||||
|
||||
# Disen
|
||||
embed_disen = self.disen_encoder(data)
|
||||
post_disen, prior_disen = self.disen_dynamics.observe(
|
||||
embed_disen, data['action'])
|
||||
kl_loss_disen, kl_value_disen = self.dynamics.kl_loss(
|
||||
post_disen, prior_disen, kl_balance, kl_free, kl_scale)
|
||||
feat_disen = self.disen_dynamics.get_feat(post_disen)
|
||||
|
||||
# Optimize disen reward/pcont till optimal
|
||||
disen_metrics = dict(reward={}, discount={})
|
||||
loss_disen = dict(reward=None, discount=None)
|
||||
for _ in range(self._config.num_reward_opt_iters):
|
||||
with tf.GradientTape() as disen_reward_tape, tf.GradientTape() as disen_pcont_tape:
|
||||
disen_gradient_tapes = dict(
|
||||
reward=disen_reward_tape, discount=disen_pcont_tape)
|
||||
for name, head in self.disen_heads.items():
|
||||
pred_disen = head(
|
||||
tf.stop_gradient(feat_disen), tf.float32)
|
||||
loss_disen[name] = -tf.reduce_mean(pred_disen.log_prob(
|
||||
tf.cast(data[name], tf.float32)))
|
||||
for name, head in self.disen_heads.items():
|
||||
disen_metrics[name] = self._disen_heads_opt[name](
|
||||
disen_gradient_tapes[name], loss_disen[name], [head], prefix='disen_neg')
|
||||
|
||||
# Compute likes for disen model (including negative gradients)
|
||||
likes_disen = {}
|
||||
for name, head in self.disen_heads.items():
|
||||
pred_disen = head(feat_disen, tf.float32)
|
||||
like_disen = pred_disen.log_prob(
|
||||
tf.cast(data[name], tf.float32))
|
||||
likes_disen[name] = tf.reduce_mean(
|
||||
like_disen) * self._disen_scales.get(name, -1.0)
|
||||
disen_only_image_pred = self.disen_only_image_head(
|
||||
feat_disen, tf.float32)
|
||||
disen_only_image_like = tf.reduce_mean(disen_only_image_pred.log_prob(
|
||||
tf.cast(data['image'], tf.float32))) * self._disen_scales.get('disen_only', 1.0)
|
||||
likes_disen['disen_only'] = disen_only_image_like
|
||||
|
||||
# Joint decode
|
||||
image_pred_joint, _, _, _ = self.joint_image_head(
|
||||
feat, feat_disen, tf.float32)
|
||||
image_like = tf.reduce_mean(image_pred_joint.log_prob(
|
||||
tf.cast(data['image'], tf.float32)))
|
||||
likes['image'] = image_like
|
||||
likes_disen['image'] = image_like
|
||||
|
||||
# Compute loss
|
||||
model_loss = kl_loss - sum(likes.values())
|
||||
disen_loss = kl_loss_disen - sum(likes_disen.values())
|
||||
|
||||
model_parts = [self.encoder, self.dynamics,
|
||||
self.joint_image_head] + list(self.heads.values())
|
||||
disen_parts = [self.disen_encoder, self.disen_dynamics,
|
||||
self.joint_image_head, self.disen_only_image_head]
|
||||
|
||||
metrics = self._model_opt(
|
||||
model_tape, model_loss, model_parts, prefix='main')
|
||||
disen_model_metrics = self._disen_model_opt(
|
||||
disen_tape, disen_loss, disen_parts, prefix='disen')
|
||||
|
||||
metrics['kl_balance'] = kl_balance
|
||||
metrics['kl_free'] = kl_free
|
||||
metrics['kl_scale'] = kl_scale
|
||||
metrics.update({f'{name}_loss': -like for name,
|
||||
like in likes.items()})
|
||||
|
||||
metrics['disen/disen_only_image_loss'] = -disen_only_image_like
|
||||
metrics['disen/disen_reward_loss'] = -likes_disen['reward'] / \
|
||||
self._disen_scales.get('reward', -1.0)
|
||||
metrics['disen/disen_discount_loss'] = -likes_disen['discount'] / \
|
||||
self._disen_scales.get('discount', -1.0)
|
||||
|
||||
metrics['kl'] = tf.reduce_mean(kl_value)
|
||||
metrics['prior_ent'] = self.dynamics.get_dist(prior).entropy()
|
||||
metrics['post_ent'] = self.dynamics.get_dist(post).entropy()
|
||||
metrics['disen/kl'] = tf.reduce_mean(kl_value_disen)
|
||||
metrics['disen/prior_ent'] = self.dynamics.get_dist(
|
||||
prior_disen).entropy()
|
||||
metrics['disen/post_ent'] = self.dynamics.get_dist(
|
||||
post_disen).entropy()
|
||||
|
||||
metrics.update(
|
||||
{f'{key}': value for key, value in disen_metrics['reward'].items()})
|
||||
metrics.update(
|
||||
{f'{key}': value for key, value in disen_metrics['discount'].items()})
|
||||
metrics.update(
|
||||
{f'{key}': value for key, value in disen_model_metrics.items()})
|
||||
|
||||
return embed, post, feat, kl_value, metrics
|
||||
|
||||
@tf.function
|
||||
def preprocess(self, obs):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
obs = obs.copy()
|
||||
obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5
|
||||
obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward'])
|
||||
if 'discount' in obs:
|
||||
obs['discount'] *= self._config.discount
|
||||
for key, value in obs.items():
|
||||
if tf.dtypes.as_dtype(value.dtype) in (
|
||||
tf.float16, tf.float32, tf.float64):
|
||||
obs[key] = tf.cast(value, dtype)
|
||||
return obs
|
||||
|
||||
@tf.function
|
||||
def video_pred(self, data):
|
||||
data = self.preprocess(data)
|
||||
truth = data['image'][:6] + 0.5
|
||||
|
||||
embed = self.encoder(data)
|
||||
embed_disen = self.disen_encoder(data)
|
||||
states, _ = self.dynamics.observe(
|
||||
embed[:6, :5], data['action'][:6, :5])
|
||||
states_disen, _ = self.disen_dynamics.observe(
|
||||
embed_disen[:6, :5], data['action'][:6, :5])
|
||||
feats = self.dynamics.get_feat(states)
|
||||
feats_disen = self.disen_dynamics.get_feat(states_disen)
|
||||
recon_joint, recon_main, recon_disen, recon_mask = self.joint_image_head(
|
||||
feats, feats_disen)
|
||||
recon_joint = recon_joint.mode()[:6]
|
||||
recon_main = recon_main.mode()[:6]
|
||||
recon_disen = recon_disen.mode()[:6]
|
||||
recon_mask = recon_mask[:6]
|
||||
|
||||
init = {k: v[:, -1] for k, v in states.items()}
|
||||
init_disen = {k: v[:, -1] for k, v in states_disen.items()}
|
||||
prior = self.dynamics.imagine(
|
||||
data['action'][:6, 5:], init)
|
||||
prior_disen = self.disen_dynamics.imagine(
|
||||
data['action'][:6, 5:], init_disen)
|
||||
_feats = self.dynamics.get_feat(prior)
|
||||
_feats_disen = self.disen_dynamics.get_feat(prior_disen)
|
||||
openl_joint, openl_main, openl_disen, openl_mask = self.joint_image_head(
|
||||
_feats, _feats_disen)
|
||||
openl_joint = openl_joint.mode()
|
||||
openl_main = openl_main.mode()
|
||||
openl_disen = openl_disen.mode()
|
||||
|
||||
model_joint = tf.concat(
|
||||
[recon_joint[:, :5] + 0.5, openl_joint + 0.5], 1)
|
||||
error_joint = (model_joint - truth + 1) / 2
|
||||
model_main = tf.concat(
|
||||
[recon_main[:, :5] + 0.5, openl_main + 0.5], 1)
|
||||
error_main = (model_main - truth + 1) / 2
|
||||
model_disen = tf.concat(
|
||||
[recon_disen[:, :5] + 0.5, openl_disen + 0.5], 1)
|
||||
error_disen = (model_disen - truth + 1) / 2
|
||||
model_mask = tf.concat(
|
||||
[recon_mask[:, :5] + 0.5, openl_mask + 0.5], 1)
|
||||
|
||||
output_joint = tf.concat([truth, model_joint, error_joint], 2)
|
||||
output_main = tf.concat([truth, model_main, error_main], 2)
|
||||
output_disen = tf.concat([truth, model_disen, error_disen], 2)
|
||||
output_mask = model_mask
|
||||
|
||||
return output_joint, output_main, output_disen, output_mask
|
||||
|
||||
|
||||
class ImagBehavior(tools.Module):
|
||||
|
||||
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
|
||||
self._config = config
|
||||
self._world_model = world_model
|
||||
self._stop_grad_actor = stop_grad_actor
|
||||
self._reward = reward
|
||||
self.actor = networks.ActionHead(
|
||||
config.num_actions, config.actor_layers, config.units, config.act,
|
||||
config.actor_dist, config.actor_init_std, config.actor_min_std,
|
||||
config.actor_dist, config.actor_temp, config.actor_outscale)
|
||||
self.value = networks.DenseHead(
|
||||
[], config.value_layers, config.units, config.act,
|
||||
config.value_head)
|
||||
if config.slow_value_target or config.slow_actor_target:
|
||||
self._slow_value = networks.DenseHead(
|
||||
[], config.value_layers, config.units, config.act)
|
||||
self._updates = tf.Variable(0, tf.int64)
|
||||
kw = dict(wd=config.weight_decay, opt=config.opt)
|
||||
self._actor_opt = tools.Optimizer(
|
||||
'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw)
|
||||
self._value_opt = tools.Optimizer(
|
||||
'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw)
|
||||
|
||||
def train(
|
||||
self, start, objective=None, imagine=None, tape=None, repeats=None):
|
||||
objective = objective or self._reward
|
||||
self._update_slow_target()
|
||||
metrics = {}
|
||||
with (tape or tf.GradientTape()) as actor_tape:
|
||||
assert bool(objective) != bool(imagine)
|
||||
if objective:
|
||||
imag_feat, imag_state, imag_action = self._imagine(
|
||||
start, self.actor, self._config.imag_horizon, repeats)
|
||||
reward = objective(imag_feat, imag_state, imag_action)
|
||||
else:
|
||||
imag_feat, imag_state, imag_action, reward = imagine(start)
|
||||
actor_ent = self.actor(imag_feat, tf.float32).entropy()
|
||||
state_ent = self._world_model.dynamics.get_dist(
|
||||
imag_state, tf.float32).entropy()
|
||||
target, weights = self._compute_target(
|
||||
imag_feat, reward, actor_ent, state_ent,
|
||||
self._config.slow_actor_target)
|
||||
actor_loss, mets = self._compute_actor_loss(
|
||||
imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
|
||||
weights)
|
||||
metrics.update(mets)
|
||||
if self._config.slow_value_target != self._config.slow_actor_target:
|
||||
target, weights = self._compute_target(
|
||||
imag_feat, reward, actor_ent, state_ent,
|
||||
self._config.slow_value_target)
|
||||
with tf.GradientTape() as value_tape:
|
||||
value = self.value(imag_feat, tf.float32)[:-1]
|
||||
value_loss = -value.log_prob(tf.stop_gradient(target))
|
||||
if self._config.value_decay:
|
||||
value_loss += self._config.value_decay * value.mode()
|
||||
value_loss = tf.reduce_mean(weights[:-1] * value_loss)
|
||||
metrics['reward_mean'] = tf.reduce_mean(reward)
|
||||
metrics['reward_std'] = tf.math.reduce_std(reward)
|
||||
metrics['actor_ent'] = tf.reduce_mean(actor_ent)
|
||||
metrics.update(self._actor_opt(actor_tape, actor_loss, [self.actor]))
|
||||
metrics.update(self._value_opt(value_tape, value_loss, [self.value]))
|
||||
return imag_feat, imag_state, imag_action, weights, metrics
|
||||
|
||||
def _imagine(self, start, policy, horizon, repeats=None):
|
||||
dynamics = self._world_model.dynamics
|
||||
if repeats:
|
||||
start = {k: tf.repeat(v, repeats, axis=1)
|
||||
for k, v in start.items()}
|
||||
|
||||
def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:]))
|
||||
start = {k: flatten(v) for k, v in start.items()}
|
||||
|
||||
def step(prev, _):
|
||||
state, _, _ = prev
|
||||
feat = dynamics.get_feat(state)
|
||||
inp = tf.stop_gradient(feat) if self._stop_grad_actor else feat
|
||||
action = policy(inp).sample()
|
||||
succ = dynamics.img_step(
|
||||
state, action, sample=self._config.imag_sample)
|
||||
return succ, feat, action
|
||||
feat = 0 * dynamics.get_feat(start)
|
||||
action = policy(feat).mode()
|
||||
succ, feats, actions = tools.static_scan(
|
||||
step, tf.range(horizon), (start, feat, action))
|
||||
states = {k: tf.concat([
|
||||
start[k][None], v[:-1]], 0) for k, v in succ.items()}
|
||||
if repeats:
|
||||
def unfold(tensor):
|
||||
s = tensor.shape
|
||||
return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:])
|
||||
states, feats, actions = tf.nest.map_structure(
|
||||
unfold, (states, feats, actions))
|
||||
return feats, states, actions
|
||||
|
||||
def _compute_target(self, imag_feat, reward, actor_ent, state_ent, slow):
|
||||
reward = tf.cast(reward, tf.float32)
|
||||
if 'discount' in self._world_model.heads:
|
||||
discount = self._world_model.heads['discount'](
|
||||
imag_feat, tf.float32).mean()
|
||||
else:
|
||||
discount = self._config.discount * tf.ones_like(reward)
|
||||
if self._config.future_entropy and tf.greater(
|
||||
self._config.actor_entropy(), 0):
|
||||
reward += self._config.actor_entropy() * actor_ent
|
||||
if self._config.future_entropy and tf.greater(
|
||||
self._config.actor_state_entropy(), 0):
|
||||
reward += self._config.actor_state_entropy() * state_ent
|
||||
if slow:
|
||||
value = self._slow_value(imag_feat, tf.float32).mode()
|
||||
else:
|
||||
value = self.value(imag_feat, tf.float32).mode()
|
||||
target = tools.lambda_return(
|
||||
reward[:-1], value[:-1], discount[:-1],
|
||||
bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0)
|
||||
weights = tf.stop_gradient(tf.math.cumprod(tf.concat(
|
||||
[tf.ones_like(discount[:1]), discount[:-1]], 0), 0))
|
||||
return target, weights
|
||||
|
||||
def _compute_actor_loss(
|
||||
self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
|
||||
weights):
|
||||
metrics = {}
|
||||
inp = tf.stop_gradient(
|
||||
imag_feat) if self._stop_grad_actor else imag_feat
|
||||
policy = self.actor(inp, tf.float32)
|
||||
actor_ent = policy.entropy()
|
||||
if self._config.imag_gradient == 'dynamics':
|
||||
actor_target = target
|
||||
elif self._config.imag_gradient == 'reinforce':
|
||||
imag_action = tf.cast(imag_action, tf.float32)
|
||||
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
|
||||
target - self.value(imag_feat[:-1], tf.float32).mode())
|
||||
elif self._config.imag_gradient == 'both':
|
||||
imag_action = tf.cast(imag_action, tf.float32)
|
||||
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
|
||||
target - self.value(imag_feat[:-1], tf.float32).mode())
|
||||
mix = self._config.imag_gradient_mix()
|
||||
actor_target = mix * target + (1 - mix) * actor_target
|
||||
metrics['imag_gradient_mix'] = mix
|
||||
else:
|
||||
raise NotImplementedError(self._config.imag_gradient)
|
||||
if not self._config.future_entropy and tf.greater(
|
||||
self._config.actor_entropy(), 0):
|
||||
actor_target += self._config.actor_entropy() * actor_ent[:-1]
|
||||
if not self._config.future_entropy and tf.greater(
|
||||
self._config.actor_state_entropy(), 0):
|
||||
actor_target += self._config.actor_state_entropy() * state_ent[:-1]
|
||||
actor_loss = -tf.reduce_mean(weights[:-1] * actor_target)
|
||||
return actor_loss, metrics
|
||||
|
||||
def _update_slow_target(self):
|
||||
if self._config.slow_value_target or self._config.slow_actor_target:
|
||||
if self._updates % self._config.slow_target_update == 0:
|
||||
mix = self._config.slow_target_fraction
|
||||
for s, d in zip(self.value.variables, self._slow_value.variables):
|
||||
d.assign(mix * s + (1 - mix) * d)
|
||||
self._updates.assign_add(1)
|
465
DreamerV2/networks.py
Normal file
465
DreamerV2/networks.py
Normal file
@ -0,0 +1,465 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers as tfkl
|
||||
from tensorflow_probability import distributions as tfd
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
|
||||
import tools
|
||||
|
||||
class RSSM(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1,
|
||||
shared=False, discrete=False, act=tf.nn.elu, mean_act='none',
|
||||
std_act='softplus', min_std=0.1, cell='keras'):
|
||||
super().__init__()
|
||||
self._stoch = stoch
|
||||
self._deter = deter
|
||||
self._hidden = hidden
|
||||
self._min_std = min_std
|
||||
self._layers_input = layers_input
|
||||
self._layers_output = layers_output
|
||||
self._shared = shared
|
||||
self._discrete = discrete
|
||||
self._act = act
|
||||
self._mean_act = mean_act
|
||||
self._std_act = std_act
|
||||
self._embed = None
|
||||
if cell == 'gru':
|
||||
self._cell = tfkl.GRUCell(self._deter)
|
||||
elif cell == 'gru_layer_norm':
|
||||
self._cell = GRUCell(self._deter, norm=True)
|
||||
else:
|
||||
raise NotImplementedError(cell)
|
||||
|
||||
def initial(self, batch_size):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
if self._discrete:
|
||||
state = dict(
|
||||
logit=tf.zeros(
|
||||
[batch_size, self._stoch, self._discrete], dtype),
|
||||
stoch=tf.zeros(
|
||||
[batch_size, self._stoch, self._discrete], dtype),
|
||||
deter=self._cell.get_initial_state(None, batch_size, dtype))
|
||||
else:
|
||||
state = dict(
|
||||
mean=tf.zeros([batch_size, self._stoch], dtype),
|
||||
std=tf.zeros([batch_size, self._stoch], dtype),
|
||||
stoch=tf.zeros([batch_size, self._stoch], dtype),
|
||||
deter=self._cell.get_initial_state(None, batch_size, dtype))
|
||||
return state
|
||||
|
||||
@tf.function
|
||||
def observe(self, embed, action, state=None):
|
||||
def swap(x): return tf.transpose(
|
||||
x, [1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(tf.shape(action)[0])
|
||||
embed, action = swap(embed), swap(action)
|
||||
post, prior = tools.static_scan(
|
||||
lambda prev, inputs: self.obs_step(prev[0], *inputs),
|
||||
(action, embed), (state, state))
|
||||
post = {k: swap(v) for k, v in post.items()}
|
||||
prior = {k: swap(v) for k, v in prior.items()}
|
||||
return post, prior
|
||||
|
||||
@tf.function
|
||||
def imagine(self, action, state=None):
|
||||
def swap(x): return tf.transpose(
|
||||
x, [1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(tf.shape(action)[0])
|
||||
assert isinstance(state, dict), state
|
||||
action = swap(action)
|
||||
prior = tools.static_scan(self.img_step, action, state)
|
||||
prior = {k: swap(v) for k, v in prior.items()}
|
||||
return prior
|
||||
|
||||
def get_feat(self, state):
|
||||
stoch = state['stoch']
|
||||
if self._discrete:
|
||||
shape = stoch.shape[:-2] + [self._stoch * self._discrete]
|
||||
stoch = tf.reshape(stoch, shape)
|
||||
return tf.concat([stoch, state['deter']], -1)
|
||||
|
||||
def get_dist(self, state, dtype=None):
|
||||
if self._discrete:
|
||||
logit = state['logit']
|
||||
logit = tf.cast(logit, tf.float32)
|
||||
dist = tfd.Independent(tools.OneHotDist(logit), 1)
|
||||
if dtype != tf.float32:
|
||||
dist = tools.DtypeDist(dist, dtype or state['logit'].dtype)
|
||||
else:
|
||||
mean, std = state['mean'], state['std']
|
||||
if dtype:
|
||||
mean = tf.cast(mean, dtype)
|
||||
std = tf.cast(std, dtype)
|
||||
dist = tfd.MultivariateNormalDiag(mean, std)
|
||||
return dist
|
||||
|
||||
@tf.function
|
||||
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
||||
if not self._embed:
|
||||
self._embed = embed.shape[-1]
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
else:
|
||||
x = tf.concat([prior['deter'], embed], -1)
|
||||
for i in range(self._layers_output):
|
||||
x = self.get(f'obi{i}', tfkl.Dense, self._hidden, self._act)(x)
|
||||
stats = self._suff_stats_layer('obs', x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
stoch = self.get_dist(stats).mode()
|
||||
post = {'stoch': stoch, 'deter': prior['deter'], **stats}
|
||||
return post, prior
|
||||
|
||||
@tf.function
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
prev_stoch = prev_state['stoch']
|
||||
if self._discrete:
|
||||
shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
|
||||
prev_stoch = tf.reshape(prev_stoch, shape)
|
||||
if self._shared:
|
||||
if embed is None:
|
||||
shape = prev_action.shape[:-1] + [self._embed]
|
||||
embed = tf.zeros(shape, prev_action.dtype)
|
||||
x = tf.concat([prev_stoch, prev_action, embed], -1)
|
||||
else:
|
||||
x = tf.concat([prev_stoch, prev_action], -1)
|
||||
for i in range(self._layers_input):
|
||||
x = self.get(f'ini{i}', tfkl.Dense, self._hidden, self._act)(x)
|
||||
x, deter = self._cell(x, [prev_state['deter']])
|
||||
deter = deter[0] # Keras wraps the state in a list.
|
||||
for i in range(self._layers_output):
|
||||
x = self.get(f'imo{i}', tfkl.Dense, self._hidden, self._act)(x)
|
||||
stats = self._suff_stats_layer('ims', x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
stoch = self.get_dist(stats).mode()
|
||||
prior = {'stoch': stoch, 'deter': deter, **stats}
|
||||
return prior
|
||||
|
||||
def _suff_stats_layer(self, name, x):
|
||||
if self._discrete:
|
||||
x = self.get(name, tfkl.Dense, self._stoch *
|
||||
self._discrete, None)(x)
|
||||
logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
|
||||
return {'logit': logit}
|
||||
else:
|
||||
x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = {
|
||||
'none': lambda: mean,
|
||||
'tanh5': lambda: 5.0 * tf.math.tanh(mean / 5.0),
|
||||
}[self._mean_act]()
|
||||
std = {
|
||||
'softplus': lambda: tf.nn.softplus(std),
|
||||
'abs': lambda: tf.math.abs(std + 1),
|
||||
'sigmoid': lambda: tf.nn.sigmoid(std),
|
||||
'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2),
|
||||
}[self._std_act]()
|
||||
std = std + self._min_std
|
||||
return {'mean': mean, 'std': std}
|
||||
|
||||
def kl_loss(self, post, prior, balance, free, scale):
|
||||
kld = tfd.kl_divergence
|
||||
def dist(x): return self.get_dist(x, tf.float32)
|
||||
if balance == 0.5:
|
||||
value = kld(dist(prior), dist(post))
|
||||
loss = tf.reduce_mean(tf.maximum(value, free))
|
||||
else:
|
||||
def sg(x): return tf.nest.map_structure(tf.stop_gradient, x)
|
||||
value = kld(dist(prior), dist(sg(post)))
|
||||
pri = tf.reduce_mean(value)
|
||||
pos = tf.reduce_mean(kld(dist(sg(prior)), dist(post)))
|
||||
pri, pos = tf.maximum(pri, free), tf.maximum(pos, free)
|
||||
loss = balance * pri + (1 - balance) * pos
|
||||
loss *= scale
|
||||
return loss, value
|
||||
|
||||
|
||||
class ConvEncoder(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, depth=32, act=tf.nn.relu, kernels=(4, 4, 4, 4)):
|
||||
self._act = act
|
||||
self._depth = depth
|
||||
self._kernels = kernels
|
||||
|
||||
def __call__(self, obs):
|
||||
kwargs = dict(strides=2, activation=self._act)
|
||||
Conv = tfkl.Conv2D
|
||||
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
|
||||
x = self.get('h1', Conv, 1 * self._depth,
|
||||
self._kernels[0], **kwargs)(x)
|
||||
x = self.get('h2', Conv, 2 * self._depth,
|
||||
self._kernels[1], **kwargs)(x)
|
||||
x = self.get('h3', Conv, 4 * self._depth,
|
||||
self._kernels[2], **kwargs)(x)
|
||||
x = self.get('h4', Conv, 8 * self._depth,
|
||||
self._kernels[3], **kwargs)(x)
|
||||
x = tf.reshape(x, [x.shape[0], np.prod(x.shape[1:])])
|
||||
shape = tf.concat([tf.shape(obs['image'])[:-3], [x.shape[-1]]], 0)
|
||||
return tf.reshape(x, shape)
|
||||
|
||||
|
||||
class ConvDecoder(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
|
||||
thin=True):
|
||||
self._act = act
|
||||
self._depth = depth
|
||||
self._shape = shape
|
||||
self._kernels = kernels
|
||||
self._thin = thin
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
kwargs = dict(strides=2, activation=self._act)
|
||||
ConvT = tfkl.Conv2DTranspose
|
||||
if self._thin:
|
||||
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
|
||||
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
|
||||
else:
|
||||
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
|
||||
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
|
||||
x = self.get('h2', ConvT, 4 * self._depth,
|
||||
self._kernels[0], **kwargs)(x)
|
||||
x = self.get('h3', ConvT, 2 * self._depth,
|
||||
self._kernels[1], **kwargs)(x)
|
||||
x = self.get('h4', ConvT, 1 * self._depth,
|
||||
self._kernels[2], **kwargs)(x)
|
||||
x = self.get(
|
||||
'h5', ConvT, self._shape[-1], self._kernels[3], strides=2)(x)
|
||||
mean = tf.reshape(x, tf.concat(
|
||||
[tf.shape(features)[:-1], self._shape], 0))
|
||||
if dtype:
|
||||
mean = tf.cast(mean, dtype)
|
||||
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
|
||||
|
||||
|
||||
class ConvDecoderMask(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
|
||||
thin=True):
|
||||
self._act = act
|
||||
self._depth = depth
|
||||
self._shape = shape
|
||||
self._kernels = kernels
|
||||
self._thin = thin
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
kwargs = dict(strides=2, activation=self._act)
|
||||
ConvT = tfkl.Conv2DTranspose
|
||||
if self._thin:
|
||||
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
|
||||
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
|
||||
else:
|
||||
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
|
||||
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
|
||||
x = self.get('h2', ConvT, 4 * self._depth,
|
||||
self._kernels[0], **kwargs)(x)
|
||||
x = self.get('h3', ConvT, 2 * self._depth,
|
||||
self._kernels[1], **kwargs)(x)
|
||||
x = self.get('h4', ConvT, 1 * self._depth,
|
||||
self._kernels[2], **kwargs)(x)
|
||||
x = self.get(
|
||||
'h5', ConvT, 2 * self._shape[-1], self._kernels[3], strides=2)(x)
|
||||
mean, mask = tf.split(x, [self._shape[-1], self._shape[-1]], -1)
|
||||
mean = tf.reshape(mean, tf.concat(
|
||||
[tf.shape(features)[:-1], self._shape], 0))
|
||||
mask = tf.reshape(mask, tf.concat(
|
||||
[tf.shape(features)[:-1], self._shape], 0))
|
||||
if dtype:
|
||||
mean = tf.cast(mean, dtype)
|
||||
mask = tf.cast(mask, dtype)
|
||||
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask
|
||||
|
||||
|
||||
class ConvDecoderMaskEnsemble(tools.Module):
|
||||
"""
|
||||
ensemble two convdecoder with <Normal, mask> outputs
|
||||
NOTE: remove pred1/pred2 for maximum performance.
|
||||
"""
|
||||
|
||||
def __init__(self, decoder1, decoder2):
|
||||
self._decoder1 = decoder1
|
||||
self._decoder2 = decoder2
|
||||
self._shape = decoder1._shape
|
||||
|
||||
def __call__(self, feat1, feat2, dtype=None):
|
||||
kwargs = dict(strides=1, activation=tf.nn.sigmoid)
|
||||
pred1, mask1 = self._decoder1(feat1, dtype)
|
||||
pred2, mask2 = self._decoder2(feat2, dtype)
|
||||
mean1 = pred1.submodules[0].loc
|
||||
mean2 = pred2.submodules[0].loc
|
||||
mask_feat = tf.concat([mask1, mask2], -1)
|
||||
mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat)
|
||||
mask_use1 = mask
|
||||
mask_use2 = 1-mask
|
||||
mean = mean1 * tf.cast(mask_use1, mean1.dtype) + \
|
||||
mean2 * tf.cast(mask_use2, mean2.dtype)
|
||||
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, mean1.dtype)
|
||||
|
||||
|
||||
class DenseHead(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, shape, layers, units, act=tf.nn.elu, dist='normal', std=1.0):
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._act = act
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
for index in range(self._layers):
|
||||
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
|
||||
mean = self.get(f'hmean', tfkl.Dense, np.prod(self._shape))(x)
|
||||
mean = tf.reshape(mean, tf.concat(
|
||||
[tf.shape(features)[:-1], self._shape], 0))
|
||||
if self._std == 'learned':
|
||||
std = self.get(f'hstd', tfkl.Dense, np.prod(self._shape))(x)
|
||||
std = tf.nn.softplus(std) + 0.01
|
||||
std = tf.reshape(std, tf.concat(
|
||||
[tf.shape(features)[:-1], self._shape], 0))
|
||||
else:
|
||||
std = self._std
|
||||
if dtype:
|
||||
mean, std = tf.cast(mean, dtype), tf.cast(std, dtype)
|
||||
if self._dist == 'normal':
|
||||
return tfd.Independent(tfd.Normal(mean, std), len(self._shape))
|
||||
if self._dist == 'huber':
|
||||
return tfd.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape))
|
||||
if self._dist == 'binary':
|
||||
return tfd.Independent(tfd.Bernoulli(mean), len(self._shape))
|
||||
raise NotImplementedError(self._dist)
|
||||
|
||||
|
||||
class ActionHead(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, size, layers, units, act=tf.nn.elu, dist='trunc_normal',
|
||||
init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0):
|
||||
# assert min_std <= 2
|
||||
self._size = size
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._dist = dist
|
||||
self._act = act
|
||||
self._min_std = min_std
|
||||
self._init_std = init_std
|
||||
self._action_disc = action_disc
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
self._outscale = outscale
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
for index in range(self._layers):
|
||||
kw = {}
|
||||
if index == self._layers - 1 and self._outscale:
|
||||
kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
|
||||
self._outscale)
|
||||
x = self.get(f'h{index}', tfkl.Dense,
|
||||
self._units, self._act, **kw)(x)
|
||||
if self._dist == 'tanh_normal':
|
||||
# https://www.desmos.com/calculator/rcmcf5jwe7
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = tf.tanh(mean)
|
||||
std = tf.nn.softplus(std + self._init_std) + self._min_std
|
||||
dist = tfd.Normal(mean, std)
|
||||
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
|
||||
dist = tfd.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == 'tanh_normal_5':
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = 5 * tf.tanh(mean / 5)
|
||||
std = tf.nn.softplus(std + 5) + 5
|
||||
dist = tfd.Normal(mean, std)
|
||||
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
|
||||
dist = tfd.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == 'normal':
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
std = tf.nn.softplus(std + self._init_std) + self._min_std
|
||||
dist = tfd.Normal(mean, std)
|
||||
dist = tfd.Independent(dist, 1)
|
||||
elif self._dist == 'normal_1':
|
||||
mean = self.get(f'hout', tfkl.Dense, self._size)(x)
|
||||
if dtype:
|
||||
mean = tf.cast(mean, dtype)
|
||||
dist = tfd.Normal(mean, 1)
|
||||
dist = tfd.Independent(dist, 1)
|
||||
elif self._dist == 'trunc_normal':
|
||||
# https://www.desmos.com/calculator/mmuvuhnyxo
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
x = tf.cast(x, tf.float32)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = tf.tanh(mean)
|
||||
std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
|
||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||
dist = tools.DtypeDist(dist, dtype)
|
||||
dist = tfd.Independent(dist, 1)
|
||||
elif self._dist == 'onehot':
|
||||
x = self.get(f'hout', tfkl.Dense, self._size)(x)
|
||||
x = tf.cast(x, tf.float32)
|
||||
dist = tools.OneHotDist(x, dtype=dtype)
|
||||
dist = tools.DtypeDist(dist, dtype)
|
||||
elif self._dist == 'onehot_gumble':
|
||||
x = self.get(f'hout', tfkl.Dense, self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
temp = self._temp
|
||||
dist = tools.GumbleDist(temp, x, dtype=dtype)
|
||||
else:
|
||||
raise NotImplementedError(self._dist)
|
||||
return dist
|
||||
|
||||
|
||||
class GRUCell(tf.keras.layers.AbstractRNNCell):
|
||||
|
||||
def __init__(self, size, norm=False, act=tf.tanh, update_bias=-1, **kwargs):
|
||||
super().__init__()
|
||||
self._size = size
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._update_bias = update_bias
|
||||
self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs)
|
||||
if norm:
|
||||
self._norm = tfkl.LayerNormalization(dtype=tf.float32)
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._size
|
||||
|
||||
def call(self, inputs, state):
|
||||
state = state[0] # Keras wraps the state in a list.
|
||||
parts = self._layer(tf.concat([inputs, state], -1))
|
||||
if self._norm:
|
||||
dtype = parts.dtype
|
||||
parts = tf.cast(parts, tf.float32)
|
||||
parts = self._norm(parts)
|
||||
parts = tf.cast(parts, dtype)
|
||||
reset, cand, update = tf.split(parts, 3, -1)
|
||||
reset = tf.nn.sigmoid(reset)
|
||||
cand = self._act(reset * cand)
|
||||
update = tf.nn.sigmoid(update + self._update_bias)
|
||||
output = update * cand + (1 - update) * state
|
||||
return output, [output]
|
694
DreamerV2/tools.py
Normal file
694
DreamerV2/tools.py
Normal file
@ -0,0 +1,694 @@
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf1
|
||||
import tensorflow_probability as tfp
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
from tensorflow_probability import distributions as tfd
|
||||
|
||||
|
||||
# Patch to ignore seed to avoid synchronization across GPUs.
|
||||
_orig_random_categorical = tf.random.categorical
|
||||
def random_categorical(*args, **kwargs):
|
||||
kwargs['seed'] = None
|
||||
return _orig_random_categorical(*args, **kwargs)
|
||||
tf.random.categorical = random_categorical
|
||||
|
||||
# Patch to ignore seed to avoid synchronization across GPUs.
|
||||
_orig_random_normal = tf.random.normal
|
||||
def random_normal(*args, **kwargs):
|
||||
kwargs['seed'] = None
|
||||
return _orig_random_normal(*args, **kwargs)
|
||||
tf.random.normal = random_normal
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__getattr__ = dict.__getitem__
|
||||
|
||||
|
||||
class Module(tf.Module):
|
||||
|
||||
def save(self, filename):
|
||||
values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
|
||||
amount = len(tf.nest.flatten(values))
|
||||
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
|
||||
print(f'Save checkpoint with {amount} tensors and {count} parameters.')
|
||||
with pathlib.Path(filename).open('wb') as f:
|
||||
pickle.dump(values, f)
|
||||
|
||||
def load(self, filename):
|
||||
with pathlib.Path(filename).open('rb') as f:
|
||||
values = pickle.load(f)
|
||||
amount = len(tf.nest.flatten(values))
|
||||
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
|
||||
print(f'Load checkpoint with {amount} tensors and {count} parameters.')
|
||||
tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)
|
||||
|
||||
def get(self, name, ctor, *args, **kwargs):
|
||||
# Create or get layer by name to avoid mentioning it in the constructor.
|
||||
if not hasattr(self, '_modules'):
|
||||
self._modules = {}
|
||||
if name not in self._modules:
|
||||
self._modules[name] = ctor(*args, **kwargs)
|
||||
return self._modules[name]
|
||||
|
||||
|
||||
def var_nest_names(nest):
|
||||
if isinstance(nest, dict):
|
||||
items = ' '.join(f'{k}:{var_nest_names(v)}' for k, v in nest.items())
|
||||
return '{' + items + '}'
|
||||
if isinstance(nest, (list, tuple)):
|
||||
items = ' '.join(var_nest_names(v) for v in nest)
|
||||
return '[' + items + ']'
|
||||
if hasattr(nest, 'name') and hasattr(nest, 'shape'):
|
||||
return nest.name + str(nest.shape).replace(', ', 'x')
|
||||
if hasattr(nest, 'shape'):
|
||||
return str(nest.shape).replace(', ', 'x')
|
||||
return '?'
|
||||
|
||||
|
||||
class Logger:
|
||||
|
||||
def __init__(self, logdir, step):
|
||||
self._logdir = logdir
|
||||
self._writer = tf.summary.create_file_writer(str(logdir), max_queue=1000)
|
||||
self._last_step = None
|
||||
self._last_time = None
|
||||
self._scalars = {}
|
||||
self._images = {}
|
||||
self._videos = {}
|
||||
self.step = step
|
||||
|
||||
def scalar(self, name, value):
|
||||
self._scalars[name] = float(value)
|
||||
|
||||
def image(self, name, value):
|
||||
self._images[name] = np.array(value)
|
||||
|
||||
def video(self, name, value):
|
||||
self._videos[name] = np.array(value)
|
||||
|
||||
def write(self, fps=False):
|
||||
scalars = list(self._scalars.items())
|
||||
if fps:
|
||||
scalars.append(('fps', self._compute_fps(self.step)))
|
||||
print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
|
||||
with (self._logdir / 'metrics.jsonl').open('a') as f:
|
||||
f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
|
||||
with self._writer.as_default():
|
||||
for name, value in scalars:
|
||||
tf.summary.scalar('scalars/' + name, value, self.step)
|
||||
for name, value in self._images.items():
|
||||
tf.summary.image(name, value, self.step)
|
||||
for name, value in self._videos.items():
|
||||
video_summary(name, value, self.step)
|
||||
self._writer.flush()
|
||||
self._scalars = {}
|
||||
self._images = {}
|
||||
self._videos = {}
|
||||
|
||||
def _compute_fps(self, step):
|
||||
if self._last_step is None:
|
||||
self._last_time = time.time()
|
||||
self._last_step = step
|
||||
return 0
|
||||
steps = step - self._last_step
|
||||
duration = time.time() - self._last_time
|
||||
self._last_time += duration
|
||||
self._last_step = step
|
||||
return steps / duration
|
||||
|
||||
|
||||
def graph_summary(writer, step, fn, *args):
|
||||
def inner(*args):
|
||||
tf.summary.experimental.set_step(step.numpy().item())
|
||||
with writer.as_default():
|
||||
fn(*args)
|
||||
return tf.numpy_function(inner, args, [])
|
||||
|
||||
|
||||
def video_summary(name, video, step=None, fps=20):
|
||||
name = name if isinstance(name, str) else name.decode('utf-8')
|
||||
if np.issubdtype(video.dtype, np.floating):
|
||||
video = np.clip(255 * video, 0, 255).astype(np.uint8)
|
||||
B, T, H, W, C = video.shape
|
||||
try:
|
||||
frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
|
||||
summary = tf1.Summary()
|
||||
image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C)
|
||||
image.encoded_image_string = encode_gif(frames, fps)
|
||||
summary.value.add(tag=name, image=image)
|
||||
tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
|
||||
except (IOError, OSError) as e:
|
||||
print('GIF summaries require ffmpeg in $PATH.', e)
|
||||
frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C))
|
||||
tf.summary.image(name, frames, step)
|
||||
|
||||
|
||||
def encode_gif(frames, fps):
|
||||
from subprocess import Popen, PIPE
|
||||
h, w, c = frames[0].shape
|
||||
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
|
||||
cmd = ' '.join([
|
||||
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
|
||||
f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
|
||||
f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
|
||||
f'-r {fps:.02f} -f gif -'])
|
||||
proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
|
||||
for image in frames:
|
||||
proc.stdin.write(image.tostring())
|
||||
out, err = proc.communicate()
|
||||
if proc.returncode:
|
||||
raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
|
||||
del proc
|
||||
return out
|
||||
|
||||
|
||||
def simulate(agent, envs, steps=0, episodes=0, state=None):
|
||||
# Initialize or unpack simulation state.
|
||||
if state is None:
|
||||
step, episode = 0, 0
|
||||
done = np.ones(len(envs), np.bool)
|
||||
length = np.zeros(len(envs), np.int32)
|
||||
obs = [None] * len(envs)
|
||||
agent_state = None
|
||||
else:
|
||||
step, episode, done, length, obs, agent_state = state
|
||||
while (steps and step < steps) or (episodes and episode < episodes):
|
||||
# Reset envs if necessary.
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
results = [envs[i].reset() for i in indices]
|
||||
for index, result in zip(indices, results):
|
||||
obs[index] = result
|
||||
# Step agents.
|
||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
||||
action, agent_state = agent(obs, done, agent_state)
|
||||
if isinstance(action, dict):
|
||||
action = [
|
||||
{k: np.array(action[k][i]) for k in action}
|
||||
for i in range(len(envs))]
|
||||
else:
|
||||
action = np.array(action)
|
||||
assert len(action) == len(envs)
|
||||
# Step envs.
|
||||
results = [e.step(a) for e, a in zip(envs, action)]
|
||||
obs, _, done = zip(*[p[:3] for p in results])
|
||||
obs = list(obs)
|
||||
done = np.stack(done)
|
||||
episode += int(done.sum())
|
||||
length += 1
|
||||
step += (done * length).sum()
|
||||
length *= (1 - done)
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# Return new state to allow resuming the simulation.
|
||||
return (step - steps, episode - episodes, done, length, obs, agent_state)
|
||||
|
||||
|
||||
def save_episodes(directory, episodes):
|
||||
directory = pathlib.Path(directory).expanduser()
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
|
||||
filenames = []
|
||||
for episode in episodes:
|
||||
identifier = str(uuid.uuid4().hex)
|
||||
length = len(episode['reward'])
|
||||
filename = directory / f'{timestamp}-{identifier}-{length}.npz'
|
||||
with io.BytesIO() as f1:
|
||||
np.savez_compressed(f1, **episode)
|
||||
f1.seek(0)
|
||||
with filename.open('wb') as f2:
|
||||
f2.write(f1.read())
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
|
||||
def sample_episodes(episodes, length=None, balance=False, seed=0):
|
||||
random = np.random.RandomState(seed)
|
||||
while True:
|
||||
episode = random.choice(list(episodes.values()))
|
||||
if length:
|
||||
total = len(next(iter(episode.values())))
|
||||
available = total - length
|
||||
if available < 1:
|
||||
# print(f'Skipped short episode of length {available}.')
|
||||
continue
|
||||
if balance:
|
||||
index = min(random.randint(0, total), available)
|
||||
else:
|
||||
index = int(random.randint(0, available + 1))
|
||||
episode = {k: v[index: index + length] for k, v in episode.items()}
|
||||
yield episode
|
||||
|
||||
|
||||
def load_episodes(directory, limit=None):
|
||||
directory = pathlib.Path(directory).expanduser()
|
||||
episodes = {}
|
||||
total = 0
|
||||
for filename in reversed(sorted(directory.glob('*.npz'))):
|
||||
try:
|
||||
with filename.open('rb') as f:
|
||||
episode = np.load(f)
|
||||
episode = {k: episode[k] for k in episode.keys()}
|
||||
except Exception as e:
|
||||
print(f'Could not load episode: {e}')
|
||||
continue
|
||||
episodes[str(filename)] = episode
|
||||
total += len(episode['reward']) - 1
|
||||
if limit and total >= limit:
|
||||
break
|
||||
return episodes
|
||||
|
||||
|
||||
class DtypeDist:
|
||||
|
||||
def __init__(self, dist, dtype=None):
|
||||
self._dist = dist
|
||||
self._dtype = dtype or prec.global_policy().compute_dtype
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'DtypeDist'
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._dist, name)
|
||||
|
||||
def mean(self):
|
||||
return tf.cast(self._dist.mean(), self._dtype)
|
||||
|
||||
def mode(self):
|
||||
return tf.cast(self._dist.mode(), self._dtype)
|
||||
|
||||
def entropy(self):
|
||||
return tf.cast(self._dist.entropy(), self._dtype)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
return tf.cast(self._dist.sample(*args, **kwargs), self._dtype)
|
||||
|
||||
|
||||
class SampleDist:
|
||||
|
||||
def __init__(self, dist, samples=100):
|
||||
self._dist = dist
|
||||
self._samples = samples
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'SampleDist'
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._dist, name)
|
||||
|
||||
def mean(self):
|
||||
samples = self._dist.sample(self._samples)
|
||||
return tf.reduce_mean(samples, 0)
|
||||
|
||||
def mode(self):
|
||||
sample = self._dist.sample(self._samples)
|
||||
logprob = self._dist.log_prob(sample)
|
||||
return tf.gather(sample, tf.argmax(logprob))[0]
|
||||
|
||||
def entropy(self):
|
||||
sample = self._dist.sample(self._samples)
|
||||
logprob = self.log_prob(sample)
|
||||
return -tf.reduce_mean(logprob, 0)
|
||||
|
||||
|
||||
class OneHotDist(tfd.OneHotCategorical):
|
||||
|
||||
def __init__(self, logits=None, probs=None, dtype=None):
|
||||
self._sample_dtype = dtype or prec.global_policy().compute_dtype
|
||||
super().__init__(logits=logits, probs=probs)
|
||||
|
||||
def mode(self):
|
||||
return tf.cast(super().mode(), self._sample_dtype)
|
||||
|
||||
def sample(self, sample_shape=(), seed=None):
|
||||
# Straight through biased gradient estimator.
|
||||
sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
|
||||
probs = super().probs_parameter()
|
||||
while len(probs.shape) < len(sample.shape):
|
||||
probs = probs[None]
|
||||
sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype)
|
||||
return sample
|
||||
|
||||
|
||||
class GumbleDist(tfd.RelaxedOneHotCategorical):
|
||||
|
||||
def __init__(self, temp, logits=None, probs=None, dtype=None):
|
||||
self._sample_dtype = dtype or prec.global_policy().compute_dtype
|
||||
self._exact = tfd.OneHotCategorical(logits=logits, probs=probs)
|
||||
super().__init__(temp, logits=logits, probs=probs)
|
||||
|
||||
def mode(self):
|
||||
return tf.cast(self._exact.mode(), self._sample_dtype)
|
||||
|
||||
def entropy(self):
|
||||
return tf.cast(self._exact.entropy(), self._sample_dtype)
|
||||
|
||||
def sample(self, sample_shape=(), seed=None):
|
||||
return tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
|
||||
|
||||
|
||||
class UnnormalizedHuber(tfd.Normal):
|
||||
|
||||
def __init__(self, loc, scale, threshold=1, **kwargs):
|
||||
self._threshold = tf.cast(threshold, loc.dtype)
|
||||
super().__init__(loc, scale, **kwargs)
|
||||
|
||||
def log_prob(self, event):
|
||||
return -(tf.math.sqrt(
|
||||
(event - self.mean()) ** 2 + self._threshold ** 2) - self._threshold)
|
||||
|
||||
|
||||
class SafeTruncatedNormal(tfd.TruncatedNormal):
|
||||
|
||||
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
|
||||
super().__init__(loc, scale, low, high)
|
||||
self._clip = clip
|
||||
self._mult = mult
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
event = super().sample(*args, **kwargs)
|
||||
if self._clip:
|
||||
clipped = tf.clip_by_value(
|
||||
event, self.low + self._clip, self.high - self._clip)
|
||||
event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped)
|
||||
if self._mult:
|
||||
event *= self._mult
|
||||
return event
|
||||
|
||||
|
||||
class TanhBijector(tfp.bijectors.Bijector):
|
||||
|
||||
def __init__(self, validate_args=False, name='tanh'):
|
||||
super().__init__(
|
||||
forward_min_event_ndims=0,
|
||||
validate_args=validate_args,
|
||||
name=name)
|
||||
|
||||
def _forward(self, x):
|
||||
return tf.nn.tanh(x)
|
||||
|
||||
def _inverse(self, y):
|
||||
dtype = y.dtype
|
||||
y = tf.cast(y, tf.float32)
|
||||
y = tf.where(
|
||||
tf.less_equal(tf.abs(y), 1.),
|
||||
tf.clip_by_value(y, -0.99999997, 0.99999997), y)
|
||||
y = tf.atanh(y)
|
||||
y = tf.cast(y, dtype)
|
||||
return y
|
||||
|
||||
def _forward_log_det_jacobian(self, x):
|
||||
log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype))
|
||||
return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x))
|
||||
|
||||
|
||||
def lambda_return(
|
||||
reward, value, pcont, bootstrap, lambda_, axis):
|
||||
# Setting lambda=1 gives a discounted Monte Carlo return.
|
||||
# Setting lambda=0 gives a fixed 1-step return.
|
||||
assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
|
||||
if isinstance(pcont, (int, float)):
|
||||
pcont = pcont * tf.ones_like(reward)
|
||||
dims = list(range(reward.shape.ndims))
|
||||
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
|
||||
if axis != 0:
|
||||
reward = tf.transpose(reward, dims)
|
||||
value = tf.transpose(value, dims)
|
||||
pcont = tf.transpose(pcont, dims)
|
||||
if bootstrap is None:
|
||||
bootstrap = tf.zeros_like(value[-1])
|
||||
next_values = tf.concat([value[1:], bootstrap[None]], 0)
|
||||
inputs = reward + pcont * next_values * (1 - lambda_)
|
||||
returns = static_scan(
|
||||
lambda agg, cur: cur[0] + cur[1] * lambda_ * agg,
|
||||
(inputs, pcont), bootstrap, reverse=True)
|
||||
if axis != 0:
|
||||
returns = tf.transpose(returns, dims)
|
||||
return returns
|
||||
|
||||
|
||||
class Optimizer(tf.Module):
|
||||
|
||||
def __init__(
|
||||
self, name, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
|
||||
opt='adam'):
|
||||
assert 0 <= wd < 1
|
||||
assert not clip or 1 <= clip
|
||||
self._name = name
|
||||
self._clip = clip
|
||||
self._wd = wd
|
||||
self._wd_pattern = wd_pattern
|
||||
self._opt = {
|
||||
'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps),
|
||||
'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps),
|
||||
'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps),
|
||||
'sgd': lambda: tf.optimizers.SGD(lr),
|
||||
'momentum': lambda: tf.optimizers.SGD(lr, 0.9),
|
||||
}[opt]()
|
||||
self._mixed = (prec.global_policy().compute_dtype == tf.float16)
|
||||
if self._mixed:
|
||||
self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic')
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return self._opt.variables()
|
||||
|
||||
def __call__(self, tape, loss, modules, prefix=None):
|
||||
assert loss.dtype is tf.float32, self._name
|
||||
modules = modules if hasattr(modules, '__len__') else (modules,)
|
||||
varibs = tf.nest.flatten([module.variables for module in modules])
|
||||
count = sum(np.prod(x.shape) for x in varibs)
|
||||
print(f'Found {count} {self._name} parameters.')
|
||||
assert len(loss.shape) == 0, loss.shape
|
||||
tf.debugging.check_numerics(loss, self._name + '_loss')
|
||||
if self._mixed:
|
||||
with tape:
|
||||
loss = self._opt.get_scaled_loss(loss)
|
||||
grads = tape.gradient(loss, varibs)
|
||||
if self._mixed:
|
||||
grads = self._opt.get_unscaled_gradients(grads)
|
||||
norm = tf.linalg.global_norm(grads)
|
||||
if not self._mixed:
|
||||
tf.debugging.check_numerics(norm, self._name + '_norm')
|
||||
if self._clip:
|
||||
grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
|
||||
if self._wd:
|
||||
self._apply_weight_decay(varibs)
|
||||
self._opt.apply_gradients(zip(grads, varibs))
|
||||
metrics = {}
|
||||
if prefix:
|
||||
metrics[f'{prefix}/{self._name}_loss'] = loss
|
||||
metrics[f'{prefix}/{self._name}_grad_norm'] = norm
|
||||
if self._mixed:
|
||||
metrics[f'{prefix}/{self._name}_loss_scale'] = \
|
||||
self._opt.loss_scale._current_loss_scale
|
||||
else:
|
||||
metrics[f'{self._name}_loss'] = loss
|
||||
metrics[f'{self._name}_grad_norm'] = norm
|
||||
if self._mixed:
|
||||
metrics[f'{self._name}_loss_scale'] = \
|
||||
self._opt.loss_scale._current_loss_scale
|
||||
return metrics
|
||||
|
||||
def _apply_weight_decay(self, varibs):
|
||||
nontrivial = (self._wd_pattern != r'.*')
|
||||
if nontrivial:
|
||||
print('Applied weight decay to variables:')
|
||||
for var in varibs:
|
||||
if re.search(self._wd_pattern, self._name + '/' + var.name):
|
||||
if nontrivial:
|
||||
print('- ' + self._name + '/' + var.name)
|
||||
var.assign((1 - self._wd) * var)
|
||||
|
||||
|
||||
def args_type(default):
|
||||
def parse_string(x):
|
||||
if default is None:
|
||||
return x
|
||||
if isinstance(default, bool):
|
||||
return bool(['False', 'True'].index(x))
|
||||
if isinstance(default, int):
|
||||
return float(x) if ('e' in x or '.' in x) else int(x)
|
||||
if isinstance(default, (list, tuple)):
|
||||
return tuple(args_type(default[0])(y) for y in x.split(','))
|
||||
return type(default)(x)
|
||||
def parse_object(x):
|
||||
if isinstance(default, (list, tuple)):
|
||||
return tuple(x)
|
||||
return x
|
||||
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
|
||||
|
||||
|
||||
def static_scan(fn, inputs, start, reverse=False):
|
||||
last = start
|
||||
outputs = [[] for _ in tf.nest.flatten(start)]
|
||||
indices = range(len(tf.nest.flatten(inputs)[0]))
|
||||
if reverse:
|
||||
indices = reversed(indices)
|
||||
for index in indices:
|
||||
inp = tf.nest.map_structure(lambda x: x[index], inputs)
|
||||
last = fn(last, inp)
|
||||
[o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))]
|
||||
if reverse:
|
||||
outputs = [list(reversed(x)) for x in outputs]
|
||||
outputs = [tf.stack(x, 0) for x in outputs]
|
||||
return tf.nest.pack_sequence_as(start, outputs)
|
||||
|
||||
|
||||
def uniform_mixture(dist, dtype=None):
|
||||
if dist.batch_shape[-1] == 1:
|
||||
return tfd.BatchReshape(dist, dist.batch_shape[:-1])
|
||||
dtype = dtype or prec.global_policy().compute_dtype
|
||||
weights = tfd.Categorical(tf.zeros(dist.batch_shape, dtype))
|
||||
return tfd.MixtureSameFamily(weights, dist)
|
||||
|
||||
|
||||
def cat_mixture_entropy(dist):
|
||||
if isinstance(dist, tfd.MixtureSameFamily):
|
||||
probs = dist.components_distribution.probs_parameter()
|
||||
else:
|
||||
probs = dist.probs_parameter()
|
||||
return -tf.reduce_mean(
|
||||
tf.reduce_mean(probs, 2) *
|
||||
tf.math.log(tf.reduce_mean(probs, 2) + 1e-8), -1)
|
||||
|
||||
|
||||
@tf.function
|
||||
def cem_planner(
|
||||
state, num_actions, horizon, proposals, topk, iterations, imagine,
|
||||
objective):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
B, P = list(state.values())[0].shape[0], proposals
|
||||
H, A = horizon, num_actions
|
||||
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
|
||||
mean = tf.zeros((B, H, A), dtype)
|
||||
std = tf.ones((B, H, A), dtype)
|
||||
for _ in range(iterations):
|
||||
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
|
||||
proposals = proposals * std[:, None] + mean[:, None]
|
||||
proposals = tf.clip_by_value(proposals, -1, 1)
|
||||
flat_proposals = tf.reshape(proposals, (B * P, H, A))
|
||||
states = imagine(flat_proposals, flat_state)
|
||||
scores = objective(states)
|
||||
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
|
||||
_, indices = tf.math.top_k(scores, topk, sorted=False)
|
||||
best = tf.gather(proposals, indices, axis=1, batch_dims=1)
|
||||
mean, var = tf.nn.moments(best, 1)
|
||||
std = tf.sqrt(var + 1e-6)
|
||||
return mean[:, 0, :]
|
||||
|
||||
|
||||
@tf.function
|
||||
def grad_planner(
|
||||
state, num_actions, horizon, proposals, iterations, imagine, objective,
|
||||
kl_scale, step_size):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
B, P = list(state.values())[0].shape[0], proposals
|
||||
H, A = horizon, num_actions
|
||||
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
|
||||
mean = tf.zeros((B, H, A), dtype)
|
||||
rawstd = 0.54 * tf.ones((B, H, A), dtype)
|
||||
for _ in range(iterations):
|
||||
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
|
||||
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
||||
tape.watch(mean)
|
||||
tape.watch(rawstd)
|
||||
std = tf.nn.softplus(rawstd)
|
||||
proposals = proposals * std[:, None] + mean[:, None]
|
||||
proposals = (
|
||||
tf.stop_gradient(tf.clip_by_value(proposals, -1, 1)) +
|
||||
proposals - tf.stop_gradient(proposals))
|
||||
flat_proposals = tf.reshape(proposals, (B * P, H, A))
|
||||
states = imagine(flat_proposals, flat_state)
|
||||
scores = objective(states)
|
||||
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
|
||||
div = tfd.kl_divergence(
|
||||
tfd.Normal(mean, std),
|
||||
tfd.Normal(tf.zeros_like(mean), tf.ones_like(std)))
|
||||
elbo = tf.reduce_sum(scores) - kl_scale * div
|
||||
elbo /= tf.cast(tf.reduce_prod(tf.shape(scores)), dtype)
|
||||
grad_mean, grad_rawstd = tape.gradient(elbo, [mean, rawstd])
|
||||
e, v = tf.nn.moments(grad_mean, [1, 2], keepdims=True)
|
||||
grad_mean /= tf.sqrt(e * e + v + 1e-4)
|
||||
e, v = tf.nn.moments(grad_rawstd, [1, 2], keepdims=True)
|
||||
grad_rawstd /= tf.sqrt(e * e + v + 1e-4)
|
||||
mean = tf.clip_by_value(mean + step_size * grad_mean, -1, 1)
|
||||
rawstd = rawstd + step_size * grad_rawstd
|
||||
return mean[:, 0, :]
|
||||
|
||||
|
||||
class Every:
|
||||
|
||||
def __init__(self, every):
|
||||
self._every = every
|
||||
self._last = None
|
||||
|
||||
def __call__(self, step):
|
||||
if not self._every:
|
||||
return False
|
||||
if self._last is None:
|
||||
self._last = step
|
||||
return True
|
||||
if step >= self._last + self._every:
|
||||
self._last += self._every
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Once:
|
||||
|
||||
def __init__(self):
|
||||
self._once = True
|
||||
|
||||
def __call__(self):
|
||||
if self._once:
|
||||
self._once = False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Until:
|
||||
|
||||
def __init__(self, until):
|
||||
self._until = until
|
||||
|
||||
def __call__(self, step):
|
||||
if not self._until:
|
||||
return True
|
||||
return step < self._until
|
||||
|
||||
|
||||
def schedule(string, step):
|
||||
try:
|
||||
return float(string)
|
||||
except ValueError:
|
||||
step = tf.cast(step, tf.float32)
|
||||
match = re.match(r'linear\((.+),(.+),(.+)\)', string)
|
||||
if match:
|
||||
initial, final, duration = [float(group) for group in match.groups()]
|
||||
mix = tf.clip_by_value(step / duration, 0, 1)
|
||||
return (1 - mix) * initial + mix * final
|
||||
match = re.match(r'warmup\((.+),(.+)\)', string)
|
||||
if match:
|
||||
warmup, value = [float(group) for group in match.groups()]
|
||||
scale = tf.clip_by_value(step / warmup, 0, 1)
|
||||
return scale * value
|
||||
match = re.match(r'exp\((.+),(.+),(.+)\)', string)
|
||||
if match:
|
||||
initial, final, halflife = [float(group) for group in match.groups()]
|
||||
return (initial - final) * 0.5 ** (step / halflife) + final
|
||||
raise NotImplementedError(string)
|
280
DreamerV2/wrappers.py
Normal file
280
DreamerV2/wrappers.py
Normal file
@ -0,0 +1,280 @@
|
||||
import threading
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
class DeepMindControl:
|
||||
|
||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
||||
domain, task = name.split('_', 1)
|
||||
if domain == 'cup': # Only domain with multiple words.
|
||||
domain = 'ball_in_cup'
|
||||
if isinstance(domain, str):
|
||||
from dm_control import suite
|
||||
self._env = suite.load(domain, task)
|
||||
else:
|
||||
assert task is None
|
||||
self._env = domain()
|
||||
self._action_repeat = action_repeat
|
||||
self._size = size
|
||||
if camera is None:
|
||||
camera = dict(quadruped=2).get(domain, 0)
|
||||
self._camera = camera
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = {}
|
||||
for key, value in self._env.observation_spec().items():
|
||||
spaces[key] = gym.spaces.Box(
|
||||
-np.inf, np.inf, value.shape, dtype=np.float32)
|
||||
spaces['image'] = gym.spaces.Box(
|
||||
0, 255, self._size + (3,), dtype=np.uint8)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
spec = self._env.action_spec()
|
||||
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
assert np.isfinite(action).all(), action
|
||||
reward = 0
|
||||
for _ in range(self._action_repeat):
|
||||
time_step = self._env.step(action)
|
||||
reward += time_step.reward or 0
|
||||
if time_step.last():
|
||||
break
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
done = time_step.last()
|
||||
info = {'discount': np.array(time_step.discount, np.float32)}
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
return obs
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
|
||||
raise ValueError("Only render mode 'rgb_array' is supported.")
|
||||
return self._env.physics.render(*self._size, camera_id=self._camera)
|
||||
|
||||
|
||||
class Atari:
|
||||
|
||||
LOCK = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30,
|
||||
life_done=False, sticky_actions=True, all_actions=False):
|
||||
assert size[0] == size[1]
|
||||
import gym.wrappers
|
||||
import gym.envs.atari
|
||||
with self.LOCK:
|
||||
env = gym.envs.atari.AtariEnv(
|
||||
game=name, obs_type='image', frameskip=1,
|
||||
repeat_action_probability=0.25 if sticky_actions else 0.0,
|
||||
full_action_space=all_actions)
|
||||
# Avoid unnecessary rendering in inner env.
|
||||
env._get_obs = lambda: None
|
||||
# Tell wrapper that the inner env has no action repeat.
|
||||
env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0')
|
||||
env = gym.wrappers.AtariPreprocessing(
|
||||
env, noops, action_repeat, size[0], life_done, grayscale)
|
||||
self._env = env
|
||||
self._grayscale = grayscale
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return gym.spaces.Dict({
|
||||
'image': self._env.observation_space,
|
||||
'ram': gym.spaces.Box(0, 255, (128,), np.uint8),
|
||||
})
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return self._env.action_space
|
||||
|
||||
def close(self):
|
||||
return self._env.close()
|
||||
|
||||
def reset(self):
|
||||
with self.LOCK:
|
||||
image = self._env.reset()
|
||||
if self._grayscale:
|
||||
image = image[..., None]
|
||||
obs = {'image': image, 'ram': self._env.env._get_ram()}
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
image, reward, done, info = self._env.step(action)
|
||||
if self._grayscale:
|
||||
image = image[..., None]
|
||||
obs = {'image': image, 'ram': self._env.env._get_ram()}
|
||||
return obs, reward, done, info
|
||||
|
||||
def render(self, mode):
|
||||
return self._env.render(mode)
|
||||
|
||||
class CollectDataset:
|
||||
|
||||
def __init__(self, env, callbacks=None, precision=32):
|
||||
self._env = env
|
||||
self._callbacks = callbacks or ()
|
||||
self._precision = precision
|
||||
self._episode = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs = {k: self._convert(v) for k, v in obs.items()}
|
||||
transition = obs.copy()
|
||||
transition['action'] = action
|
||||
transition['reward'] = reward
|
||||
transition['discount'] = info.get('discount', np.array(1 - float(done)))
|
||||
self._episode.append(transition)
|
||||
if done:
|
||||
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
|
||||
episode = {k: self._convert(v) for k, v in episode.items()}
|
||||
info['episode'] = episode
|
||||
for callback in self._callbacks:
|
||||
callback(episode)
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
transition = obs.copy()
|
||||
transition['action'] = np.zeros(self._env.action_space.shape)
|
||||
transition['reward'] = 0.0
|
||||
transition['discount'] = 1.0
|
||||
self._episode = [transition]
|
||||
return obs
|
||||
|
||||
def _convert(self, value):
|
||||
value = np.array(value)
|
||||
if np.issubdtype(value.dtype, np.floating):
|
||||
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.signedinteger):
|
||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.uint8):
|
||||
dtype = np.uint8
|
||||
else:
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
|
||||
|
||||
class TimeLimit:
|
||||
|
||||
def __init__(self, env, duration):
|
||||
self._env = env
|
||||
self._duration = duration
|
||||
self._step = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
assert self._step is not None, 'Must reset environment.'
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
self._step += 1
|
||||
if self._step >= self._duration:
|
||||
done = True
|
||||
if 'discount' not in info:
|
||||
info['discount'] = np.array(1.0).astype(np.float32)
|
||||
self._step = None
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
self._step = 0
|
||||
return self._env.reset()
|
||||
|
||||
|
||||
class NormalizeActions:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
self._mask = np.logical_and(
|
||||
np.isfinite(env.action_space.low),
|
||||
np.isfinite(env.action_space.high))
|
||||
self._low = np.where(self._mask, env.action_space.low, -1)
|
||||
self._high = np.where(self._mask, env.action_space.high, 1)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
||||
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
||||
return gym.spaces.Box(low, high, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
||||
original = np.where(self._mask, original, action)
|
||||
return self._env.step(original)
|
||||
|
||||
|
||||
class OneHotAction:
|
||||
|
||||
def __init__(self, env):
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
self._env = env
|
||||
self._random = np.random.RandomState()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
shape = (self._env.action_space.n,)
|
||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
space.sample = self._sample_action
|
||||
return space
|
||||
|
||||
def step(self, action):
|
||||
index = np.argmax(action).astype(int)
|
||||
reference = np.zeros_like(action)
|
||||
reference[index] = 1
|
||||
if not np.allclose(reference, action):
|
||||
raise ValueError(f'Invalid one-hot action:\n{action}')
|
||||
return self._env.step(index)
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def _sample_action(self):
|
||||
actions = self._env.action_space.n
|
||||
index = self._random.randint(0, actions)
|
||||
reference = np.zeros(actions, dtype=np.float32)
|
||||
reference[index] = 1.0
|
||||
return reference
|
||||
|
||||
|
||||
class RewardObs:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = self._env.observation_space.spaces
|
||||
assert 'reward' not in spaces
|
||||
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs['reward'] = reward
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
obs['reward'] = 0.0
|
||||
return obs
|
Loading…
Reference in New Issue
Block a user