import tensorflow as tf from tensorflow.keras.mixed_precision import experimental as prec import networks import tools class WorldModel(tools.Module): def __init__(self, step, config): self._step = step self._config = config channels = (1 if config.atari_grayscale else 3) shape = config.size + (channels,) ######## # Main # ######## self.encoder = networks.ConvEncoder( config.cnn_depth, config.act, config.encoder_kernels) self.dynamics = networks.RSSM( config.dyn_stoch, config.dyn_deter, config.dyn_hidden, config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared, config.dyn_discrete, config.act, config.dyn_mean_act, config.dyn_std_act, config.dyn_min_std, config.dyn_cell) self.heads = {} self.heads['reward'] = networks.DenseHead( [], config.reward_layers, config.units, config.act) if config.pred_discount: self.heads['discount'] = networks.DenseHead( [], config.discount_layers, config.units, config.act, dist='binary') self._model_opt = tools.Optimizer( 'model', config.model_lr, config.opt_eps, config.grad_clip, config.weight_decay, opt=config.opt) self._scales = dict( reward=config.reward_scale, discount=config.discount_scale) ######### # Disen # ######### self.disen_encoder = networks.ConvEncoder( config.disen_cnn_depth, config.act, config.encoder_kernels) self.disen_dynamics = networks.RSSM( config.dyn_stoch, config.dyn_deter, config.dyn_hidden, config.dyn_input_layers, config.dyn_output_layers, config.dyn_shared, config.dyn_discrete, config.act, config.dyn_mean_act, config.dyn_std_act, config.dyn_min_std, config.dyn_cell) self.disen_heads = {} self.disen_heads['reward'] = networks.DenseHead( [], config.reward_layers, config.units, config.act) if config.pred_discount: self.disen_heads['discount'] = networks.DenseHead( [], config.discount_layers, config.units, config.act, dist='binary') self._disen_model_opt = tools.Optimizer( 'disen', config.model_lr, config.opt_eps, config.grad_clip, config.weight_decay, opt=config.opt) self._disen_heads_opt = {} self._disen_heads_opt['reward'] = tools.Optimizer( 'disen_reward', config.model_lr, config.opt_eps, config.grad_clip, config.weight_decay, opt=config.opt) if config.pred_discount: self._disen_heads_opt['discount'] = tools.Optimizer( 'disen_pcont', config.model_lr, config.opt_eps, config.grad_clip, config.weight_decay, opt=config.opt) # negative signs for reward/discount here self._disen_scales = dict(disen_only=config.disen_only_scale, reward=-config.disen_reward_scale, discount=-config.disen_discount_scale) self.disen_only_image_head = networks.ConvDecoder( config.disen_cnn_depth, config.act, shape, config.decoder_kernels, config.decoder_thin) ################ # Joint Decode # ################ self.image_head = networks.ConvDecoderMask( config.cnn_depth, config.act, shape, config.decoder_kernels, config.decoder_thin) self.disen_image_head = networks.ConvDecoderMask( config.disen_cnn_depth, config.act, shape, config.decoder_kernels, config.decoder_thin) self.joint_image_head = networks.ConvDecoderMaskEnsemble( self.image_head, self.disen_image_head ) def train(self, data): data = self.preprocess(data) with tf.GradientTape() as model_tape, tf.GradientTape() as disen_tape: # kl schedule kl_balance = tools.schedule(self._config.kl_balance, self._step) kl_free = tools.schedule(self._config.kl_free, self._step) kl_scale = tools.schedule(self._config.kl_scale, self._step) # Main embed = self.encoder(data) post, prior = self.dynamics.observe(embed, data['action']) kl_loss, kl_value = self.dynamics.kl_loss( post, prior, kl_balance, kl_free, kl_scale) feat = self.dynamics.get_feat(post) likes = {} for name, head in self.heads.items(): grad_head = (name in self._config.grad_heads) inp = feat if grad_head else tf.stop_gradient(feat) pred = head(inp, tf.float32) like = pred.log_prob(tf.cast(data[name], tf.float32)) likes[name] = tf.reduce_mean( like) * self._scales.get(name, 1.0) # Disen embed_disen = self.disen_encoder(data) post_disen, prior_disen = self.disen_dynamics.observe( embed_disen, data['action']) kl_loss_disen, kl_value_disen = self.dynamics.kl_loss( post_disen, prior_disen, kl_balance, kl_free, kl_scale) feat_disen = self.disen_dynamics.get_feat(post_disen) # Optimize disen reward/pcont till optimal disen_metrics = dict(reward={}, discount={}) loss_disen = dict(reward=None, discount=None) for _ in range(self._config.num_reward_opt_iters): with tf.GradientTape() as disen_reward_tape, tf.GradientTape() as disen_pcont_tape: disen_gradient_tapes = dict( reward=disen_reward_tape, discount=disen_pcont_tape) for name, head in self.disen_heads.items(): pred_disen = head( tf.stop_gradient(feat_disen), tf.float32) loss_disen[name] = -tf.reduce_mean(pred_disen.log_prob( tf.cast(data[name], tf.float32))) for name, head in self.disen_heads.items(): disen_metrics[name] = self._disen_heads_opt[name]( disen_gradient_tapes[name], loss_disen[name], [head], prefix='disen_neg') # Compute likes for disen model (including negative gradients) likes_disen = {} for name, head in self.disen_heads.items(): pred_disen = head(feat_disen, tf.float32) like_disen = pred_disen.log_prob( tf.cast(data[name], tf.float32)) likes_disen[name] = tf.reduce_mean( like_disen) * self._disen_scales.get(name, -1.0) disen_only_image_pred = self.disen_only_image_head( feat_disen, tf.float32) disen_only_image_like = tf.reduce_mean(disen_only_image_pred.log_prob( tf.cast(data['image'], tf.float32))) * self._disen_scales.get('disen_only', 1.0) likes_disen['disen_only'] = disen_only_image_like # Joint decode image_pred_joint, _, _, _ = self.joint_image_head( feat, feat_disen, tf.float32) image_like = tf.reduce_mean(image_pred_joint.log_prob( tf.cast(data['image'], tf.float32))) likes['image'] = image_like likes_disen['image'] = image_like # Compute loss model_loss = kl_loss - sum(likes.values()) disen_loss = kl_loss_disen - sum(likes_disen.values()) model_parts = [self.encoder, self.dynamics, self.joint_image_head] + list(self.heads.values()) disen_parts = [self.disen_encoder, self.disen_dynamics, self.joint_image_head, self.disen_only_image_head] metrics = self._model_opt( model_tape, model_loss, model_parts, prefix='main') disen_model_metrics = self._disen_model_opt( disen_tape, disen_loss, disen_parts, prefix='disen') metrics['kl_balance'] = kl_balance metrics['kl_free'] = kl_free metrics['kl_scale'] = kl_scale metrics.update({f'{name}_loss': -like for name, like in likes.items()}) metrics['disen/disen_only_image_loss'] = -disen_only_image_like metrics['disen/disen_reward_loss'] = -likes_disen['reward'] / \ self._disen_scales.get('reward', -1.0) metrics['disen/disen_discount_loss'] = -likes_disen['discount'] / \ self._disen_scales.get('discount', -1.0) metrics['kl'] = tf.reduce_mean(kl_value) metrics['prior_ent'] = self.dynamics.get_dist(prior).entropy() metrics['post_ent'] = self.dynamics.get_dist(post).entropy() metrics['disen/kl'] = tf.reduce_mean(kl_value_disen) metrics['disen/prior_ent'] = self.dynamics.get_dist( prior_disen).entropy() metrics['disen/post_ent'] = self.dynamics.get_dist( post_disen).entropy() metrics.update( {f'{key}': value for key, value in disen_metrics['reward'].items()}) metrics.update( {f'{key}': value for key, value in disen_metrics['discount'].items()}) metrics.update( {f'{key}': value for key, value in disen_model_metrics.items()}) return embed, post, feat, kl_value, metrics @tf.function def preprocess(self, obs): dtype = prec.global_policy().compute_dtype obs = obs.copy() obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward']) if 'discount' in obs: obs['discount'] *= self._config.discount for key, value in obs.items(): if tf.dtypes.as_dtype(value.dtype) in ( tf.float16, tf.float32, tf.float64): obs[key] = tf.cast(value, dtype) return obs @tf.function def video_pred(self, data): data = self.preprocess(data) truth = data['image'][:6] + 0.5 embed = self.encoder(data) embed_disen = self.disen_encoder(data) states, _ = self.dynamics.observe( embed[:6, :5], data['action'][:6, :5]) states_disen, _ = self.disen_dynamics.observe( embed_disen[:6, :5], data['action'][:6, :5]) feats = self.dynamics.get_feat(states) feats_disen = self.disen_dynamics.get_feat(states_disen) recon_joint, recon_main, recon_disen, recon_mask = self.joint_image_head( feats, feats_disen) recon_joint = recon_joint.mode()[:6] recon_main = recon_main.mode()[:6] recon_disen = recon_disen.mode()[:6] recon_mask = recon_mask[:6] init = {k: v[:, -1] for k, v in states.items()} init_disen = {k: v[:, -1] for k, v in states_disen.items()} prior = self.dynamics.imagine( data['action'][:6, 5:], init) prior_disen = self.disen_dynamics.imagine( data['action'][:6, 5:], init_disen) _feats = self.dynamics.get_feat(prior) _feats_disen = self.disen_dynamics.get_feat(prior_disen) openl_joint, openl_main, openl_disen, openl_mask = self.joint_image_head( _feats, _feats_disen) openl_joint = openl_joint.mode() openl_main = openl_main.mode() openl_disen = openl_disen.mode() model_joint = tf.concat( [recon_joint[:, :5] + 0.5, openl_joint + 0.5], 1) error_joint = (model_joint - truth + 1) / 2 model_main = tf.concat( [recon_main[:, :5] + 0.5, openl_main + 0.5], 1) error_main = (model_main - truth + 1) / 2 model_disen = tf.concat( [recon_disen[:, :5] + 0.5, openl_disen + 0.5], 1) error_disen = (model_disen - truth + 1) / 2 model_mask = tf.concat( [recon_mask[:, :5] + 0.5, openl_mask + 0.5], 1) output_joint = tf.concat([truth, model_joint, error_joint], 2) output_main = tf.concat([truth, model_main, error_main], 2) output_disen = tf.concat([truth, model_disen, error_disen], 2) output_mask = model_mask return output_joint, output_main, output_disen, output_mask class ImagBehavior(tools.Module): def __init__(self, config, world_model, stop_grad_actor=True, reward=None): self._config = config self._world_model = world_model self._stop_grad_actor = stop_grad_actor self._reward = reward self.actor = networks.ActionHead( config.num_actions, config.actor_layers, config.units, config.act, config.actor_dist, config.actor_init_std, config.actor_min_std, config.actor_dist, config.actor_temp, config.actor_outscale) self.value = networks.DenseHead( [], config.value_layers, config.units, config.act, config.value_head) if config.slow_value_target or config.slow_actor_target: self._slow_value = networks.DenseHead( [], config.value_layers, config.units, config.act) self._updates = tf.Variable(0, tf.int64) kw = dict(wd=config.weight_decay, opt=config.opt) self._actor_opt = tools.Optimizer( 'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw) self._value_opt = tools.Optimizer( 'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw) def train( self, start, objective=None, imagine=None, tape=None, repeats=None): objective = objective or self._reward self._update_slow_target() metrics = {} with (tape or tf.GradientTape()) as actor_tape: assert bool(objective) != bool(imagine) if objective: imag_feat, imag_state, imag_action = self._imagine( start, self.actor, self._config.imag_horizon, repeats) reward = objective(imag_feat, imag_state, imag_action) else: imag_feat, imag_state, imag_action, reward = imagine(start) actor_ent = self.actor(imag_feat, tf.float32).entropy() state_ent = self._world_model.dynamics.get_dist( imag_state, tf.float32).entropy() target, weights = self._compute_target( imag_feat, reward, actor_ent, state_ent, self._config.slow_actor_target) actor_loss, mets = self._compute_actor_loss( imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights) metrics.update(mets) if self._config.slow_value_target != self._config.slow_actor_target: target, weights = self._compute_target( imag_feat, reward, actor_ent, state_ent, self._config.slow_value_target) with tf.GradientTape() as value_tape: value = self.value(imag_feat, tf.float32)[:-1] value_loss = -value.log_prob(tf.stop_gradient(target)) if self._config.value_decay: value_loss += self._config.value_decay * value.mode() value_loss = tf.reduce_mean(weights[:-1] * value_loss) metrics['reward_mean'] = tf.reduce_mean(reward) metrics['reward_std'] = tf.math.reduce_std(reward) metrics['actor_ent'] = tf.reduce_mean(actor_ent) metrics.update(self._actor_opt(actor_tape, actor_loss, [self.actor])) metrics.update(self._value_opt(value_tape, value_loss, [self.value])) return imag_feat, imag_state, imag_action, weights, metrics def _imagine(self, start, policy, horizon, repeats=None): dynamics = self._world_model.dynamics if repeats: start = {k: tf.repeat(v, repeats, axis=1) for k, v in start.items()} def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:])) start = {k: flatten(v) for k, v in start.items()} def step(prev, _): state, _, _ = prev feat = dynamics.get_feat(state) inp = tf.stop_gradient(feat) if self._stop_grad_actor else feat action = policy(inp).sample() succ = dynamics.img_step( state, action, sample=self._config.imag_sample) return succ, feat, action feat = 0 * dynamics.get_feat(start) action = policy(feat).mode() succ, feats, actions = tools.static_scan( step, tf.range(horizon), (start, feat, action)) states = {k: tf.concat([ start[k][None], v[:-1]], 0) for k, v in succ.items()} if repeats: def unfold(tensor): s = tensor.shape return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:]) states, feats, actions = tf.nest.map_structure( unfold, (states, feats, actions)) return feats, states, actions def _compute_target(self, imag_feat, reward, actor_ent, state_ent, slow): reward = tf.cast(reward, tf.float32) if 'discount' in self._world_model.heads: discount = self._world_model.heads['discount']( imag_feat, tf.float32).mean() else: discount = self._config.discount * tf.ones_like(reward) if self._config.future_entropy and tf.greater( self._config.actor_entropy(), 0): reward += self._config.actor_entropy() * actor_ent if self._config.future_entropy and tf.greater( self._config.actor_state_entropy(), 0): reward += self._config.actor_state_entropy() * state_ent if slow: value = self._slow_value(imag_feat, tf.float32).mode() else: value = self.value(imag_feat, tf.float32).mode() target = tools.lambda_return( reward[:-1], value[:-1], discount[:-1], bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0) weights = tf.stop_gradient(tf.math.cumprod(tf.concat( [tf.ones_like(discount[:1]), discount[:-1]], 0), 0)) return target, weights def _compute_actor_loss( self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights): metrics = {} inp = tf.stop_gradient( imag_feat) if self._stop_grad_actor else imag_feat policy = self.actor(inp, tf.float32) actor_ent = policy.entropy() if self._config.imag_gradient == 'dynamics': actor_target = target elif self._config.imag_gradient == 'reinforce': imag_action = tf.cast(imag_action, tf.float32) actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient( target - self.value(imag_feat[:-1], tf.float32).mode()) elif self._config.imag_gradient == 'both': imag_action = tf.cast(imag_action, tf.float32) actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient( target - self.value(imag_feat[:-1], tf.float32).mode()) mix = self._config.imag_gradient_mix() actor_target = mix * target + (1 - mix) * actor_target metrics['imag_gradient_mix'] = mix else: raise NotImplementedError(self._config.imag_gradient) if not self._config.future_entropy and tf.greater( self._config.actor_entropy(), 0): actor_target += self._config.actor_entropy() * actor_ent[:-1] if not self._config.future_entropy and tf.greater( self._config.actor_state_entropy(), 0): actor_target += self._config.actor_state_entropy() * state_ent[:-1] actor_loss = -tf.reduce_mean(weights[:-1] * actor_target) return actor_loss, metrics def _update_slow_target(self): if self._config.slow_value_target or self._config.slow_actor_target: if self._updates % self._config.slow_target_update == 0: mix = self._config.slow_target_fraction for s, d in zip(self.value.variables, self._slow_value.variables): d.assign(mix * s + (1 - mix) * d) self._updates.assign_add(1)