diff --git a/DreamerV2/configs.yaml b/DreamerV2/configs.yaml new file mode 100644 index 0000000..872e3ae --- /dev/null +++ b/DreamerV2/configs.yaml @@ -0,0 +1,185 @@ +defaults: + gpu: 'none' + logdir: ./ + traindir: null + evaldir: null + offline_traindir: '' + offline_evaldir: '' + seed: 0 + steps: 1e7 + eval_every: 1e4 + log_every: 1e4 + reset_every: 0 + gpu_growth: True + precision: 32 + debug: False + expl_gifs: False + + # Environment + task: 'dmc_walker_walk' + size: [64, 64] + envs: 1 + action_repeat: 2 + time_limit: 1000 + prefill: 2500 + eval_noise: 0.0 + clip_rewards: 'identity' + atari_grayscale: False + + # Model + dyn_cell: 'gru' + dyn_hidden: 200 + dyn_deter: 200 + dyn_stoch: 50 + dyn_discrete: 0 + dyn_input_layers: 1 + dyn_output_layers: 1 + dyn_shared: False + dyn_mean_act: 'none' + dyn_std_act: 'sigmoid2' + dyn_min_std: 0.1 + grad_heads: ['image', 'reward'] + units: 400 + reward_layers: 2 + discount_layers: 3 + value_layers: 3 + actor_layers: 4 + act: 'elu' + cnn_depth: 32 + encoder_kernels: [4, 4, 4, 4] + decoder_kernels: [5, 5, 6, 6] + decoder_thin: True + value_head: 'normal' + kl_scale: '1.0' + kl_balance: '0.8' + kl_free: '1.0' + pred_discount: False + discount_scale: 1.0 + reward_scale: 1.0 + weight_decay: 0.0 + + # Training + batch_size: 50 + batch_length: 50 + train_every: 5 + train_steps: 1 + pretrain: 100 + model_lr: 3e-4 + value_lr: 8e-5 + actor_lr: 8e-5 + opt_eps: 1e-5 + grad_clip: 100 + value_grad_clip: 100 + actor_grad_clip: 100 + dataset_size: 0 + oversample_ends: False + slow_value_target: True + slow_actor_target: True + slow_target_update: 100 + slow_target_fraction: 1 + opt: 'adam' + + # Behavior. + discount: 0.99 + discount_lambda: 0.95 + imag_horizon: 15 + imag_gradient: 'dynamics' + imag_gradient_mix: '0.1' + imag_sample: True + actor_dist: 'trunc_normal' + actor_entropy: '1e-4' + actor_state_entropy: 0.0 + actor_init_std: 1.0 + actor_min_std: 0.1 + actor_disc: 5 + actor_temp: 0.1 + actor_outscale: 0.0 + expl_amount: 0.0 + eval_state_mean: False + collect_dyn_sample: True + behavior_stop_grad: True + value_decay: 0.0 + future_entropy: False + + # Exploration + expl_behavior: 'greedy' + expl_until: 0 + expl_extr_scale: 0.0 + expl_intr_scale: 1.0 + disag_target: 'stoch' + disag_log: True + disag_models: 10 + disag_offset: 1 + disag_layers: 4 + disag_units: 400 + +atari: + + # General + task: 'atari_demon_attack' + steps: 3e7 + eval_every: 1e5 + log_every: 1e4 + prefill: 50000 + dataset_size: 2e6 + pretrain: 0 + precision: 16 + + # Environment + time_limit: 108000 # 30 minutes of game play. + atari_grayscale: True + action_repeat: 4 + eval_noise: 0.001 + train_every: 16 + train_steps: 1 + clip_rewards: 'tanh' + + # Model + grad_heads: ['image', 'reward', 'discount'] + dyn_cell: 'gru_layer_norm' + pred_discount: True + cnn_depth: 48 + dyn_deter: 600 + dyn_hidden: 600 + dyn_stoch: 32 + dyn_discrete: 32 + reward_layers: 4 + discount_layers: 4 + value_layers: 4 + actor_layers: 4 + + # Behavior + actor_dist: 'onehot' + actor_entropy: 'linear(3e-3,3e-4,2.5e6)' + expl_amount: 0.0 + expl_until: 3e7 + discount: 0.995 + imag_gradient: 'both' + imag_gradient_mix: 'linear(0.1,0,2.5e6)' + + # Training + discount_scale: 5.0 + reward_scale: 1 + weight_decay: 1e-6 + model_lr: 2e-4 + kl_scale: 0.1 + kl_free: 0.0 + actor_lr: 4e-5 + value_lr: 1e-4 + oversample_ends: True + + # Disen + disen_cnn_depth: 16 + disen_only_scale: 1.0 + disen_discount_scale: 2000.0 + disen_reward_scale: 2000.0 + num_reward_opt_iters: 20 + +debug: + + debug: True + pretrain: 1 + prefill: 1 + train_steps: 1 + batch_size: 10 + batch_length: 20 diff --git a/DreamerV2/dreamer.py b/DreamerV2/dreamer.py new file mode 100644 index 0000000..6883bfc --- /dev/null +++ b/DreamerV2/dreamer.py @@ -0,0 +1,316 @@ +import argparse +import collections +import functools +import os +import pathlib +import sys +import warnings + +warnings.filterwarnings('ignore', '.*box bound precision lowered.*') +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ['MUJOCO_GL'] = 'egl' + +import numpy as np +import ruamel.yaml as yaml +import tensorflow as tf +from tensorflow.keras.mixed_precision import experimental as prec + +tf.get_logger().setLevel('ERROR') + +from tensorflow_probability import distributions as tfd + +sys.path.append(str(pathlib.Path(__file__).parent)) + +import exploration as expl +import models +import tools +import wrappers + +class Dreamer(tools.Module): + + def __init__(self, config, logger, dataset): + self._config = config + self._logger = logger + self._float = prec.global_policy().compute_dtype + self._should_log = tools.Every(config.log_every) + self._should_train = tools.Every(config.train_every) + self._should_pretrain = tools.Once() + self._should_reset = tools.Every(config.reset_every) + self._should_expl = tools.Until(int( + config.expl_until / config.action_repeat)) + self._metrics = collections.defaultdict(tf.metrics.Mean) + with tf.device('cpu:0'): + self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64) + # Schedules. + config.actor_entropy = ( + lambda x=config.actor_entropy: tools.schedule(x, self._step)) + config.actor_state_entropy = ( + lambda x=config.actor_state_entropy: tools.schedule(x, self._step)) + config.imag_gradient_mix = ( + lambda x=config.imag_gradient_mix: tools.schedule(x, self._step)) + self._dataset = iter(dataset) + self._wm = models.WorldModel(self._step, config) + self._task_behavior = models.ImagBehavior( + config, self._wm, config.behavior_stop_grad) + reward = lambda f, s, a: self._wm.heads['reward'](f).mode() + self._expl_behavior = dict( + greedy=lambda: self._task_behavior, + random=lambda: expl.Random(config), + plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), + )[config.expl_behavior]() + # Train step to initialize variables including optimizer statistics. + self._train(next(self._dataset)) + + def __call__(self, obs, reset, state=None, training=True): + step = self._step.numpy().item() + if self._should_reset(step): + state = None + if state is not None and reset.any(): + mask = tf.cast(1 - reset, self._float)[:, None] + state = tf.nest.map_structure(lambda x: x * mask, state) + if training and self._should_train(step): + steps = ( + self._config.pretrain if self._should_pretrain() + else self._config.train_steps) + for _ in range(steps): + self._train(next(self._dataset)) + if self._should_log(step): + for name, mean in self._metrics.items(): + self._logger.scalar(name, float(mean.result())) + mean.reset_states() + openl_joint, openl_main, openl_disen, openl_mask = self._wm.video_pred(next(self._dataset)) + self._logger.video('train_openl_joint', openl_joint) + self._logger.video('train_openl_main', openl_main) + self._logger.video('train_openl_disen', openl_disen) + self._logger.video('train_openl_mask', openl_mask) + self._logger.write(fps=True) + action, state = self._policy(obs, state, training) + if training: + self._step.assign_add(len(reset)) + self._logger.step = self._config.action_repeat \ + * self._step.numpy().item() + return action, state + + @tf.function + def _policy(self, obs, state, training): + if state is None: + batch_size = len(obs['image']) + latent = self._wm.dynamics.initial(len(obs['image'])) + action = tf.zeros((batch_size, self._config.num_actions), self._float) + else: + latent, action = state + embed = self._wm.encoder(self._wm.preprocess(obs)) + latent, _ = self._wm.dynamics.obs_step( + latent, action, embed, self._config.collect_dyn_sample) + if self._config.eval_state_mean: + latent['stoch'] = latent['mean'] + feat = self._wm.dynamics.get_feat(latent) + if not training: + action = self._task_behavior.actor(feat).mode() + elif self._should_expl(self._step): + action = self._expl_behavior.actor(feat).sample() + else: + action = self._task_behavior.actor(feat).sample() + if self._config.actor_dist == 'onehot_gumble': + action = tf.cast( + tf.one_hot(tf.argmax(action, axis=-1), self._config.num_actions), + action.dtype) + action = self._exploration(action, training) + state = (latent, action) + return action, state + + def _exploration(self, action, training): + amount = self._config.expl_amount if training else self._config.eval_noise + if amount == 0: + return action + amount = tf.cast(amount, self._float) + if 'onehot' in self._config.actor_dist: + probs = amount / self._config.num_actions + (1 - amount) * action + return tools.OneHotDist(probs=probs).sample() + else: + return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) + raise NotImplementedError(self._config.action_noise) + + @tf.function + def _train(self, data): + print('Tracing train function.') + metrics = {} + embed, post, feat, kl, mets = self._wm.train(data) + metrics.update(mets) + start = post + if self._config.pred_discount: # Last step could be terminal. + start = {k: v[:, :-1] for k, v in post.items()} + embed, feat, kl = embed[:, :-1], feat[:, :-1], kl[:, :-1] + reward = lambda f, s, a: self._wm.heads['reward'](f).mode() + metrics.update(self._task_behavior.train(start, reward)[-1]) + if self._config.expl_behavior != 'greedy': + mets = self._expl_behavior.train(start, feat, embed, kl)[-1] + metrics.update({'expl_' + key: value for key, value in mets.items()}) + for name, value in metrics.items(): + self._metrics[name].update_state(value) + + +def count_steps(folder): + return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in folder.glob('*.npz')) + + +def make_dataset(episodes, config): + example = episodes[next(iter(episodes.keys()))] + types = {k: v.dtype for k, v in example.items()} + shapes = {k: (None,) + v.shape[1:] for k, v in example.items()} + generator = lambda: tools.sample_episodes( + episodes, config.batch_length, config.oversample_ends) + dataset = tf.data.Dataset.from_generator(generator, types, shapes) + dataset = dataset.batch(config.batch_size, drop_remainder=True) + dataset = dataset.prefetch(10) + return dataset + + +def make_env(config, logger, mode, train_eps, eval_eps): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + env = wrappers.DeepMindControl(task, config.action_repeat, config.size) + env = wrappers.NormalizeActions(env) + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, config.size, + grayscale=config.atari_grayscale, + life_done=False and (mode == 'train'), + sticky_actions=True, + all_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit) + callbacks = [functools.partial( + process_episode, config, logger, mode, train_eps, eval_eps)] + env = wrappers.CollectDataset(env, callbacks) + env = wrappers.RewardObs(env) + return env + + +def process_episode(config, logger, mode, train_eps, eval_eps, episode): + directory = dict(train=config.traindir, eval=config.evaldir)[mode] + cache = dict(train=train_eps, eval=eval_eps)[mode] + filename = tools.save_episodes(directory, [episode])[0] + length = len(episode['reward']) - 1 + score = float(episode['reward'].astype(np.float64).sum()) + video = episode['image'] + if mode == 'eval': + cache.clear() + if mode == 'train' and config.dataset_size: + total = 0 + for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): + if total <= config.dataset_size - length: + total += len(ep['reward']) - 1 + else: + del cache[key] + logger.scalar('dataset_size', total + length) + cache[str(filename)] = episode + print(f'{mode.title()} episode has {length} steps and return {score:.1f}.') + logger.scalar(f'{mode}_return', score) + logger.scalar(f'{mode}_length', length) + logger.scalar(f'{mode}_episodes', len(cache)) + if mode == 'eval' or config.expl_gifs: + logger.video(f'{mode}_policy', video[None]) + logger.write() + + +def main(logdir, config): + + logdir = os.path.join( + logdir, config.task, 'Ours', str(config.seed)) + + logdir = pathlib.Path(logdir).expanduser() + config.traindir = config.traindir or logdir / 'train_eps' + config.evaldir = config.evaldir or logdir / 'eval_eps' + config.steps //= config.action_repeat + config.eval_every //= config.action_repeat + config.log_every //= config.action_repeat + config.time_limit //= config.action_repeat + config.act = getattr(tf.nn, config.act) + + if config.debug: + tf.config.experimental_run_functions_eagerly(True) + if config.gpu_growth: + message = 'No GPU found. To actually train on CPU remove this assert.' + assert tf.config.experimental.list_physical_devices('GPU'), message + for gpu in tf.config.experimental.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(gpu, True) + assert config.precision in (16, 32), config.precision + if config.precision == 16: + prec.set_policy(prec.Policy('mixed_float16')) + print('Logdir', logdir) + logdir.mkdir(parents=True, exist_ok=True) + config.traindir.mkdir(parents=True, exist_ok=True) + config.evaldir.mkdir(parents=True, exist_ok=True) + step = count_steps(config.traindir) + logger = tools.Logger(logdir, config.action_repeat * step) + + print('Create envs.') + if config.offline_traindir: + directory = config.offline_traindir.format(**vars(config)) + else: + directory = config.traindir + train_eps = tools.load_episodes(directory, limit=config.dataset_size) + if config.offline_evaldir: + directory = config.offline_evaldir.format(**vars(config)) + else: + directory = config.evaldir + eval_eps = tools.load_episodes(directory, limit=1) + make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) + train_envs = [make('train') for _ in range(config.envs)] + eval_envs = [make('eval') for _ in range(config.envs)] + acts = train_envs[0].action_space + config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0] + + prefill = max(0, config.prefill - count_steps(config.traindir)) + print(f'Prefill dataset ({prefill} steps).') + random_agent = lambda o, d, s: ([acts.sample() for _ in d], s) + tools.simulate(random_agent, train_envs, prefill) + tools.simulate(random_agent, eval_envs, episodes=1) + logger.step = config.action_repeat * count_steps(config.traindir) + + print('Simulate agent.') + train_dataset = make_dataset(train_eps, config) + eval_dataset = iter(make_dataset(eval_eps, config)) + agent = Dreamer(config, logger, train_dataset) + if (logdir / 'variables.pkl').exists(): + agent.load(logdir / 'variables.pkl') + agent._should_pretrain._once = False + + state = None + suite, task = config.task.split('_', 1) + num_eval_episodes = 10 if suite == 'procgen' else 1 + while agent._step.numpy().item() < config.steps: + logger.write() + print('Start evaluation.') + openl_joint, openl_main, openl_disen, openl_mask = agent._wm.video_pred(next(eval_dataset)) + logger.video('eval_openl_joint', openl_joint) + logger.video('eval_openl_main', openl_main) + logger.video('eval_openl_disen', openl_disen) + logger.video('eval_openl_mask', openl_mask) + eval_policy = functools.partial(agent, training=False) + tools.simulate(eval_policy, eval_envs, episodes=num_eval_episodes) + print('Start training.') + state = tools.simulate(agent, train_envs, config.eval_every, state=state) + agent.save(logdir / 'variables.pkl') + for env in train_envs + eval_envs: + try: + env.close() + except Exception: + pass + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--configs', nargs='+', required=True) + args, remaining = parser.parse_known_args() + configs = yaml.safe_load((pathlib.Path(__file__).parent / 'configs.yaml').read_text()) + config_ = {} + for name in args.configs: + config_.update(configs[name]) + parser = argparse.ArgumentParser() + for key, value in config_.items(): + arg_type = tools.args_type(value) + parser.add_argument(f'--{key}', type=arg_type, default=arg_type(value)) + main(config_['logdir'], parser.parse_args(remaining)) diff --git a/DreamerV2/exploration.py b/DreamerV2/exploration.py new file mode 100644 index 0000000..4425ba5 --- /dev/null +++ b/DreamerV2/exploration.py @@ -0,0 +1,83 @@ +import tensorflow as tf +from tensorflow.keras.mixed_precision import experimental as prec +from tensorflow_probability import distributions as tfd + +import models +import networks +import tools + +class Random(tools.Module): + + def __init__(self, config): + self._config = config + self._float = prec.global_policy().comput_dtype + + def actor(self, feat): + shape = feat.shape[:-1] + [self._config.num_actions] + if self._config.actor_dist == 'onehot': + return tools.OneHotDist(tf.zeros(shape)) + else: + ones = tf.ones(shape, self._float) + return tfd.Uniform(-ones, ones) + + def train(self, start, feat, embed, kl): + return None, {} + + +class Plan2Explore(tools.Module): + + def __init__(self, config, world_model, reward=None): + self._config = config + self._reward = reward + self._behavior = models.ImagBehavior(config, world_model) + self.actor = self._behavior.actor + size = { + 'embed': 32 * config.cnn_depth, + 'stoch': config.dyn_stoch, + 'deter': config.dyn_deter, + 'feat': config.dyn_stoch + config.dyn_deter, + }[self._config.disag_target] + kw = dict( + shape=size, layers=config.disag_layers, units=config.disag_units, + act=config.act) + self._networks = [ + networks.DenseHead(**kw) for _ in range(config.disag_models)] + self._opt = tools.Optimizer( + 'ensemble', config.model_lr, config.opt_eps, config.grad_clip, + config.weight_decay, opt=config.opt) + + def train(self, start, feat, embed, kl): + metrics = {} + target = { + 'embed': embed, + 'stoch': start['stoch'], + 'deter': start['deter'], + 'feat': feat, + }[self._config.disag_target] + metrics.update(self._train_ensemble(feat, target)) + metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1]) + return None, metrics + + def _intrinsic_reward(self, feat, state, action): + preds = [head(feat, tf.float32).mean() for head in self._networks] + disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1) + if self._config.disag_log: + disag = tf.math.log(disag) + reward = self._config.expl_intr_scale * disag + if self._config.expl_extr_scale: + reward += tf.cast(self._config.expl_extr_scale * self._reward( + feat, state, action), tf.float32) + return reward + + def _train_ensemble(self, inputs, targets): + if self._config.disag_offset: + targets = targets[:, self._config.disag_offset:] + inputs = inputs[:, :-self._config.disag_offset] + targets = tf.stop_gradient(targets) + inputs = tf.stop_gradient(inputs) + with tf.GradientTape() as tape: + preds = [head(inputs) for head in self._networks] + likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] + loss = -tf.cast(tf.reduce_sum(likes), tf.float32) + metrics = self._opt(tape, loss, self._networks) + return metrics diff --git a/DreamerV2/models.py b/DreamerV2/models.py new file mode 100644 index 0000000..8225adb --- /dev/null +++ b/DreamerV2/models.py @@ -0,0 +1,429 @@ +import tensorflow as tf +from tensorflow.keras.mixed_precision import experimental as prec + +import networks +import tools + + +class WorldModel(tools.Module): + + def __init__(self, step, config): + self._step = step + self._config = config + channels = (1 if config.atari_grayscale else 3) + shape = config.size + (channels,) + + ######## + # Main # + ######## + self.encoder = networks.ConvEncoder( + config.cnn_depth, config.act, config.encoder_kernels) + self.dynamics = networks.RSSM( + config.dyn_stoch, config.dyn_deter, config.dyn_hidden, + config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared, + config.dyn_discrete, config.act, config.dyn_mean_act, + config.dyn_std_act, config.dyn_min_std, config.dyn_cell) + self.heads = {} + self.heads['reward'] = networks.DenseHead( + [], config.reward_layers, config.units, config.act) + if config.pred_discount: + self.heads['discount'] = networks.DenseHead( + [], config.discount_layers, config.units, config.act, dist='binary') + self._model_opt = tools.Optimizer( + 'model', config.model_lr, config.opt_eps, config.grad_clip, + config.weight_decay, opt=config.opt) + self._scales = dict( + reward=config.reward_scale, discount=config.discount_scale) + + ######### + # Disen # + ######### + self.disen_encoder = networks.ConvEncoder( + config.disen_cnn_depth, config.act, config.encoder_kernels) + self.disen_dynamics = networks.RSSM( + config.dyn_stoch, config.dyn_deter, config.dyn_hidden, + config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared, + config.dyn_discrete, config.act, config.dyn_mean_act, + config.dyn_std_act, config.dyn_min_std, config.dyn_cell) + + self.disen_heads = {} + self.disen_heads['reward'] = networks.DenseHead( + [], config.reward_layers, config.units, config.act) + if config.pred_discount: + self.disen_heads['discount'] = networks.DenseHead( + [], config.discount_layers, config.units, config.act, dist='binary') + + self._disen_model_opt = tools.Optimizer( + 'disen', config.model_lr, config.opt_eps, config.grad_clip, + config.weight_decay, opt=config.opt) + + self._disen_heads_opt = {} + self._disen_heads_opt['reward'] = tools.Optimizer( + 'disen_reward', config.model_lr, config.opt_eps, config.grad_clip, + config.weight_decay, opt=config.opt) + if config.pred_discount: + self._disen_heads_opt['discount'] = tools.Optimizer( + 'disen_pcont', config.model_lr, config.opt_eps, config.grad_clip, + config.weight_decay, opt=config.opt) + + # negative signs for reward/discount here + self._disen_scales = dict(disen_only=config.disen_only_scale, + reward=-config.disen_reward_scale, discount=-config.disen_discount_scale) + + self.disen_only_image_head = networks.ConvDecoder( + config.disen_cnn_depth, config.act, shape, config.decoder_kernels, + config.decoder_thin) + + ################ + # Joint Decode # + ################ + self.image_head = networks.ConvDecoderMask( + config.cnn_depth, config.act, shape, config.decoder_kernels, + config.decoder_thin) + self.disen_image_head = networks.ConvDecoderMask( + config.disen_cnn_depth, config.act, shape, config.decoder_kernels, + config.decoder_thin) + self.joint_image_head = networks.ConvDecoderMaskEnsemble( + self.image_head, self.disen_image_head + ) + + def train(self, data): + data = self.preprocess(data) + with tf.GradientTape() as model_tape, tf.GradientTape() as disen_tape: + + # kl schedule + kl_balance = tools.schedule(self._config.kl_balance, self._step) + kl_free = tools.schedule(self._config.kl_free, self._step) + kl_scale = tools.schedule(self._config.kl_scale, self._step) + + # Main + embed = self.encoder(data) + post, prior = self.dynamics.observe(embed, data['action']) + kl_loss, kl_value = self.dynamics.kl_loss( + post, prior, kl_balance, kl_free, kl_scale) + feat = self.dynamics.get_feat(post) + likes = {} + for name, head in self.heads.items(): + grad_head = (name in self._config.grad_heads) + inp = feat if grad_head else tf.stop_gradient(feat) + pred = head(inp, tf.float32) + like = pred.log_prob(tf.cast(data[name], tf.float32)) + likes[name] = tf.reduce_mean( + like) * self._scales.get(name, 1.0) + + # Disen + embed_disen = self.disen_encoder(data) + post_disen, prior_disen = self.disen_dynamics.observe( + embed_disen, data['action']) + kl_loss_disen, kl_value_disen = self.dynamics.kl_loss( + post_disen, prior_disen, kl_balance, kl_free, kl_scale) + feat_disen = self.disen_dynamics.get_feat(post_disen) + + # Optimize disen reward/pcont till optimal + disen_metrics = dict(reward={}, discount={}) + loss_disen = dict(reward=None, discount=None) + for _ in range(self._config.num_reward_opt_iters): + with tf.GradientTape() as disen_reward_tape, tf.GradientTape() as disen_pcont_tape: + disen_gradient_tapes = dict( + reward=disen_reward_tape, discount=disen_pcont_tape) + for name, head in self.disen_heads.items(): + pred_disen = head( + tf.stop_gradient(feat_disen), tf.float32) + loss_disen[name] = -tf.reduce_mean(pred_disen.log_prob( + tf.cast(data[name], tf.float32))) + for name, head in self.disen_heads.items(): + disen_metrics[name] = self._disen_heads_opt[name]( + disen_gradient_tapes[name], loss_disen[name], [head], prefix='disen_neg') + + # Compute likes for disen model (including negative gradients) + likes_disen = {} + for name, head in self.disen_heads.items(): + pred_disen = head(feat_disen, tf.float32) + like_disen = pred_disen.log_prob( + tf.cast(data[name], tf.float32)) + likes_disen[name] = tf.reduce_mean( + like_disen) * self._disen_scales.get(name, -1.0) + disen_only_image_pred = self.disen_only_image_head( + feat_disen, tf.float32) + disen_only_image_like = tf.reduce_mean(disen_only_image_pred.log_prob( + tf.cast(data['image'], tf.float32))) * self._disen_scales.get('disen_only', 1.0) + likes_disen['disen_only'] = disen_only_image_like + + # Joint decode + image_pred_joint, _, _, _ = self.joint_image_head( + feat, feat_disen, tf.float32) + image_like = tf.reduce_mean(image_pred_joint.log_prob( + tf.cast(data['image'], tf.float32))) + likes['image'] = image_like + likes_disen['image'] = image_like + + # Compute loss + model_loss = kl_loss - sum(likes.values()) + disen_loss = kl_loss_disen - sum(likes_disen.values()) + + model_parts = [self.encoder, self.dynamics, + self.joint_image_head] + list(self.heads.values()) + disen_parts = [self.disen_encoder, self.disen_dynamics, + self.joint_image_head, self.disen_only_image_head] + + metrics = self._model_opt( + model_tape, model_loss, model_parts, prefix='main') + disen_model_metrics = self._disen_model_opt( + disen_tape, disen_loss, disen_parts, prefix='disen') + + metrics['kl_balance'] = kl_balance + metrics['kl_free'] = kl_free + metrics['kl_scale'] = kl_scale + metrics.update({f'{name}_loss': -like for name, + like in likes.items()}) + + metrics['disen/disen_only_image_loss'] = -disen_only_image_like + metrics['disen/disen_reward_loss'] = -likes_disen['reward'] / \ + self._disen_scales.get('reward', -1.0) + metrics['disen/disen_discount_loss'] = -likes_disen['discount'] / \ + self._disen_scales.get('discount', -1.0) + + metrics['kl'] = tf.reduce_mean(kl_value) + metrics['prior_ent'] = self.dynamics.get_dist(prior).entropy() + metrics['post_ent'] = self.dynamics.get_dist(post).entropy() + metrics['disen/kl'] = tf.reduce_mean(kl_value_disen) + metrics['disen/prior_ent'] = self.dynamics.get_dist( + prior_disen).entropy() + metrics['disen/post_ent'] = self.dynamics.get_dist( + post_disen).entropy() + + metrics.update( + {f'{key}': value for key, value in disen_metrics['reward'].items()}) + metrics.update( + {f'{key}': value for key, value in disen_metrics['discount'].items()}) + metrics.update( + {f'{key}': value for key, value in disen_model_metrics.items()}) + + return embed, post, feat, kl_value, metrics + + @tf.function + def preprocess(self, obs): + dtype = prec.global_policy().compute_dtype + obs = obs.copy() + obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 + obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward']) + if 'discount' in obs: + obs['discount'] *= self._config.discount + for key, value in obs.items(): + if tf.dtypes.as_dtype(value.dtype) in ( + tf.float16, tf.float32, tf.float64): + obs[key] = tf.cast(value, dtype) + return obs + + @tf.function + def video_pred(self, data): + data = self.preprocess(data) + truth = data['image'][:6] + 0.5 + + embed = self.encoder(data) + embed_disen = self.disen_encoder(data) + states, _ = self.dynamics.observe( + embed[:6, :5], data['action'][:6, :5]) + states_disen, _ = self.disen_dynamics.observe( + embed_disen[:6, :5], data['action'][:6, :5]) + feats = self.dynamics.get_feat(states) + feats_disen = self.disen_dynamics.get_feat(states_disen) + recon_joint, recon_main, recon_disen, recon_mask = self.joint_image_head( + feats, feats_disen) + recon_joint = recon_joint.mode()[:6] + recon_main = recon_main.mode()[:6] + recon_disen = recon_disen.mode()[:6] + recon_mask = recon_mask[:6] + + init = {k: v[:, -1] for k, v in states.items()} + init_disen = {k: v[:, -1] for k, v in states_disen.items()} + prior = self.dynamics.imagine( + data['action'][:6, 5:], init) + prior_disen = self.disen_dynamics.imagine( + data['action'][:6, 5:], init_disen) + _feats = self.dynamics.get_feat(prior) + _feats_disen = self.disen_dynamics.get_feat(prior_disen) + openl_joint, openl_main, openl_disen, openl_mask = self.joint_image_head( + _feats, _feats_disen) + openl_joint = openl_joint.mode() + openl_main = openl_main.mode() + openl_disen = openl_disen.mode() + + model_joint = tf.concat( + [recon_joint[:, :5] + 0.5, openl_joint + 0.5], 1) + error_joint = (model_joint - truth + 1) / 2 + model_main = tf.concat( + [recon_main[:, :5] + 0.5, openl_main + 0.5], 1) + error_main = (model_main - truth + 1) / 2 + model_disen = tf.concat( + [recon_disen[:, :5] + 0.5, openl_disen + 0.5], 1) + error_disen = (model_disen - truth + 1) / 2 + model_mask = tf.concat( + [recon_mask[:, :5] + 0.5, openl_mask + 0.5], 1) + + output_joint = tf.concat([truth, model_joint, error_joint], 2) + output_main = tf.concat([truth, model_main, error_main], 2) + output_disen = tf.concat([truth, model_disen, error_disen], 2) + output_mask = model_mask + + return output_joint, output_main, output_disen, output_mask + + +class ImagBehavior(tools.Module): + + def __init__(self, config, world_model, stop_grad_actor=True, reward=None): + self._config = config + self._world_model = world_model + self._stop_grad_actor = stop_grad_actor + self._reward = reward + self.actor = networks.ActionHead( + config.num_actions, config.actor_layers, config.units, config.act, + config.actor_dist, config.actor_init_std, config.actor_min_std, + config.actor_dist, config.actor_temp, config.actor_outscale) + self.value = networks.DenseHead( + [], config.value_layers, config.units, config.act, + config.value_head) + if config.slow_value_target or config.slow_actor_target: + self._slow_value = networks.DenseHead( + [], config.value_layers, config.units, config.act) + self._updates = tf.Variable(0, tf.int64) + kw = dict(wd=config.weight_decay, opt=config.opt) + self._actor_opt = tools.Optimizer( + 'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw) + self._value_opt = tools.Optimizer( + 'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw) + + def train( + self, start, objective=None, imagine=None, tape=None, repeats=None): + objective = objective or self._reward + self._update_slow_target() + metrics = {} + with (tape or tf.GradientTape()) as actor_tape: + assert bool(objective) != bool(imagine) + if objective: + imag_feat, imag_state, imag_action = self._imagine( + start, self.actor, self._config.imag_horizon, repeats) + reward = objective(imag_feat, imag_state, imag_action) + else: + imag_feat, imag_state, imag_action, reward = imagine(start) + actor_ent = self.actor(imag_feat, tf.float32).entropy() + state_ent = self._world_model.dynamics.get_dist( + imag_state, tf.float32).entropy() + target, weights = self._compute_target( + imag_feat, reward, actor_ent, state_ent, + self._config.slow_actor_target) + actor_loss, mets = self._compute_actor_loss( + imag_feat, imag_state, imag_action, target, actor_ent, state_ent, + weights) + metrics.update(mets) + if self._config.slow_value_target != self._config.slow_actor_target: + target, weights = self._compute_target( + imag_feat, reward, actor_ent, state_ent, + self._config.slow_value_target) + with tf.GradientTape() as value_tape: + value = self.value(imag_feat, tf.float32)[:-1] + value_loss = -value.log_prob(tf.stop_gradient(target)) + if self._config.value_decay: + value_loss += self._config.value_decay * value.mode() + value_loss = tf.reduce_mean(weights[:-1] * value_loss) + metrics['reward_mean'] = tf.reduce_mean(reward) + metrics['reward_std'] = tf.math.reduce_std(reward) + metrics['actor_ent'] = tf.reduce_mean(actor_ent) + metrics.update(self._actor_opt(actor_tape, actor_loss, [self.actor])) + metrics.update(self._value_opt(value_tape, value_loss, [self.value])) + return imag_feat, imag_state, imag_action, weights, metrics + + def _imagine(self, start, policy, horizon, repeats=None): + dynamics = self._world_model.dynamics + if repeats: + start = {k: tf.repeat(v, repeats, axis=1) + for k, v in start.items()} + + def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in start.items()} + + def step(prev, _): + state, _, _ = prev + feat = dynamics.get_feat(state) + inp = tf.stop_gradient(feat) if self._stop_grad_actor else feat + action = policy(inp).sample() + succ = dynamics.img_step( + state, action, sample=self._config.imag_sample) + return succ, feat, action + feat = 0 * dynamics.get_feat(start) + action = policy(feat).mode() + succ, feats, actions = tools.static_scan( + step, tf.range(horizon), (start, feat, action)) + states = {k: tf.concat([ + start[k][None], v[:-1]], 0) for k, v in succ.items()} + if repeats: + def unfold(tensor): + s = tensor.shape + return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:]) + states, feats, actions = tf.nest.map_structure( + unfold, (states, feats, actions)) + return feats, states, actions + + def _compute_target(self, imag_feat, reward, actor_ent, state_ent, slow): + reward = tf.cast(reward, tf.float32) + if 'discount' in self._world_model.heads: + discount = self._world_model.heads['discount']( + imag_feat, tf.float32).mean() + else: + discount = self._config.discount * tf.ones_like(reward) + if self._config.future_entropy and tf.greater( + self._config.actor_entropy(), 0): + reward += self._config.actor_entropy() * actor_ent + if self._config.future_entropy and tf.greater( + self._config.actor_state_entropy(), 0): + reward += self._config.actor_state_entropy() * state_ent + if slow: + value = self._slow_value(imag_feat, tf.float32).mode() + else: + value = self.value(imag_feat, tf.float32).mode() + target = tools.lambda_return( + reward[:-1], value[:-1], discount[:-1], + bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0) + weights = tf.stop_gradient(tf.math.cumprod(tf.concat( + [tf.ones_like(discount[:1]), discount[:-1]], 0), 0)) + return target, weights + + def _compute_actor_loss( + self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, + weights): + metrics = {} + inp = tf.stop_gradient( + imag_feat) if self._stop_grad_actor else imag_feat + policy = self.actor(inp, tf.float32) + actor_ent = policy.entropy() + if self._config.imag_gradient == 'dynamics': + actor_target = target + elif self._config.imag_gradient == 'reinforce': + imag_action = tf.cast(imag_action, tf.float32) + actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient( + target - self.value(imag_feat[:-1], tf.float32).mode()) + elif self._config.imag_gradient == 'both': + imag_action = tf.cast(imag_action, tf.float32) + actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient( + target - self.value(imag_feat[:-1], tf.float32).mode()) + mix = self._config.imag_gradient_mix() + actor_target = mix * target + (1 - mix) * actor_target + metrics['imag_gradient_mix'] = mix + else: + raise NotImplementedError(self._config.imag_gradient) + if not self._config.future_entropy and tf.greater( + self._config.actor_entropy(), 0): + actor_target += self._config.actor_entropy() * actor_ent[:-1] + if not self._config.future_entropy and tf.greater( + self._config.actor_state_entropy(), 0): + actor_target += self._config.actor_state_entropy() * state_ent[:-1] + actor_loss = -tf.reduce_mean(weights[:-1] * actor_target) + return actor_loss, metrics + + def _update_slow_target(self): + if self._config.slow_value_target or self._config.slow_actor_target: + if self._updates % self._config.slow_target_update == 0: + mix = self._config.slow_target_fraction + for s, d in zip(self.value.variables, self._slow_value.variables): + d.assign(mix * s + (1 - mix) * d) + self._updates.assign_add(1) diff --git a/DreamerV2/networks.py b/DreamerV2/networks.py new file mode 100644 index 0000000..7a65d15 --- /dev/null +++ b/DreamerV2/networks.py @@ -0,0 +1,465 @@ +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers as tfkl +from tensorflow_probability import distributions as tfd +from tensorflow.keras.mixed_precision import experimental as prec + +import tools + +class RSSM(tools.Module): + + def __init__( + self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1, + shared=False, discrete=False, act=tf.nn.elu, mean_act='none', + std_act='softplus', min_std=0.1, cell='keras'): + super().__init__() + self._stoch = stoch + self._deter = deter + self._hidden = hidden + self._min_std = min_std + self._layers_input = layers_input + self._layers_output = layers_output + self._shared = shared + self._discrete = discrete + self._act = act + self._mean_act = mean_act + self._std_act = std_act + self._embed = None + if cell == 'gru': + self._cell = tfkl.GRUCell(self._deter) + elif cell == 'gru_layer_norm': + self._cell = GRUCell(self._deter, norm=True) + else: + raise NotImplementedError(cell) + + def initial(self, batch_size): + dtype = prec.global_policy().compute_dtype + if self._discrete: + state = dict( + logit=tf.zeros( + [batch_size, self._stoch, self._discrete], dtype), + stoch=tf.zeros( + [batch_size, self._stoch, self._discrete], dtype), + deter=self._cell.get_initial_state(None, batch_size, dtype)) + else: + state = dict( + mean=tf.zeros([batch_size, self._stoch], dtype), + std=tf.zeros([batch_size, self._stoch], dtype), + stoch=tf.zeros([batch_size, self._stoch], dtype), + deter=self._cell.get_initial_state(None, batch_size, dtype)) + return state + + @tf.function + def observe(self, embed, action, state=None): + def swap(x): return tf.transpose( + x, [1, 0] + list(range(2, len(x.shape)))) + if state is None: + state = self.initial(tf.shape(action)[0]) + embed, action = swap(embed), swap(action) + post, prior = tools.static_scan( + lambda prev, inputs: self.obs_step(prev[0], *inputs), + (action, embed), (state, state)) + post = {k: swap(v) for k, v in post.items()} + prior = {k: swap(v) for k, v in prior.items()} + return post, prior + + @tf.function + def imagine(self, action, state=None): + def swap(x): return tf.transpose( + x, [1, 0] + list(range(2, len(x.shape)))) + if state is None: + state = self.initial(tf.shape(action)[0]) + assert isinstance(state, dict), state + action = swap(action) + prior = tools.static_scan(self.img_step, action, state) + prior = {k: swap(v) for k, v in prior.items()} + return prior + + def get_feat(self, state): + stoch = state['stoch'] + if self._discrete: + shape = stoch.shape[:-2] + [self._stoch * self._discrete] + stoch = tf.reshape(stoch, shape) + return tf.concat([stoch, state['deter']], -1) + + def get_dist(self, state, dtype=None): + if self._discrete: + logit = state['logit'] + logit = tf.cast(logit, tf.float32) + dist = tfd.Independent(tools.OneHotDist(logit), 1) + if dtype != tf.float32: + dist = tools.DtypeDist(dist, dtype or state['logit'].dtype) + else: + mean, std = state['mean'], state['std'] + if dtype: + mean = tf.cast(mean, dtype) + std = tf.cast(std, dtype) + dist = tfd.MultivariateNormalDiag(mean, std) + return dist + + @tf.function + def obs_step(self, prev_state, prev_action, embed, sample=True): + if not self._embed: + self._embed = embed.shape[-1] + prior = self.img_step(prev_state, prev_action, None, sample) + if self._shared: + post = self.img_step(prev_state, prev_action, embed, sample) + else: + x = tf.concat([prior['deter'], embed], -1) + for i in range(self._layers_output): + x = self.get(f'obi{i}', tfkl.Dense, self._hidden, self._act)(x) + stats = self._suff_stats_layer('obs', x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + post = {'stoch': stoch, 'deter': prior['deter'], **stats} + return post, prior + + @tf.function + def img_step(self, prev_state, prev_action, embed=None, sample=True): + prev_stoch = prev_state['stoch'] + if self._discrete: + shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete] + prev_stoch = tf.reshape(prev_stoch, shape) + if self._shared: + if embed is None: + shape = prev_action.shape[:-1] + [self._embed] + embed = tf.zeros(shape, prev_action.dtype) + x = tf.concat([prev_stoch, prev_action, embed], -1) + else: + x = tf.concat([prev_stoch, prev_action], -1) + for i in range(self._layers_input): + x = self.get(f'ini{i}', tfkl.Dense, self._hidden, self._act)(x) + x, deter = self._cell(x, [prev_state['deter']]) + deter = deter[0] # Keras wraps the state in a list. + for i in range(self._layers_output): + x = self.get(f'imo{i}', tfkl.Dense, self._hidden, self._act)(x) + stats = self._suff_stats_layer('ims', x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + prior = {'stoch': stoch, 'deter': deter, **stats} + return prior + + def _suff_stats_layer(self, name, x): + if self._discrete: + x = self.get(name, tfkl.Dense, self._stoch * + self._discrete, None)(x) + logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete]) + return {'logit': logit} + else: + x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x) + mean, std = tf.split(x, 2, -1) + mean = { + 'none': lambda: mean, + 'tanh5': lambda: 5.0 * tf.math.tanh(mean / 5.0), + }[self._mean_act]() + std = { + 'softplus': lambda: tf.nn.softplus(std), + 'abs': lambda: tf.math.abs(std + 1), + 'sigmoid': lambda: tf.nn.sigmoid(std), + 'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2), + }[self._std_act]() + std = std + self._min_std + return {'mean': mean, 'std': std} + + def kl_loss(self, post, prior, balance, free, scale): + kld = tfd.kl_divergence + def dist(x): return self.get_dist(x, tf.float32) + if balance == 0.5: + value = kld(dist(prior), dist(post)) + loss = tf.reduce_mean(tf.maximum(value, free)) + else: + def sg(x): return tf.nest.map_structure(tf.stop_gradient, x) + value = kld(dist(prior), dist(sg(post))) + pri = tf.reduce_mean(value) + pos = tf.reduce_mean(kld(dist(sg(prior)), dist(post))) + pri, pos = tf.maximum(pri, free), tf.maximum(pos, free) + loss = balance * pri + (1 - balance) * pos + loss *= scale + return loss, value + + +class ConvEncoder(tools.Module): + + def __init__( + self, depth=32, act=tf.nn.relu, kernels=(4, 4, 4, 4)): + self._act = act + self._depth = depth + self._kernels = kernels + + def __call__(self, obs): + kwargs = dict(strides=2, activation=self._act) + Conv = tfkl.Conv2D + x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:])) + x = self.get('h1', Conv, 1 * self._depth, + self._kernels[0], **kwargs)(x) + x = self.get('h2', Conv, 2 * self._depth, + self._kernels[1], **kwargs)(x) + x = self.get('h3', Conv, 4 * self._depth, + self._kernels[2], **kwargs)(x) + x = self.get('h4', Conv, 8 * self._depth, + self._kernels[3], **kwargs)(x) + x = tf.reshape(x, [x.shape[0], np.prod(x.shape[1:])]) + shape = tf.concat([tf.shape(obs['image'])[:-3], [x.shape[-1]]], 0) + return tf.reshape(x, shape) + + +class ConvDecoder(tools.Module): + + def __init__( + self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6), + thin=True): + self._act = act + self._depth = depth + self._shape = shape + self._kernels = kernels + self._thin = thin + + def __call__(self, features, dtype=None): + kwargs = dict(strides=2, activation=self._act) + ConvT = tfkl.Conv2DTranspose + if self._thin: + x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features) + x = tf.reshape(x, [-1, 1, 1, 32 * self._depth]) + else: + x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features) + x = tf.reshape(x, [-1, 2, 2, 32 * self._depth]) + x = self.get('h2', ConvT, 4 * self._depth, + self._kernels[0], **kwargs)(x) + x = self.get('h3', ConvT, 2 * self._depth, + self._kernels[1], **kwargs)(x) + x = self.get('h4', ConvT, 1 * self._depth, + self._kernels[2], **kwargs)(x) + x = self.get( + 'h5', ConvT, self._shape[-1], self._kernels[3], strides=2)(x) + mean = tf.reshape(x, tf.concat( + [tf.shape(features)[:-1], self._shape], 0)) + if dtype: + mean = tf.cast(mean, dtype) + return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)) + + +class ConvDecoderMask(tools.Module): + + def __init__( + self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6), + thin=True): + self._act = act + self._depth = depth + self._shape = shape + self._kernels = kernels + self._thin = thin + + def __call__(self, features, dtype=None): + kwargs = dict(strides=2, activation=self._act) + ConvT = tfkl.Conv2DTranspose + if self._thin: + x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features) + x = tf.reshape(x, [-1, 1, 1, 32 * self._depth]) + else: + x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features) + x = tf.reshape(x, [-1, 2, 2, 32 * self._depth]) + x = self.get('h2', ConvT, 4 * self._depth, + self._kernels[0], **kwargs)(x) + x = self.get('h3', ConvT, 2 * self._depth, + self._kernels[1], **kwargs)(x) + x = self.get('h4', ConvT, 1 * self._depth, + self._kernels[2], **kwargs)(x) + x = self.get( + 'h5', ConvT, 2 * self._shape[-1], self._kernels[3], strides=2)(x) + mean, mask = tf.split(x, [self._shape[-1], self._shape[-1]], -1) + mean = tf.reshape(mean, tf.concat( + [tf.shape(features)[:-1], self._shape], 0)) + mask = tf.reshape(mask, tf.concat( + [tf.shape(features)[:-1], self._shape], 0)) + if dtype: + mean = tf.cast(mean, dtype) + mask = tf.cast(mask, dtype) + return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask + + +class ConvDecoderMaskEnsemble(tools.Module): + """ + ensemble two convdecoder with outputs + NOTE: remove pred1/pred2 for maximum performance. + """ + + def __init__(self, decoder1, decoder2): + self._decoder1 = decoder1 + self._decoder2 = decoder2 + self._shape = decoder1._shape + + def __call__(self, feat1, feat2, dtype=None): + kwargs = dict(strides=1, activation=tf.nn.sigmoid) + pred1, mask1 = self._decoder1(feat1, dtype) + pred2, mask2 = self._decoder2(feat2, dtype) + mean1 = pred1.submodules[0].loc + mean2 = pred2.submodules[0].loc + mask_feat = tf.concat([mask1, mask2], -1) + mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat) + mask_use1 = mask + mask_use2 = 1-mask + mean = mean1 * tf.cast(mask_use1, mean1.dtype) + \ + mean2 * tf.cast(mask_use2, mean2.dtype) + return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, mean1.dtype) + + +class DenseHead(tools.Module): + + def __init__( + self, shape, layers, units, act=tf.nn.elu, dist='normal', std=1.0): + self._shape = (shape,) if isinstance(shape, int) else shape + self._layers = layers + self._units = units + self._act = act + self._dist = dist + self._std = std + + def __call__(self, features, dtype=None): + x = features + for index in range(self._layers): + x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) + mean = self.get(f'hmean', tfkl.Dense, np.prod(self._shape))(x) + mean = tf.reshape(mean, tf.concat( + [tf.shape(features)[:-1], self._shape], 0)) + if self._std == 'learned': + std = self.get(f'hstd', tfkl.Dense, np.prod(self._shape))(x) + std = tf.nn.softplus(std) + 0.01 + std = tf.reshape(std, tf.concat( + [tf.shape(features)[:-1], self._shape], 0)) + else: + std = self._std + if dtype: + mean, std = tf.cast(mean, dtype), tf.cast(std, dtype) + if self._dist == 'normal': + return tfd.Independent(tfd.Normal(mean, std), len(self._shape)) + if self._dist == 'huber': + return tfd.Independent( + tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape)) + if self._dist == 'binary': + return tfd.Independent(tfd.Bernoulli(mean), len(self._shape)) + raise NotImplementedError(self._dist) + + +class ActionHead(tools.Module): + + def __init__( + self, size, layers, units, act=tf.nn.elu, dist='trunc_normal', + init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0): + # assert min_std <= 2 + self._size = size + self._layers = layers + self._units = units + self._dist = dist + self._act = act + self._min_std = min_std + self._init_std = init_std + self._action_disc = action_disc + self._temp = temp() if callable(temp) else temp + self._outscale = outscale + + def __call__(self, features, dtype=None): + x = features + for index in range(self._layers): + kw = {} + if index == self._layers - 1 and self._outscale: + kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling( + self._outscale) + x = self.get(f'h{index}', tfkl.Dense, + self._units, self._act, **kw)(x) + if self._dist == 'tanh_normal': + # https://www.desmos.com/calculator/rcmcf5jwe7 + x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) + if dtype: + x = tf.cast(x, dtype) + mean, std = tf.split(x, 2, -1) + mean = tf.tanh(mean) + std = tf.nn.softplus(std + self._init_std) + self._min_std + dist = tfd.Normal(mean, std) + dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) + dist = tfd.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == 'tanh_normal_5': + x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) + if dtype: + x = tf.cast(x, dtype) + mean, std = tf.split(x, 2, -1) + mean = 5 * tf.tanh(mean / 5) + std = tf.nn.softplus(std + 5) + 5 + dist = tfd.Normal(mean, std) + dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) + dist = tfd.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == 'normal': + x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) + if dtype: + x = tf.cast(x, dtype) + mean, std = tf.split(x, 2, -1) + std = tf.nn.softplus(std + self._init_std) + self._min_std + dist = tfd.Normal(mean, std) + dist = tfd.Independent(dist, 1) + elif self._dist == 'normal_1': + mean = self.get(f'hout', tfkl.Dense, self._size)(x) + if dtype: + mean = tf.cast(mean, dtype) + dist = tfd.Normal(mean, 1) + dist = tfd.Independent(dist, 1) + elif self._dist == 'trunc_normal': + # https://www.desmos.com/calculator/mmuvuhnyxo + x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) + x = tf.cast(x, tf.float32) + mean, std = tf.split(x, 2, -1) + mean = tf.tanh(mean) + std = 2 * tf.nn.sigmoid(std / 2) + self._min_std + dist = tools.SafeTruncatedNormal(mean, std, -1, 1) + dist = tools.DtypeDist(dist, dtype) + dist = tfd.Independent(dist, 1) + elif self._dist == 'onehot': + x = self.get(f'hout', tfkl.Dense, self._size)(x) + x = tf.cast(x, tf.float32) + dist = tools.OneHotDist(x, dtype=dtype) + dist = tools.DtypeDist(dist, dtype) + elif self._dist == 'onehot_gumble': + x = self.get(f'hout', tfkl.Dense, self._size)(x) + if dtype: + x = tf.cast(x, dtype) + temp = self._temp + dist = tools.GumbleDist(temp, x, dtype=dtype) + else: + raise NotImplementedError(self._dist) + return dist + + +class GRUCell(tf.keras.layers.AbstractRNNCell): + + def __init__(self, size, norm=False, act=tf.tanh, update_bias=-1, **kwargs): + super().__init__() + self._size = size + self._act = act + self._norm = norm + self._update_bias = update_bias + self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs) + if norm: + self._norm = tfkl.LayerNormalization(dtype=tf.float32) + + @property + def state_size(self): + return self._size + + def call(self, inputs, state): + state = state[0] # Keras wraps the state in a list. + parts = self._layer(tf.concat([inputs, state], -1)) + if self._norm: + dtype = parts.dtype + parts = tf.cast(parts, tf.float32) + parts = self._norm(parts) + parts = tf.cast(parts, dtype) + reset, cand, update = tf.split(parts, 3, -1) + reset = tf.nn.sigmoid(reset) + cand = self._act(reset * cand) + update = tf.nn.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, [output] diff --git a/DreamerV2/tools.py b/DreamerV2/tools.py new file mode 100644 index 0000000..5b0dbb1 --- /dev/null +++ b/DreamerV2/tools.py @@ -0,0 +1,694 @@ +import datetime +import io +import json +import pathlib +import pickle +import re +import time +import uuid + +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tf1 +import tensorflow_probability as tfp +from tensorflow.keras.mixed_precision import experimental as prec +from tensorflow_probability import distributions as tfd + + +# Patch to ignore seed to avoid synchronization across GPUs. +_orig_random_categorical = tf.random.categorical +def random_categorical(*args, **kwargs): + kwargs['seed'] = None + return _orig_random_categorical(*args, **kwargs) +tf.random.categorical = random_categorical + +# Patch to ignore seed to avoid synchronization across GPUs. +_orig_random_normal = tf.random.normal +def random_normal(*args, **kwargs): + kwargs['seed'] = None + return _orig_random_normal(*args, **kwargs) +tf.random.normal = random_normal + + +class AttrDict(dict): + + __setattr__ = dict.__setitem__ + __getattr__ = dict.__getitem__ + + +class Module(tf.Module): + + def save(self, filename): + values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) + amount = len(tf.nest.flatten(values)) + count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) + print(f'Save checkpoint with {amount} tensors and {count} parameters.') + with pathlib.Path(filename).open('wb') as f: + pickle.dump(values, f) + + def load(self, filename): + with pathlib.Path(filename).open('rb') as f: + values = pickle.load(f) + amount = len(tf.nest.flatten(values)) + count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) + print(f'Load checkpoint with {amount} tensors and {count} parameters.') + tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) + + def get(self, name, ctor, *args, **kwargs): + # Create or get layer by name to avoid mentioning it in the constructor. + if not hasattr(self, '_modules'): + self._modules = {} + if name not in self._modules: + self._modules[name] = ctor(*args, **kwargs) + return self._modules[name] + + +def var_nest_names(nest): + if isinstance(nest, dict): + items = ' '.join(f'{k}:{var_nest_names(v)}' for k, v in nest.items()) + return '{' + items + '}' + if isinstance(nest, (list, tuple)): + items = ' '.join(var_nest_names(v) for v in nest) + return '[' + items + ']' + if hasattr(nest, 'name') and hasattr(nest, 'shape'): + return nest.name + str(nest.shape).replace(', ', 'x') + if hasattr(nest, 'shape'): + return str(nest.shape).replace(', ', 'x') + return '?' + + +class Logger: + + def __init__(self, logdir, step): + self._logdir = logdir + self._writer = tf.summary.create_file_writer(str(logdir), max_queue=1000) + self._last_step = None + self._last_time = None + self._scalars = {} + self._images = {} + self._videos = {} + self.step = step + + def scalar(self, name, value): + self._scalars[name] = float(value) + + def image(self, name, value): + self._images[name] = np.array(value) + + def video(self, name, value): + self._videos[name] = np.array(value) + + def write(self, fps=False): + scalars = list(self._scalars.items()) + if fps: + scalars.append(('fps', self._compute_fps(self.step))) + print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) + with (self._logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n') + with self._writer.as_default(): + for name, value in scalars: + tf.summary.scalar('scalars/' + name, value, self.step) + for name, value in self._images.items(): + tf.summary.image(name, value, self.step) + for name, value in self._videos.items(): + video_summary(name, value, self.step) + self._writer.flush() + self._scalars = {} + self._images = {} + self._videos = {} + + def _compute_fps(self, step): + if self._last_step is None: + self._last_time = time.time() + self._last_step = step + return 0 + steps = step - self._last_step + duration = time.time() - self._last_time + self._last_time += duration + self._last_step = step + return steps / duration + + +def graph_summary(writer, step, fn, *args): + def inner(*args): + tf.summary.experimental.set_step(step.numpy().item()) + with writer.as_default(): + fn(*args) + return tf.numpy_function(inner, args, []) + + +def video_summary(name, video, step=None, fps=20): + name = name if isinstance(name, str) else name.decode('utf-8') + if np.issubdtype(video.dtype, np.floating): + video = np.clip(255 * video, 0, 255).astype(np.uint8) + B, T, H, W, C = video.shape + try: + frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) + summary = tf1.Summary() + image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C) + image.encoded_image_string = encode_gif(frames, fps) + summary.value.add(tag=name, image=image) + tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) + except (IOError, OSError) as e: + print('GIF summaries require ffmpeg in $PATH.', e) + frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) + tf.summary.image(name, frames, step) + + +def encode_gif(frames, fps): + from subprocess import Popen, PIPE + h, w, c = frames[0].shape + pxfmt = {1: 'gray', 3: 'rgb24'}[c] + cmd = ' '.join([ + f'ffmpeg -y -f rawvideo -vcodec rawvideo', + f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', + f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', + f'-r {fps:.02f} -f gif -']) + proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) + for image in frames: + proc.stdin.write(image.tostring()) + out, err = proc.communicate() + if proc.returncode: + raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) + del proc + return out + + +def simulate(agent, envs, steps=0, episodes=0, state=None): + # Initialize or unpack simulation state. + if state is None: + step, episode = 0, 0 + done = np.ones(len(envs), np.bool) + length = np.zeros(len(envs), np.int32) + obs = [None] * len(envs) + agent_state = None + else: + step, episode, done, length, obs, agent_state = state + while (steps and step < steps) or (episodes and episode < episodes): + # Reset envs if necessary. + if done.any(): + indices = [index for index, d in enumerate(done) if d] + results = [envs[i].reset() for i in indices] + for index, result in zip(indices, results): + obs[index] = result + # Step agents. + obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} + action, agent_state = agent(obs, done, agent_state) + if isinstance(action, dict): + action = [ + {k: np.array(action[k][i]) for k in action} + for i in range(len(envs))] + else: + action = np.array(action) + assert len(action) == len(envs) + # Step envs. + results = [e.step(a) for e, a in zip(envs, action)] + obs, _, done = zip(*[p[:3] for p in results]) + obs = list(obs) + done = np.stack(done) + episode += int(done.sum()) + length += 1 + step += (done * length).sum() + length *= (1 - done) + # import pdb + # pdb.set_trace() + # Return new state to allow resuming the simulation. + return (step - steps, episode - episodes, done, length, obs, agent_state) + + +def save_episodes(directory, episodes): + directory = pathlib.Path(directory).expanduser() + directory.mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') + filenames = [] + for episode in episodes: + identifier = str(uuid.uuid4().hex) + length = len(episode['reward']) + filename = directory / f'{timestamp}-{identifier}-{length}.npz' + with io.BytesIO() as f1: + np.savez_compressed(f1, **episode) + f1.seek(0) + with filename.open('wb') as f2: + f2.write(f1.read()) + filenames.append(filename) + return filenames + + +def sample_episodes(episodes, length=None, balance=False, seed=0): + random = np.random.RandomState(seed) + while True: + episode = random.choice(list(episodes.values())) + if length: + total = len(next(iter(episode.values()))) + available = total - length + if available < 1: + # print(f'Skipped short episode of length {available}.') + continue + if balance: + index = min(random.randint(0, total), available) + else: + index = int(random.randint(0, available + 1)) + episode = {k: v[index: index + length] for k, v in episode.items()} + yield episode + + +def load_episodes(directory, limit=None): + directory = pathlib.Path(directory).expanduser() + episodes = {} + total = 0 + for filename in reversed(sorted(directory.glob('*.npz'))): + try: + with filename.open('rb') as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f'Could not load episode: {e}') + continue + episodes[str(filename)] = episode + total += len(episode['reward']) - 1 + if limit and total >= limit: + break + return episodes + + +class DtypeDist: + + def __init__(self, dist, dtype=None): + self._dist = dist + self._dtype = dtype or prec.global_policy().compute_dtype + + @property + def name(self): + return 'DtypeDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + return tf.cast(self._dist.mean(), self._dtype) + + def mode(self): + return tf.cast(self._dist.mode(), self._dtype) + + def entropy(self): + return tf.cast(self._dist.entropy(), self._dtype) + + def sample(self, *args, **kwargs): + return tf.cast(self._dist.sample(*args, **kwargs), self._dtype) + + +class SampleDist: + + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples + + @property + def name(self): + return 'SampleDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + samples = self._dist.sample(self._samples) + return tf.reduce_mean(samples, 0) + + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return tf.gather(sample, tf.argmax(logprob))[0] + + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -tf.reduce_mean(logprob, 0) + + +class OneHotDist(tfd.OneHotCategorical): + + def __init__(self, logits=None, probs=None, dtype=None): + self._sample_dtype = dtype or prec.global_policy().compute_dtype + super().__init__(logits=logits, probs=probs) + + def mode(self): + return tf.cast(super().mode(), self._sample_dtype) + + def sample(self, sample_shape=(), seed=None): + # Straight through biased gradient estimator. + sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype) + probs = super().probs_parameter() + while len(probs.shape) < len(sample.shape): + probs = probs[None] + sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype) + return sample + + +class GumbleDist(tfd.RelaxedOneHotCategorical): + + def __init__(self, temp, logits=None, probs=None, dtype=None): + self._sample_dtype = dtype or prec.global_policy().compute_dtype + self._exact = tfd.OneHotCategorical(logits=logits, probs=probs) + super().__init__(temp, logits=logits, probs=probs) + + def mode(self): + return tf.cast(self._exact.mode(), self._sample_dtype) + + def entropy(self): + return tf.cast(self._exact.entropy(), self._sample_dtype) + + def sample(self, sample_shape=(), seed=None): + return tf.cast(super().sample(sample_shape, seed), self._sample_dtype) + + +class UnnormalizedHuber(tfd.Normal): + + def __init__(self, loc, scale, threshold=1, **kwargs): + self._threshold = tf.cast(threshold, loc.dtype) + super().__init__(loc, scale, **kwargs) + + def log_prob(self, event): + return -(tf.math.sqrt( + (event - self.mean()) ** 2 + self._threshold ** 2) - self._threshold) + + +class SafeTruncatedNormal(tfd.TruncatedNormal): + + def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): + super().__init__(loc, scale, low, high) + self._clip = clip + self._mult = mult + + def sample(self, *args, **kwargs): + event = super().sample(*args, **kwargs) + if self._clip: + clipped = tf.clip_by_value( + event, self.low + self._clip, self.high - self._clip) + event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped) + if self._mult: + event *= self._mult + return event + + +class TanhBijector(tfp.bijectors.Bijector): + + def __init__(self, validate_args=False, name='tanh'): + super().__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return tf.nn.tanh(x) + + def _inverse(self, y): + dtype = y.dtype + y = tf.cast(y, tf.float32) + y = tf.where( + tf.less_equal(tf.abs(y), 1.), + tf.clip_by_value(y, -0.99999997, 0.99999997), y) + y = tf.atanh(y) + y = tf.cast(y, dtype) + return y + + def _forward_log_det_jacobian(self, x): + log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) + return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) + + +def lambda_return( + reward, value, pcont, bootstrap, lambda_, axis): + # Setting lambda=1 gives a discounted Monte Carlo return. + # Setting lambda=0 gives a fixed 1-step return. + assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) + if isinstance(pcont, (int, float)): + pcont = pcont * tf.ones_like(reward) + dims = list(range(reward.shape.ndims)) + dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] + if axis != 0: + reward = tf.transpose(reward, dims) + value = tf.transpose(value, dims) + pcont = tf.transpose(pcont, dims) + if bootstrap is None: + bootstrap = tf.zeros_like(value[-1]) + next_values = tf.concat([value[1:], bootstrap[None]], 0) + inputs = reward + pcont * next_values * (1 - lambda_) + returns = static_scan( + lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, + (inputs, pcont), bootstrap, reverse=True) + if axis != 0: + returns = tf.transpose(returns, dims) + return returns + + +class Optimizer(tf.Module): + + def __init__( + self, name, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*', + opt='adam'): + assert 0 <= wd < 1 + assert not clip or 1 <= clip + self._name = name + self._clip = clip + self._wd = wd + self._wd_pattern = wd_pattern + self._opt = { + 'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps), + 'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps), + 'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps), + 'sgd': lambda: tf.optimizers.SGD(lr), + 'momentum': lambda: tf.optimizers.SGD(lr, 0.9), + }[opt]() + self._mixed = (prec.global_policy().compute_dtype == tf.float16) + if self._mixed: + self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic') + + @property + def variables(self): + return self._opt.variables() + + def __call__(self, tape, loss, modules, prefix=None): + assert loss.dtype is tf.float32, self._name + modules = modules if hasattr(modules, '__len__') else (modules,) + varibs = tf.nest.flatten([module.variables for module in modules]) + count = sum(np.prod(x.shape) for x in varibs) + print(f'Found {count} {self._name} parameters.') + assert len(loss.shape) == 0, loss.shape + tf.debugging.check_numerics(loss, self._name + '_loss') + if self._mixed: + with tape: + loss = self._opt.get_scaled_loss(loss) + grads = tape.gradient(loss, varibs) + if self._mixed: + grads = self._opt.get_unscaled_gradients(grads) + norm = tf.linalg.global_norm(grads) + if not self._mixed: + tf.debugging.check_numerics(norm, self._name + '_norm') + if self._clip: + grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) + if self._wd: + self._apply_weight_decay(varibs) + self._opt.apply_gradients(zip(grads, varibs)) + metrics = {} + if prefix: + metrics[f'{prefix}/{self._name}_loss'] = loss + metrics[f'{prefix}/{self._name}_grad_norm'] = norm + if self._mixed: + metrics[f'{prefix}/{self._name}_loss_scale'] = \ + self._opt.loss_scale._current_loss_scale + else: + metrics[f'{self._name}_loss'] = loss + metrics[f'{self._name}_grad_norm'] = norm + if self._mixed: + metrics[f'{self._name}_loss_scale'] = \ + self._opt.loss_scale._current_loss_scale + return metrics + + def _apply_weight_decay(self, varibs): + nontrivial = (self._wd_pattern != r'.*') + if nontrivial: + print('Applied weight decay to variables:') + for var in varibs: + if re.search(self._wd_pattern, self._name + '/' + var.name): + if nontrivial: + print('- ' + self._name + '/' + var.name) + var.assign((1 - self._wd) * var) + + +def args_type(default): + def parse_string(x): + if default is None: + return x + if isinstance(default, bool): + return bool(['False', 'True'].index(x)) + if isinstance(default, int): + return float(x) if ('e' in x or '.' in x) else int(x) + if isinstance(default, (list, tuple)): + return tuple(args_type(default[0])(y) for y in x.split(',')) + return type(default)(x) + def parse_object(x): + if isinstance(default, (list, tuple)): + return tuple(x) + return x + return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x) + + +def static_scan(fn, inputs, start, reverse=False): + last = start + outputs = [[] for _ in tf.nest.flatten(start)] + indices = range(len(tf.nest.flatten(inputs)[0])) + if reverse: + indices = reversed(indices) + for index in indices: + inp = tf.nest.map_structure(lambda x: x[index], inputs) + last = fn(last, inp) + [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] + if reverse: + outputs = [list(reversed(x)) for x in outputs] + outputs = [tf.stack(x, 0) for x in outputs] + return tf.nest.pack_sequence_as(start, outputs) + + +def uniform_mixture(dist, dtype=None): + if dist.batch_shape[-1] == 1: + return tfd.BatchReshape(dist, dist.batch_shape[:-1]) + dtype = dtype or prec.global_policy().compute_dtype + weights = tfd.Categorical(tf.zeros(dist.batch_shape, dtype)) + return tfd.MixtureSameFamily(weights, dist) + + +def cat_mixture_entropy(dist): + if isinstance(dist, tfd.MixtureSameFamily): + probs = dist.components_distribution.probs_parameter() + else: + probs = dist.probs_parameter() + return -tf.reduce_mean( + tf.reduce_mean(probs, 2) * + tf.math.log(tf.reduce_mean(probs, 2) + 1e-8), -1) + + +@tf.function +def cem_planner( + state, num_actions, horizon, proposals, topk, iterations, imagine, + objective): + dtype = prec.global_policy().compute_dtype + B, P = list(state.values())[0].shape[0], proposals + H, A = horizon, num_actions + flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()} + mean = tf.zeros((B, H, A), dtype) + std = tf.ones((B, H, A), dtype) + for _ in range(iterations): + proposals = tf.random.normal((B, P, H, A), dtype=dtype) + proposals = proposals * std[:, None] + mean[:, None] + proposals = tf.clip_by_value(proposals, -1, 1) + flat_proposals = tf.reshape(proposals, (B * P, H, A)) + states = imagine(flat_proposals, flat_state) + scores = objective(states) + scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P)) + _, indices = tf.math.top_k(scores, topk, sorted=False) + best = tf.gather(proposals, indices, axis=1, batch_dims=1) + mean, var = tf.nn.moments(best, 1) + std = tf.sqrt(var + 1e-6) + return mean[:, 0, :] + + +@tf.function +def grad_planner( + state, num_actions, horizon, proposals, iterations, imagine, objective, + kl_scale, step_size): + dtype = prec.global_policy().compute_dtype + B, P = list(state.values())[0].shape[0], proposals + H, A = horizon, num_actions + flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()} + mean = tf.zeros((B, H, A), dtype) + rawstd = 0.54 * tf.ones((B, H, A), dtype) + for _ in range(iterations): + proposals = tf.random.normal((B, P, H, A), dtype=dtype) + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(mean) + tape.watch(rawstd) + std = tf.nn.softplus(rawstd) + proposals = proposals * std[:, None] + mean[:, None] + proposals = ( + tf.stop_gradient(tf.clip_by_value(proposals, -1, 1)) + + proposals - tf.stop_gradient(proposals)) + flat_proposals = tf.reshape(proposals, (B * P, H, A)) + states = imagine(flat_proposals, flat_state) + scores = objective(states) + scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P)) + div = tfd.kl_divergence( + tfd.Normal(mean, std), + tfd.Normal(tf.zeros_like(mean), tf.ones_like(std))) + elbo = tf.reduce_sum(scores) - kl_scale * div + elbo /= tf.cast(tf.reduce_prod(tf.shape(scores)), dtype) + grad_mean, grad_rawstd = tape.gradient(elbo, [mean, rawstd]) + e, v = tf.nn.moments(grad_mean, [1, 2], keepdims=True) + grad_mean /= tf.sqrt(e * e + v + 1e-4) + e, v = tf.nn.moments(grad_rawstd, [1, 2], keepdims=True) + grad_rawstd /= tf.sqrt(e * e + v + 1e-4) + mean = tf.clip_by_value(mean + step_size * grad_mean, -1, 1) + rawstd = rawstd + step_size * grad_rawstd + return mean[:, 0, :] + + +class Every: + + def __init__(self, every): + self._every = every + self._last = None + + def __call__(self, step): + if not self._every: + return False + if self._last is None: + self._last = step + return True + if step >= self._last + self._every: + self._last += self._every + return True + return False + + +class Once: + + def __init__(self): + self._once = True + + def __call__(self): + if self._once: + self._once = False + return True + return False + + +class Until: + + def __init__(self, until): + self._until = until + + def __call__(self, step): + if not self._until: + return True + return step < self._until + + +def schedule(string, step): + try: + return float(string) + except ValueError: + step = tf.cast(step, tf.float32) + match = re.match(r'linear\((.+),(.+),(.+)\)', string) + if match: + initial, final, duration = [float(group) for group in match.groups()] + mix = tf.clip_by_value(step / duration, 0, 1) + return (1 - mix) * initial + mix * final + match = re.match(r'warmup\((.+),(.+)\)', string) + if match: + warmup, value = [float(group) for group in match.groups()] + scale = tf.clip_by_value(step / warmup, 0, 1) + return scale * value + match = re.match(r'exp\((.+),(.+),(.+)\)', string) + if match: + initial, final, halflife = [float(group) for group in match.groups()] + return (initial - final) * 0.5 ** (step / halflife) + final + raise NotImplementedError(string) diff --git a/DreamerV2/wrappers.py b/DreamerV2/wrappers.py new file mode 100644 index 0000000..7abd6e9 --- /dev/null +++ b/DreamerV2/wrappers.py @@ -0,0 +1,280 @@ +import threading + +import gym +import numpy as np + +class DeepMindControl: + + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): + domain, task = name.split('_', 1) + if domain == 'cup': # Only domain with multiple words. + domain = 'ball_in_cup' + if isinstance(domain, str): + from dm_control import suite + self._env = suite.load(domain, task) + else: + assert task is None + self._env = domain() + self._action_repeat = action_repeat + self._size = size + if camera is None: + camera = dict(quadruped=2).get(domain, 0) + self._camera = camera + + @property + def observation_space(self): + spaces = {} + for key, value in self._env.observation_spec().items(): + spaces[key] = gym.spaces.Box( + -np.inf, np.inf, value.shape, dtype=np.float32) + spaces['image'] = gym.spaces.Box( + 0, 255, self._size + (3,), dtype=np.uint8) + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + spec = self._env.action_spec() + return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + + def step(self, action): + assert np.isfinite(action).all(), action + reward = 0 + for _ in range(self._action_repeat): + time_step = self._env.step(action) + reward += time_step.reward or 0 + if time_step.last(): + break + obs = dict(time_step.observation) + obs['image'] = self.render() + done = time_step.last() + info = {'discount': np.array(time_step.discount, np.float32)} + return obs, reward, done, info + + def reset(self): + time_step = self._env.reset() + obs = dict(time_step.observation) + obs['image'] = self.render() + return obs + + def render(self, *args, **kwargs): + if kwargs.get('mode', 'rgb_array') != 'rgb_array': + raise ValueError("Only render mode 'rgb_array' is supported.") + return self._env.physics.render(*self._size, camera_id=self._camera) + + +class Atari: + + LOCK = threading.Lock() + + def __init__( + self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, + life_done=False, sticky_actions=True, all_actions=False): + assert size[0] == size[1] + import gym.wrappers + import gym.envs.atari + with self.LOCK: + env = gym.envs.atari.AtariEnv( + game=name, obs_type='image', frameskip=1, + repeat_action_probability=0.25 if sticky_actions else 0.0, + full_action_space=all_actions) + # Avoid unnecessary rendering in inner env. + env._get_obs = lambda: None + # Tell wrapper that the inner env has no action repeat. + env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') + env = gym.wrappers.AtariPreprocessing( + env, noops, action_repeat, size[0], life_done, grayscale) + self._env = env + self._grayscale = grayscale + + @property + def observation_space(self): + return gym.spaces.Dict({ + 'image': self._env.observation_space, + 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), + }) + + @property + def action_space(self): + return self._env.action_space + + def close(self): + return self._env.close() + + def reset(self): + with self.LOCK: + image = self._env.reset() + if self._grayscale: + image = image[..., None] + obs = {'image': image, 'ram': self._env.env._get_ram()} + return obs + + def step(self, action): + image, reward, done, info = self._env.step(action) + if self._grayscale: + image = image[..., None] + obs = {'image': image, 'ram': self._env.env._get_ram()} + return obs, reward, done, info + + def render(self, mode): + return self._env.render(mode) + +class CollectDataset: + + def __init__(self, env, callbacks=None, precision=32): + self._env = env + self._callbacks = callbacks or () + self._precision = precision + self._episode = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = {k: self._convert(v) for k, v in obs.items()} + transition = obs.copy() + transition['action'] = action + transition['reward'] = reward + transition['discount'] = info.get('discount', np.array(1 - float(done))) + self._episode.append(transition) + if done: + episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} + episode = {k: self._convert(v) for k, v in episode.items()} + info['episode'] = episode + for callback in self._callbacks: + callback(episode) + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + transition = obs.copy() + transition['action'] = np.zeros(self._env.action_space.shape) + transition['reward'] = 0.0 + transition['discount'] = 1.0 + self._episode = [transition] + return obs + + def _convert(self, value): + value = np.array(value) + if np.issubdtype(value.dtype, np.floating): + dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] + elif np.issubdtype(value.dtype, np.signedinteger): + dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] + elif np.issubdtype(value.dtype, np.uint8): + dtype = np.uint8 + else: + raise NotImplementedError(value.dtype) + return value.astype(dtype) + + +class TimeLimit: + + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + assert self._step is not None, 'Must reset environment.' + obs, reward, done, info = self._env.step(action) + self._step += 1 + if self._step >= self._duration: + done = True + if 'discount' not in info: + info['discount'] = np.array(1.0).astype(np.float32) + self._step = None + return obs, reward, done, info + + def reset(self): + self._step = 0 + return self._env.reset() + + +class NormalizeActions: + + def __init__(self, env): + self._env = env + self._mask = np.logical_and( + np.isfinite(env.action_space.low), + np.isfinite(env.action_space.high)) + self._low = np.where(self._mask, env.action_space.low, -1) + self._high = np.where(self._mask, env.action_space.high, 1) + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + low = np.where(self._mask, -np.ones_like(self._low), self._low) + high = np.where(self._mask, np.ones_like(self._low), self._high) + return gym.spaces.Box(low, high, dtype=np.float32) + + def step(self, action): + original = (action + 1) / 2 * (self._high - self._low) + self._low + original = np.where(self._mask, original, action) + return self._env.step(original) + + +class OneHotAction: + + def __init__(self, env): + assert isinstance(env.action_space, gym.spaces.Discrete) + self._env = env + self._random = np.random.RandomState() + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + shape = (self._env.action_space.n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + return space + + def step(self, action): + index = np.argmax(action).astype(int) + reference = np.zeros_like(action) + reference[index] = 1 + if not np.allclose(reference, action): + raise ValueError(f'Invalid one-hot action:\n{action}') + return self._env.step(index) + + def reset(self): + return self._env.reset() + + def _sample_action(self): + actions = self._env.action_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference + + +class RewardObs: + + def __init__(self, env): + self._env = env + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def observation_space(self): + spaces = self._env.observation_space.spaces + assert 'reward' not in spaces + spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) + return gym.spaces.Dict(spaces) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs['reward'] = reward + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + obs['reward'] = 0.0 + return obs