import numpy as np import tensorflow as tf from tensorflow.keras import layers as tfkl from tensorflow_probability import distributions as tfd from tensorflow.keras.mixed_precision import experimental as prec import tools class RSSM(tools.Module): def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu): super().__init__() self._activation = act self._stoch_size = stoch self._deter_size = deter self._hidden_size = hidden self._cell = tfkl.GRUCell(self._deter_size) def initial(self, batch_size): dtype = prec.global_policy().compute_dtype return dict( mean=tf.zeros([batch_size, self._stoch_size], dtype), std=tf.zeros([batch_size, self._stoch_size], dtype), stoch=tf.zeros([batch_size, self._stoch_size], dtype), deter=self._cell.get_initial_state(None, batch_size, dtype)) @tf.function def observe(self, embed, action, state=None): if state is None: state = self.initial(tf.shape(action)[0]) embed = tf.transpose(embed, [1, 0, 2]) action = tf.transpose(action, [1, 0, 2]) post, prior = tools.static_scan( lambda prev, inputs: self.obs_step( prev[0], *inputs), (action, embed), (state, state)) post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()} prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} return post, prior @tf.function def imagine(self, action, state=None): if state is None: state = self.initial(tf.shape(action)[0]) assert isinstance(state, dict), state action = tf.transpose(action, [1, 0, 2]) prior = tools.static_scan(self.img_step, action, state) prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} return prior def get_feat(self, state): return tf.concat([state['stoch'], state['deter']], -1) def get_dist(self, state): return tfd.MultivariateNormalDiag(state['mean'], state['std']) @tf.function def obs_step(self, prev_state, prev_action, embed): prior = self.img_step(prev_state, prev_action) x = tf.concat([prior['deter'], embed], -1) x = self.get('obs1', tfkl.Dense, self._hidden_size, self._activation)(x) x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x) mean, std = tf.split(x, 2, -1) std = tf.nn.softplus(std) + 0.1 stoch = self.get_dist({'mean': mean, 'std': std}).sample() post = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': prior['deter']} return post, prior @tf.function def img_step(self, prev_state, prev_action): x = tf.concat([prev_state['stoch'], prev_action], -1) x = self.get('img1', tfkl.Dense, self._hidden_size, self._activation)(x) x, deter = self._cell(x, [prev_state['deter']]) deter = deter[0] # Keras wraps the state in a list. x = self.get('img2', tfkl.Dense, self._hidden_size, self._activation)(x) x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x) mean, std = tf.split(x, 2, -1) std = tf.nn.softplus(std) + 0.1 stoch = self.get_dist({'mean': mean, 'std': std}).sample() prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter} return prior class ConvEncoder(tools.Module): def __init__(self, depth=32, act=tf.nn.relu, image_size=64): self._act = act self._depth = depth self._image_size = image_size if image_size == 64: self._outdim = 32 * self._depth self._kernel_sizes = [4, 4, 4, 4] elif image_size == 32: self._outdim = 8 * self._depth self._kernel_sizes = [3, 3, 3, 3] elif image_size == 84: self._outdim = 72 * self._depth self._kernel_sizes = [4, 4, 4, 4] else: raise NotImplementedError def __call__(self, obs): kwargs = dict(strides=2, activation=self._act) x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:])) x = self.get('h1', tfkl.Conv2D, 1 * self._depth, self._kernel_sizes[0], **kwargs)(x) x = self.get('h2', tfkl.Conv2D, 2 * self._depth, self._kernel_sizes[1], **kwargs)(x) x = self.get('h3', tfkl.Conv2D, 4 * self._depth, self._kernel_sizes[2], **kwargs)(x) x = self.get('h4', tfkl.Conv2D, 8 * self._depth, self._kernel_sizes[3], **kwargs)(x) shape = tf.concat([tf.shape(obs['image'])[:-3], [self._outdim]], 0) return tf.reshape(x, shape) class ConvDecoder(tools.Module): def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)): self._act = act self._depth = depth self._shape = shape if shape[0] == 64: self._outdim = 32 * self._depth self._kernel_sizes = [5, 5, 6, 6] elif shape[0] == 32: self._outdim = 8 * self._depth self._kernel_sizes = [3, 3, 3, 4] elif shape[0] == 84: self._outdim = 72 * self._depth self._kernel_sizes = [7, 6, 6, 6] else: raise NotImplementedError def __call__(self, features): kwargs = dict(strides=2, activation=self._act) x = self.get('h1', tfkl.Dense, self._outdim, None)(features) x = tf.reshape(x, [-1, 1, 1, self._outdim]) x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, self._kernel_sizes[0], **kwargs)(x) x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, self._kernel_sizes[1], **kwargs)(x) x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, self._kernel_sizes[2], **kwargs)(x) x = self.get('h5', tfkl.Conv2DTranspose, self._shape[-1], self._kernel_sizes[3], strides=2)(x) mean = tf.reshape(x, tf.concat( [tf.shape(features)[:-1], self._shape], 0)) return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)) class ConvDecoderMask(tools.Module): def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)): self._act = act self._depth = depth self._shape = shape if shape[0] == 64: self._outdim = 32 * self._depth self._kernel_sizes = [5, 5, 6, 6] elif shape[0] == 32: self._outdim = 8 * self._depth self._kernel_sizes = [3, 3, 3, 4] elif shape[0] == 84: self._outdim = 72 * self._depth self._kernel_sizes = [7, 6, 6, 6] else: raise NotImplementedError def __call__(self, features): kwargs = dict(strides=2, activation=self._act) x = self.get('h1', tfkl.Dense, self._outdim, None)(features) x = tf.reshape(x, [-1, 1, 1, self._outdim]) x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, self._kernel_sizes[0], **kwargs)(x) x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, self._kernel_sizes[1], **kwargs)(x) x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, self._kernel_sizes[2], **kwargs)(x) x = self.get('h5', tfkl.Conv2DTranspose, 3 + self._shape[-1], self._kernel_sizes[3], strides=2)(x) mean, mask = tf.split(x, [3, 3], -1) mean = tf.reshape(mean, tf.concat( [tf.shape(features)[:-1], self._shape], 0)) mask = tf.reshape(mask, tf.concat( [tf.shape(features)[:-1], self._shape], 0)) return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask class ConvDecoderMaskEnsemble(tools.Module): """ ensemble two convdecoder with outputs """ def __init__(self, decoder1, decoder2, precision): self._decoder1 = decoder1 self._decoder2 = decoder2 self._precision = 'float' + str(precision) self._shape = decoder1._shape def __call__(self, feat1, feat2): kwargs = dict(strides=1, activation=tf.nn.sigmoid) pred1, mask1 = self._decoder1(feat1) pred2, mask2 = self._decoder2(feat2) mean1 = pred1.submodules[0].loc mean2 = pred2.submodules[0].loc mask_feat = tf.concat([mask1, mask2], -1) mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat) mask_use1 = mask mask_use2 = 1-mask mean = mean1 * tf.cast(mask_use1, self._precision) + \ mean2 * tf.cast(mask_use2, self._precision) return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, self._precision) class InverseDecoder(tools.Module): def __init__(self, shape, layers, units, act=tf.nn.elu): self._shape = shape self._layers = layers self._units = units self._act = act def __call__(self, features): x = tf.concat([features[:, :-1], features[:, 1:]], -1) for index in range(self._layers): x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x) return tfd.Independent(tfd.Normal(x, 1), 1) class DenseDecoder(tools.Module): def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu): self._shape = shape self._layers = layers self._units = units self._dist = dist self._act = act def __call__(self, features): x = features for index in range(self._layers): x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x) x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) if self._dist == 'normal': return tfd.Independent(tfd.Normal(x, 1), len(self._shape)) if self._dist == 'binary': return tfd.Independent(tfd.Bernoulli(x), len(self._shape)) raise NotImplementedError(self._dist) class ActionDecoder(tools.Module): def __init__( self, size, layers, units, dist='tanh_normal', act=tf.nn.elu, min_std=1e-4, init_std=5, mean_scale=5): self._size = size self._layers = layers self._units = units self._dist = dist self._act = act self._min_std = min_std self._init_std = init_std self._mean_scale = mean_scale def __call__(self, features): raw_init_std = np.log(np.exp(self._init_std) - 1) x = features for index in range(self._layers): x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) if self._dist == 'tanh_normal': # https://www.desmos.com/calculator/rcmcf5jwe7 x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) mean, std = tf.split(x, 2, -1) mean = self._mean_scale * tf.tanh(mean / self._mean_scale) std = tf.nn.softplus(std + raw_init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'onehot': x = self.get(f'hout', tfkl.Dense, self._size)(x) dist = tools.OneHotDist(x) else: raise NotImplementedError(dist) return dist