import tools import models from tensorflow_probability import distributions as tfd from tensorflow.keras.mixed_precision import experimental as prec import tensorflow as tf import numpy as np import collections import functools import json import time from env_tools import preprocess, count_steps def load_dataset(directory, config): episode = next(tools.load_episodes(directory, 1)) types = {k: v.dtype for k, v in episode.items()} shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()} def generator(): return tools.load_episodes( directory, config.train_steps, config.batch_length, config.dataset_balance) dataset = tf.data.Dataset.from_generator(generator, types, shapes) dataset = dataset.batch(config.batch_size, drop_remainder=True) dataset = dataset.map(functools.partial(preprocess, config=config)) dataset = dataset.prefetch(10) return dataset class Dreamer(tools.Module): def __init__(self, config, datadir, actspace, writer): self._c = config self._actspace = actspace self._actdim = actspace.n if hasattr( actspace, 'n') else actspace.shape[0] self._writer = writer self._random = np.random.RandomState(config.seed) self._should_pretrain = tools.Once() self._should_train = tools.Every(config.train_every) self._should_log = tools.Every(config.log_every) self._last_log = None self._last_time = time.time() self._metrics = collections.defaultdict(tf.metrics.Mean) self._metrics['expl_amount'] # Create variable for checkpoint. self._float = prec.global_policy().compute_dtype self._strategy = tf.distribute.MirroredStrategy() with tf.device('cpu:0'): self._step = tf.Variable(count_steps( datadir, config), dtype=tf.int64) with self._strategy.scope(): self._dataset = iter(self._strategy.experimental_distribute_dataset( load_dataset(datadir, self._c))) self._build_model() def __call__(self, obs, reset, state=None, training=True): step = self._step.numpy().item() tf.summary.experimental.set_step(step) if state is not None and reset.any(): mask = tf.cast(1 - reset, self._float)[:, None] state = tf.nest.map_structure(lambda x: x * mask, state) if self._should_train(step): log = self._should_log(step) n = self._c.pretrain if self._should_pretrain() else self._c.train_steps print(f'Training for {n} steps.') with self._strategy.scope(): for train_step in range(n): log_images = self._c.log_images and log and train_step == 0 self.train(next(self._dataset), log_images) if log: self._write_summaries() action, state = self.policy(obs, state, training) if training: self._step.assign_add(len(reset) * self._c.action_repeat) return action, state @tf.function def policy(self, obs, state, training): if state is None: latent = self._dynamics.initial(len(obs['image'])) action = tf.zeros((len(obs['image']), self._actdim), self._float) else: latent, action = state embed = self._encode(preprocess(obs, self._c)) latent, _ = self._dynamics.obs_step(latent, action, embed) feat = self._dynamics.get_feat(latent) if training: action = self._actor(feat).sample() else: action = self._actor(feat).mode() action = self._exploration(action, training) state = (latent, action) return action, state def load(self, filename): super().load(filename) self._should_pretrain() @tf.function() def train(self, data, log_images=False): self._strategy.experimental_run_v2( self._train, args=(data, log_images)) def _train(self, data, log_images): with tf.GradientTape() as model_tape: embed = self._encode(data) post, prior = self._dynamics.observe(embed, data['action']) feat = self._dynamics.get_feat(post) image_pred = self._decode(feat) reward_pred = self._reward(feat) likes = tools.AttrDict() likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) model_loss = self._c.kl_scale * div - sum(likes.values()) model_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as actor_tape: imag_feat = self._imagine_ahead(post) reward = self._reward(imag_feat).mode() if self._c.pcont: pcont = self._pcont(imag_feat).mean() else: pcont = self._c.discount * tf.ones_like(reward) value = self._value(imag_feat).mode() returns = tools.lambda_return( reward[:-1], value[:-1], pcont[:-1], bootstrap=value[-1], lambda_=self._c.disclam, axis=0) discount = tf.stop_gradient(tf.math.cumprod(tf.concat( [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) actor_loss = -tf.reduce_mean(discount * returns) actor_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as value_tape: value_pred = self._value(imag_feat)[:-1] target = tf.stop_gradient(returns) value_loss = - \ tf.reduce_mean(discount * value_pred.log_prob(target)) value_loss /= float(self._strategy.num_replicas_in_sync) model_norm = self._model_opt(model_tape, model_loss) actor_norm = self._actor_opt(actor_tape, actor_loss) value_norm = self._value_opt(value_tape, value_loss) if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: if self._c.log_scalars: self._scalar_summaries( data, feat, prior_dist, post_dist, likes, div, model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm) if tf.equal(log_images, True): self._image_summaries(data, embed, image_pred) def _build_model(self): acts = dict( elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, leaky_relu=tf.nn.leaky_relu) cnn_act = acts[self._c.cnn_act] act = acts[self._c.dense_act] self._encode = models.ConvEncoder( self._c.cnn_depth, cnn_act, self._c.image_size) self._dynamics = models.RSSM( self._c.stoch_size, self._c.deter_size, self._c.deter_size) self._decode = models.ConvDecoder( self._c.cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3)) self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) if self._c.pcont: self._pcont = models.DenseDecoder( (), 3, self._c.num_units, 'binary', act=act) self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) self._actor = models.ActionDecoder( self._actdim, 4, self._c.num_units, self._c.action_dist, init_std=self._c.action_init_std, act=act) model_modules = [self._encode, self._dynamics, self._decode, self._reward] if self._c.pcont: model_modules.append(self._pcont) Optimizer = functools.partial( tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, wdpattern=self._c.weight_decay_pattern) self._model_opt = Optimizer('model', model_modules, self._c.model_lr) self._value_opt = Optimizer('value', [self._value], self._c.value_lr) self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) self.train(next(self._dataset)) def _exploration(self, action, training): if training: amount = self._c.expl_amount if self._c.expl_decay: amount *= 0.5 ** (tf.cast(self._step, tf.float32) / self._c.expl_decay) if self._c.expl_min: amount = tf.maximum(self._c.expl_min, amount) self._metrics['expl_amount'].update_state(amount) elif self._c.eval_noise: amount = self._c.eval_noise else: return action if self._c.expl == 'additive_gaussian': return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) if self._c.expl == 'completely_random': return tf.random.uniform(action.shape, -1, 1) if self._c.expl == 'epsilon_greedy': indices = tfd.Categorical(0 * action).sample() # pylint: disable=unexpected-keyword-arg, no-value-for-parameter return tf.where( tf.random.uniform(action.shape[:1], 0, 1) < amount, tf.one_hot(indices, action.shape[-1], dtype=self._float), action) raise NotImplementedError(self._c.expl) def _imagine_ahead(self, post): if self._c.pcont: # Last step could be terminal. post = {k: v[:, :-1] for k, v in post.items()} def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:])) start = {k: flatten(v) for k, v in post.items()} def policy(state): return self._actor( tf.stop_gradient(self._dynamics.get_feat(state))).sample() states = tools.static_scan( lambda prev, _: self._dynamics.img_step(prev, policy(prev)), tf.range(self._c.horizon), start) imag_feat = self._dynamics.get_feat(states) return imag_feat def _scalar_summaries( self, data, feat, prior_dist, post_dist, likes, div, model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm): self._metrics['model_grad_norm'].update_state(model_norm) self._metrics['value_grad_norm'].update_state(value_norm) self._metrics['actor_grad_norm'].update_state(actor_norm) self._metrics['prior_ent'].update_state(prior_dist.entropy()) self._metrics['post_ent'].update_state(post_dist.entropy()) for name, logprob in likes.items(): self._metrics[name + '_loss'].update_state(-logprob) self._metrics['div'].update_state(div) self._metrics['model_loss'].update_state(model_loss) self._metrics['value_loss'].update_state(value_loss) self._metrics['actor_loss'].update_state(actor_loss) self._metrics['action_ent'].update_state(self._actor(feat).entropy()) def _image_summaries(self, data, embed, image_pred): truth = data['image'][:6] + 0.5 recon = image_pred.mode()[:6] init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5]) init = {k: v[:, -1] for k, v in init.items()} prior = self._dynamics.imagine(data['action'][:6, 5:], init) openl = self._decode(self._dynamics.get_feat(prior)).mode() model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) error = (model - truth + 1) / 2 openl = tf.concat([truth, model, error], 2) tools.graph_summary( self._writer, tools.video_summary, self._step, 'agent/openl', openl) def image_summary_from_data(self, data): truth = data['image'][:6] + 0.5 embed = self._encode(data) post, _ = self._dynamics.observe( embed[:6, :5], data['action'][:6, :5]) feat = self._dynamics.get_feat(post) init = {k: v[:, -1] for k, v in post.items()} recon = self._decode(feat).mode()[:6] prior = self._dynamics.imagine(data['action'][:6, 5:], init) openl = self._decode(self._dynamics.get_feat(prior)).mode() model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) error = (model - truth + 1) / 2 openl = tf.concat([truth, model, error], 2) tools.graph_summary( self._writer, tools.video_summary, self._step, 'agent/eval_openl', openl) def _write_summaries(self): step = int(self._step.numpy()) metrics = [(k, float(v.result())) for k, v in self._metrics.items()] if self._last_log is not None: duration = time.time() - self._last_time self._last_time += duration metrics.append(('fps', (step - self._last_log) / duration)) self._last_log = step [m.reset_states() for m in self._metrics.values()] with (self._c.logdir / 'metrics.jsonl').open('a') as f: f.write(json.dumps({'step': step, **dict(metrics)}) + '\n') [tf.summary.scalar('agent/' + k, m) for k, m in metrics] print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) self._writer.flush() class SeparationDreamer(Dreamer): def __init__(self, config, datadir, actspace, writer): self._metrics_disen = collections.defaultdict(tf.metrics.Mean) self._metrics_disen['expl_amount'] super().__init__(config, datadir, actspace, writer) def _train(self, data, log_images): with tf.GradientTape(persistent=True) as model_tape: # main embed = self._encode(data) post, prior = self._dynamics.observe(embed, data['action']) feat = self._dynamics.get_feat(post) # disen embed_disen = self._disen_encode(data) post_disen, prior_disen = self._disen_dynamics.observe( embed_disen, data['action']) feat_disen = self._disen_dynamics.get_feat(post_disen) # disen image pred image_pred_disen = self._disen_only_decode(feat_disen) # joint image pred image_pred_joint, image_pred_joint_main, image_pred_joint_disen, mask_pred = self._joint_decode( feat, feat_disen) # reward pred reward_pred = self._reward(feat) # optimize disen reward predictor till optimal for _ in range(self._c.num_reward_opt_iters): with tf.GradientTape() as disen_reward_tape: reward_pred_disen = self._disen_reward( tf.stop_gradient(feat_disen)) reward_like_disen = reward_pred_disen.log_prob( data['reward']) reward_loss_disen = -tf.reduce_mean(reward_like_disen) reward_loss_disen /= float( self._strategy.num_replicas_in_sync) reward_disen_norm = self._disen_reward_opt( disen_reward_tape, reward_loss_disen) # disen reward pred with optimal reward predictor reward_pred_disen = self._disen_reward(feat_disen) reward_like_disen = tf.reduce_mean( reward_pred_disen.log_prob(data['reward'])) # main model loss likes = tools.AttrDict() likes.image = tf.reduce_mean( image_pred_joint.log_prob(data['image'])) likes.reward = tf.reduce_mean(reward_pred.log_prob( data['reward'])) * self._c.reward_scale if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) model_loss = self._c.kl_scale * div - sum(likes.values()) model_loss /= float(self._strategy.num_replicas_in_sync) # disen model loss with reward negative gradient likes_disen = tools.AttrDict() likes_disen.image = tf.reduce_mean( image_pred_joint.log_prob(data['image'])) likes_disen.disen_only = tf.reduce_mean( image_pred_disen.log_prob(data['image'])) reward_like_disen = reward_pred_disen.log_prob(data['reward']) reward_like_disen = tf.reduce_mean(reward_like_disen) reward_loss_disen = -reward_like_disen prior_dist_disen = self._disen_dynamics.get_dist(prior_disen) post_dist_disen = self._disen_dynamics.get_dist(post_disen) div_disen = tf.reduce_mean(tfd.kl_divergence( post_dist_disen, prior_dist_disen)) div_disen = tf.maximum(div_disen, self._c.free_nats) model_loss_disen = div_disen * self._c.disen_kl_scale + \ reward_like_disen * self._c.disen_neg_rew_scale - \ likes_disen.image - likes_disen.disen_only * self._c.disen_rec_scale model_loss_disen /= float(self._strategy.num_replicas_in_sync) decode_loss = model_loss_disen + model_loss with tf.GradientTape() as actor_tape: imag_feat = self._imagine_ahead(post) reward = self._reward(imag_feat).mode() if self._c.pcont: pcont = self._pcont(imag_feat).mean() else: pcont = self._c.discount * tf.ones_like(reward) value = self._value(imag_feat).mode() returns = tools.lambda_return( reward[:-1], value[:-1], pcont[:-1], bootstrap=value[-1], lambda_=self._c.disclam, axis=0) discount = tf.stop_gradient(tf.math.cumprod(tf.concat( [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) actor_loss = -tf.reduce_mean(discount * returns) actor_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as value_tape: value_pred = self._value(imag_feat)[:-1] target = tf.stop_gradient(returns) value_loss = - \ tf.reduce_mean(discount * value_pred.log_prob(target)) value_loss /= float(self._strategy.num_replicas_in_sync) model_norm = self._model_opt(model_tape, model_loss) model_disen_norm = self._disen_opt(model_tape, model_loss_disen) decode_norm = self._decode_opt(model_tape, decode_loss) actor_norm = self._actor_opt(actor_tape, actor_loss) value_norm = self._value_opt(value_tape, value_loss) if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: if self._c.log_scalars: self._scalar_summaries( data, feat, prior_dist, post_dist, likes, div, model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm) self._scalar_summaries_disen( prior_dist_disen, post_dist_disen, likes_disen, div_disen, model_loss_disen, reward_loss_disen, model_disen_norm, reward_disen_norm) if tf.equal(log_images, True): self._image_summaries_joint( data, embed, embed_disen, image_pred_joint, mask_pred) self._image_summaries( self._disen_dynamics, self._disen_decode, data, embed_disen, image_pred_joint_disen, tag='disen/openl_joint_disen') self._image_summaries( self._disen_dynamics, self._disen_only_decode, data, embed_disen, image_pred_disen, tag='disen_only/openl_disen_only') self._image_summaries( self._dynamics, self._main_decode, data, embed, image_pred_joint_main, tag='main/openl_joint_main') def _build_model(self): acts = dict( elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, leaky_relu=tf.nn.leaky_relu) cnn_act = acts[self._c.cnn_act] act = acts[self._c.dense_act] # Distractor dynamic model self._disen_encode = models.ConvEncoder( self._c.disen_cnn_depth, cnn_act, self._c.image_size) self._disen_dynamics = models.RSSM( self._c.disen_stoch_size, self._c.disen_deter_size, self._c.disen_deter_size) self._disen_only_decode = models.ConvDecoder( self._c.disen_cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3)) self._disen_reward = models.DenseDecoder( (), 2, self._c.num_units, act=act) # Task dynamic model self._encode = models.ConvEncoder( self._c.cnn_depth, cnn_act, self._c.image_size) self._dynamics = models.RSSM( self._c.stoch_size, self._c.deter_size, self._c.deter_size) self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) if self._c.pcont: self._pcont = models.DenseDecoder( (), 3, self._c.num_units, 'binary', act=act) self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) self._actor = models.ActionDecoder( self._actdim, 4, self._c.num_units, self._c.action_dist, init_std=self._c.action_init_std, act=act) # Joint decode self._main_decode = models.ConvDecoderMask( self._c.cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3)) self._disen_decode = models.ConvDecoderMask( self._c.disen_cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3)) self._joint_decode = models.ConvDecoderMaskEnsemble( self._main_decode, self._disen_decode, self._c.precision ) disen_modules = [self._disen_encode, self._disen_dynamics, self._disen_only_decode] model_modules = [self._encode, self._dynamics, self._reward] if self._c.pcont: model_modules.append(self._pcont) Optimizer = functools.partial( tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, wdpattern=self._c.weight_decay_pattern) self._model_opt = Optimizer('model', model_modules, self._c.model_lr) self._disen_opt = Optimizer('disen', disen_modules, self._c.model_lr) self._decode_opt = Optimizer( 'decode', [self._joint_decode], self._c.model_lr) self._disen_reward_opt = Optimizer( 'disen_reward', [self._disen_reward], self._c.disen_reward_lr) self._value_opt = Optimizer('value', [self._value], self._c.value_lr) self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) self.train(next(self._dataset)) def _scalar_summaries_disen( self, prior_dist_disen, post_dist_disen, likes_disen, div_disen, model_loss_disen, reward_loss_disen, model_disen_norm, reward_disen_norm): self._metrics_disen['model_grad_norm'].update_state(model_disen_norm) self._metrics_disen['reward_grad_norm'].update_state(reward_disen_norm) self._metrics_disen['prior_ent'].update_state( prior_dist_disen.entropy()) self._metrics_disen['post_ent'].update_state(post_dist_disen.entropy()) for name, logprob in likes_disen.items(): self._metrics_disen[name + '_loss'].update_state(-logprob) self._metrics_disen['div'].update_state(div_disen) self._metrics_disen['model_loss'].update_state(model_loss_disen) self._metrics_disen['reward_loss'].update_state( reward_loss_disen) def _image_summaries(self, dynamics, decoder, data, embed, image_pred, tag='agent/openl'): truth = data['image'][:6] + 0.5 recon = image_pred.mode()[:6] init, _ = dynamics.observe(embed[:6, :5], data['action'][:6, :5]) init = {k: v[:, -1] for k, v in init.items()} prior = dynamics.imagine(data['action'][:6, 5:], init) if isinstance(decoder, models.ConvDecoderMask): openl, _ = decoder(dynamics.get_feat(prior)) openl = openl.mode() else: openl = decoder(dynamics.get_feat(prior)).mode() model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) error = (model - truth + 1) / 2 openl = tf.concat([truth, model, error], 2) tools.graph_summary( self._writer, tools.video_summary, self._step, tag, openl) def _image_summaries_joint(self, data, embed, embed_disen, image_pred_joint, mask_pred): truth = data['image'][:6] + 0.5 recon_joint = image_pred_joint.mode()[:6] mask_pred = mask_pred[:6] init, _ = self._dynamics.observe( embed[:6, :5], data['action'][:6, :5]) init_disen, _ = self._disen_dynamics.observe( embed_disen[:6, :5], data['action'][:6, :5]) init = {k: v[:, -1] for k, v in init.items()} init_disen = {k: v[:, -1] for k, v in init_disen.items()} prior = self._dynamics.imagine( data['action'][:6, 5:], init) prior_disen = self._disen_dynamics.imagine( data['action'][:6, 5:], init_disen) feat = self._dynamics.get_feat(prior) feat_disen = self._disen_dynamics.get_feat(prior_disen) openl, _, _, openl_mask = self._joint_decode(feat, feat_disen) openl = openl.mode() model = tf.concat([recon_joint[:, :5] + 0.5, openl + 0.5], 1) error = (model - truth + 1) / 2 openl = tf.concat([truth, model, error], 2) openl_mask = tf.concat([mask_pred[:, :5] + 0.5, openl_mask + 0.5], 1) tools.graph_summary( self._writer, tools.video_summary, self._step, 'joint/openl_joint', openl) tools.graph_summary( self._writer, tools.video_summary, self._step, 'mask/openl_mask', openl_mask) def image_summary_from_data(self, data): truth = data['image'][:6] + 0.5 # main embed = self._encode(data) post, _ = self._dynamics.observe( embed[:6, :5], data['action'][:6, :5]) feat = self._dynamics.get_feat(post) init = {k: v[:, -1] for k, v in post.items()} # disen embed_disen = self._disen_encode(data) post_disen, _ = self._disen_dynamics.observe( embed_disen[:6, :5], data['action'][:6, :5]) feat_disen = self._disen_dynamics.get_feat(post_disen) init_disen = {k: v[:, -1] for k, v in post_disen.items()} # joint image pred recon_joint, recon_main, recon_disen, recon_mask = self._joint_decode( feat, feat_disen) recon_joint = recon_joint.mode()[:6] recon_main = recon_main.mode()[:6] recon_disen = recon_disen.mode()[:6] recon_mask = recon_mask[:6] prior = self._dynamics.imagine( data['action'][:6, 5:], init) prior_disen = self._disen_dynamics.imagine( data['action'][:6, 5:], init_disen) feat = self._dynamics.get_feat(prior) feat_disen = self._disen_dynamics.get_feat(prior_disen) openl_joint, openl_main, openl_disen, openl_mask = self._joint_decode( feat, feat_disen) openl_joint = openl_joint.mode() openl_main = openl_main.mode() openl_disen = openl_disen.mode() model_joint = tf.concat( [recon_joint[:, :5] + 0.5, openl_joint + 0.5], 1) error_joint = (model_joint - truth + 1) / 2 model_main = tf.concat( [recon_main[:, :5] + 0.5, openl_main + 0.5], 1) model_disen = tf.concat( [recon_disen[:, :5] + 0.5, openl_disen + 0.5], 1) model_mask = tf.concat( [recon_mask[:, :5] + 0.5, openl_mask + 0.5], 1) output_joint = tf.concat( [truth, model_main, model_disen, model_joint, error_joint], 2) output_mask = model_mask tools.graph_summary( self._writer, tools.video_summary, self._step, 'summary/openl', output_joint) tools.graph_summary( self._writer, tools.video_summary, self._step, 'summary/openl_mask', output_mask) def _write_summaries(self): step = int(self._step.numpy()) metrics = [(k, float(v.result())) for k, v in self._metrics.items()] metrics_disen = [(k, float(v.result())) for k, v in self._metrics_disen.items()] if self._last_log is not None: duration = time.time() - self._last_time self._last_time += duration metrics.append(('fps', (step - self._last_log) / duration)) self._last_log = step [m.reset_states() for m in self._metrics.values()] [m.reset_states() for m in self._metrics_disen.values()] with (self._c.logdir / 'metrics.jsonl').open('a') as f: f.write(json.dumps({'step': step, **dict(metrics)}) + '\n') [tf.summary.scalar('agent/' + k, m) for k, m in metrics] [tf.summary.scalar('disen/' + k, m) for k, m in metrics_disen] print('#'*30 + ' Main ' + '#'*30) print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) print('#'*30 + ' Disen ' + '#'*30) print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics_disen)) self._writer.flush() class InverseDreamer(Dreamer): def __init__(self, config, datadir, actspace, writer): super().__init__(config, datadir, actspace, writer) def _train(self, data, log_images): with tf.GradientTape() as model_tape: embed = self._encode(data) post, prior = self._dynamics.observe(embed, data['action']) feat = self._dynamics.get_feat(post) action_pred = self._decode(feat) reward_pred = self._reward(feat) likes = tools.AttrDict() likes.action = tf.reduce_mean( action_pred.log_prob(data['action'][:, :-1])) likes.reward = tf.reduce_mean( reward_pred.log_prob(data['reward'])) if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) model_loss = self._c.kl_scale * div - sum(likes.values()) model_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as actor_tape: imag_feat = self._imagine_ahead(post) reward = self._reward(imag_feat).mode() if self._c.pcont: pcont = self._pcont(imag_feat).mean() else: pcont = self._c.discount * tf.ones_like(reward) value = self._value(imag_feat).mode() returns = tools.lambda_return( reward[:-1], value[:-1], pcont[:-1], bootstrap=value[-1], lambda_=self._c.disclam, axis=0) discount = tf.stop_gradient(tf.math.cumprod(tf.concat( [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) actor_loss = -tf.reduce_mean(discount * returns) actor_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as value_tape: value_pred = self._value(imag_feat)[:-1] target = tf.stop_gradient(returns) value_loss = - \ tf.reduce_mean(discount * value_pred.log_prob(target)) value_loss /= float(self._strategy.num_replicas_in_sync) model_norm = self._model_opt(model_tape, model_loss) actor_norm = self._actor_opt(actor_tape, actor_loss) value_norm = self._value_opt(value_tape, value_loss) if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: if self._c.log_scalars: self._scalar_summaries( data, feat, prior_dist, post_dist, likes, div, model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm) def _build_model(self): acts = dict( elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, leaky_relu=tf.nn.leaky_relu) cnn_act = acts[self._c.cnn_act] act = acts[self._c.dense_act] self._encode = models.ConvEncoder( self._c.cnn_depth, cnn_act, self._c.image_size) self._dynamics = models.RSSM( self._c.stoch_size, self._c.deter_size, self._c.deter_size) self._decode = models.InverseDecoder( self._actdim, 4, self._c.num_units, act=act) self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) if self._c.pcont: self._pcont = models.DenseDecoder( (), 3, self._c.num_units, 'binary', act=act) self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) self._actor = models.ActionDecoder( self._actdim, 4, self._c.num_units, self._c.action_dist, init_std=self._c.action_init_std, act=act) model_modules = [self._encode, self._dynamics, self._decode, self._reward] if self._c.pcont: model_modules.append(self._pcont) Optimizer = functools.partial( tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, wdpattern=self._c.weight_decay_pattern) self._model_opt = Optimizer('model', model_modules, self._c.model_lr) self._value_opt = Optimizer('value', [self._value], self._c.value_lr) self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) self.train(next(self._dataset))