298 lines
11 KiB
Python
298 lines
11 KiB
Python
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
from tensorflow.keras import layers as tfkl
|
||
|
from tensorflow_probability import distributions as tfd
|
||
|
from tensorflow.keras.mixed_precision import experimental as prec
|
||
|
|
||
|
import tools
|
||
|
|
||
|
|
||
|
class RSSM(tools.Module):
|
||
|
|
||
|
def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu):
|
||
|
super().__init__()
|
||
|
self._activation = act
|
||
|
self._stoch_size = stoch
|
||
|
self._deter_size = deter
|
||
|
self._hidden_size = hidden
|
||
|
self._cell = tfkl.GRUCell(self._deter_size)
|
||
|
|
||
|
def initial(self, batch_size):
|
||
|
dtype = prec.global_policy().compute_dtype
|
||
|
return dict(
|
||
|
mean=tf.zeros([batch_size, self._stoch_size], dtype),
|
||
|
std=tf.zeros([batch_size, self._stoch_size], dtype),
|
||
|
stoch=tf.zeros([batch_size, self._stoch_size], dtype),
|
||
|
deter=self._cell.get_initial_state(None, batch_size, dtype))
|
||
|
|
||
|
|
||
|
@tf.function
|
||
|
def observe(self, embed, action, state=None):
|
||
|
if state is None:
|
||
|
state = self.initial(tf.shape(action)[0])
|
||
|
embed = tf.transpose(embed, [1, 0, 2])
|
||
|
action = tf.transpose(action, [1, 0, 2])
|
||
|
post, prior = tools.static_scan(
|
||
|
lambda prev, inputs: self.obs_step(
|
||
|
prev[0], *inputs),
|
||
|
(action, embed), (state, state))
|
||
|
post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
|
||
|
prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
|
||
|
return post, prior
|
||
|
|
||
|
@tf.function
|
||
|
def imagine(self, action, state=None):
|
||
|
if state is None:
|
||
|
state = self.initial(tf.shape(action)[0])
|
||
|
assert isinstance(state, dict), state
|
||
|
action = tf.transpose(action, [1, 0, 2])
|
||
|
prior = tools.static_scan(self.img_step, action, state)
|
||
|
prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
|
||
|
return prior
|
||
|
|
||
|
def get_feat(self, state):
|
||
|
return tf.concat([state['stoch'], state['deter']], -1)
|
||
|
|
||
|
def get_dist(self, state):
|
||
|
return tfd.MultivariateNormalDiag(state['mean'], state['std'])
|
||
|
|
||
|
@tf.function
|
||
|
def obs_step(self, prev_state, prev_action, embed):
|
||
|
prior = self.img_step(prev_state, prev_action)
|
||
|
x = tf.concat([prior['deter'], embed], -1)
|
||
|
x = self.get('obs1', tfkl.Dense, self._hidden_size,
|
||
|
self._activation)(x)
|
||
|
x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x)
|
||
|
mean, std = tf.split(x, 2, -1)
|
||
|
std = tf.nn.softplus(std) + 0.1
|
||
|
stoch = self.get_dist({'mean': mean, 'std': std}).sample()
|
||
|
post = {'mean': mean, 'std': std,
|
||
|
'stoch': stoch, 'deter': prior['deter']}
|
||
|
return post, prior
|
||
|
|
||
|
@tf.function
|
||
|
def img_step(self, prev_state, prev_action):
|
||
|
x = tf.concat([prev_state['stoch'], prev_action], -1)
|
||
|
x = self.get('img1', tfkl.Dense, self._hidden_size,
|
||
|
self._activation)(x)
|
||
|
x, deter = self._cell(x, [prev_state['deter']])
|
||
|
deter = deter[0] # Keras wraps the state in a list.
|
||
|
x = self.get('img2', tfkl.Dense, self._hidden_size,
|
||
|
self._activation)(x)
|
||
|
x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x)
|
||
|
mean, std = tf.split(x, 2, -1)
|
||
|
std = tf.nn.softplus(std) + 0.1
|
||
|
stoch = self.get_dist({'mean': mean, 'std': std}).sample()
|
||
|
prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter}
|
||
|
return prior
|
||
|
|
||
|
|
||
|
class ConvEncoder(tools.Module):
|
||
|
|
||
|
def __init__(self, depth=32, act=tf.nn.relu, image_size=64):
|
||
|
self._act = act
|
||
|
self._depth = depth
|
||
|
self._image_size = image_size
|
||
|
|
||
|
if image_size == 64:
|
||
|
self._outdim = 32 * self._depth
|
||
|
self._kernel_sizes = [4, 4, 4, 4]
|
||
|
elif image_size == 32:
|
||
|
self._outdim = 8 * self._depth
|
||
|
self._kernel_sizes = [3, 3, 3, 3]
|
||
|
elif image_size == 84:
|
||
|
self._outdim = 72 * self._depth
|
||
|
self._kernel_sizes = [4, 4, 4, 4]
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def __call__(self, obs):
|
||
|
kwargs = dict(strides=2, activation=self._act)
|
||
|
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
|
||
|
x = self.get('h1', tfkl.Conv2D, 1 * self._depth,
|
||
|
self._kernel_sizes[0], **kwargs)(x)
|
||
|
x = self.get('h2', tfkl.Conv2D, 2 * self._depth,
|
||
|
self._kernel_sizes[1], **kwargs)(x)
|
||
|
x = self.get('h3', tfkl.Conv2D, 4 * self._depth,
|
||
|
self._kernel_sizes[2], **kwargs)(x)
|
||
|
x = self.get('h4', tfkl.Conv2D, 8 * self._depth,
|
||
|
self._kernel_sizes[3], **kwargs)(x)
|
||
|
shape = tf.concat([tf.shape(obs['image'])[:-3], [self._outdim]], 0)
|
||
|
return tf.reshape(x, shape)
|
||
|
|
||
|
|
||
|
class ConvDecoder(tools.Module):
|
||
|
|
||
|
def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)):
|
||
|
self._act = act
|
||
|
self._depth = depth
|
||
|
self._shape = shape
|
||
|
|
||
|
if shape[0] == 64:
|
||
|
self._outdim = 32 * self._depth
|
||
|
self._kernel_sizes = [5, 5, 6, 6]
|
||
|
elif shape[0] == 32:
|
||
|
self._outdim = 8 * self._depth
|
||
|
self._kernel_sizes = [3, 3, 3, 4]
|
||
|
elif shape[0] == 84:
|
||
|
self._outdim = 72 * self._depth
|
||
|
self._kernel_sizes = [7, 6, 6, 6]
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def __call__(self, features):
|
||
|
kwargs = dict(strides=2, activation=self._act)
|
||
|
x = self.get('h1', tfkl.Dense, self._outdim, None)(features)
|
||
|
x = tf.reshape(x, [-1, 1, 1, self._outdim])
|
||
|
x = self.get('h2', tfkl.Conv2DTranspose,
|
||
|
4 * self._depth, self._kernel_sizes[0], **kwargs)(x)
|
||
|
x = self.get('h3', tfkl.Conv2DTranspose,
|
||
|
2 * self._depth, self._kernel_sizes[1], **kwargs)(x)
|
||
|
x = self.get('h4', tfkl.Conv2DTranspose,
|
||
|
1 * self._depth, self._kernel_sizes[2], **kwargs)(x)
|
||
|
x = self.get('h5', tfkl.Conv2DTranspose,
|
||
|
self._shape[-1], self._kernel_sizes[3], strides=2)(x)
|
||
|
mean = tf.reshape(x, tf.concat(
|
||
|
[tf.shape(features)[:-1], self._shape], 0))
|
||
|
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
|
||
|
|
||
|
|
||
|
class ConvDecoderMask(tools.Module):
|
||
|
|
||
|
def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)):
|
||
|
self._act = act
|
||
|
self._depth = depth
|
||
|
self._shape = shape
|
||
|
|
||
|
if shape[0] == 64:
|
||
|
self._outdim = 32 * self._depth
|
||
|
self._kernel_sizes = [5, 5, 6, 6]
|
||
|
elif shape[0] == 32:
|
||
|
self._outdim = 8 * self._depth
|
||
|
self._kernel_sizes = [3, 3, 3, 4]
|
||
|
elif shape[0] == 84:
|
||
|
self._outdim = 72 * self._depth
|
||
|
self._kernel_sizes = [7, 6, 6, 6]
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def __call__(self, features):
|
||
|
kwargs = dict(strides=2, activation=self._act)
|
||
|
x = self.get('h1', tfkl.Dense, self._outdim, None)(features)
|
||
|
x = tf.reshape(x, [-1, 1, 1, self._outdim])
|
||
|
x = self.get('h2', tfkl.Conv2DTranspose,
|
||
|
4 * self._depth, self._kernel_sizes[0], **kwargs)(x)
|
||
|
x = self.get('h3', tfkl.Conv2DTranspose,
|
||
|
2 * self._depth, self._kernel_sizes[1], **kwargs)(x)
|
||
|
x = self.get('h4', tfkl.Conv2DTranspose,
|
||
|
1 * self._depth, self._kernel_sizes[2], **kwargs)(x)
|
||
|
x = self.get('h5', tfkl.Conv2DTranspose,
|
||
|
3 + self._shape[-1], self._kernel_sizes[3], strides=2)(x)
|
||
|
mean, mask = tf.split(x, [3, 3], -1)
|
||
|
mean = tf.reshape(mean, tf.concat(
|
||
|
[tf.shape(features)[:-1], self._shape], 0))
|
||
|
mask = tf.reshape(mask, tf.concat(
|
||
|
[tf.shape(features)[:-1], self._shape], 0))
|
||
|
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask
|
||
|
|
||
|
|
||
|
class ConvDecoderMaskEnsemble(tools.Module):
|
||
|
"""
|
||
|
ensemble two convdecoder with <Normal, mask> outputs
|
||
|
"""
|
||
|
|
||
|
def __init__(self, decoder1, decoder2, precision):
|
||
|
self._decoder1 = decoder1
|
||
|
self._decoder2 = decoder2
|
||
|
self._precision = 'float' + str(precision)
|
||
|
self._shape = decoder1._shape
|
||
|
|
||
|
def __call__(self, feat1, feat2):
|
||
|
kwargs = dict(strides=1, activation=tf.nn.sigmoid)
|
||
|
pred1, mask1 = self._decoder1(feat1)
|
||
|
pred2, mask2 = self._decoder2(feat2)
|
||
|
mean1 = pred1.submodules[0].loc
|
||
|
mean2 = pred2.submodules[0].loc
|
||
|
mask_feat = tf.concat([mask1, mask2], -1)
|
||
|
mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat)
|
||
|
mask_use1 = mask
|
||
|
mask_use2 = 1-mask
|
||
|
mean = mean1 * tf.cast(mask_use1, self._precision) + \
|
||
|
mean2 * tf.cast(mask_use2, self._precision)
|
||
|
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, self._precision)
|
||
|
|
||
|
|
||
|
class InverseDecoder(tools.Module):
|
||
|
|
||
|
def __init__(self, shape, layers, units, act=tf.nn.elu):
|
||
|
self._shape = shape
|
||
|
self._layers = layers
|
||
|
self._units = units
|
||
|
self._act = act
|
||
|
|
||
|
def __call__(self, features):
|
||
|
x = tf.concat([features[:, :-1], features[:, 1:]], -1)
|
||
|
for index in range(self._layers):
|
||
|
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
|
||
|
x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x)
|
||
|
return tfd.Independent(tfd.Normal(x, 1), 1)
|
||
|
|
||
|
|
||
|
class DenseDecoder(tools.Module):
|
||
|
|
||
|
def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu):
|
||
|
self._shape = shape
|
||
|
self._layers = layers
|
||
|
self._units = units
|
||
|
self._dist = dist
|
||
|
self._act = act
|
||
|
|
||
|
def __call__(self, features):
|
||
|
x = features
|
||
|
for index in range(self._layers):
|
||
|
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
|
||
|
x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x)
|
||
|
x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
|
||
|
if self._dist == 'normal':
|
||
|
return tfd.Independent(tfd.Normal(x, 1), len(self._shape))
|
||
|
if self._dist == 'binary':
|
||
|
return tfd.Independent(tfd.Bernoulli(x), len(self._shape))
|
||
|
raise NotImplementedError(self._dist)
|
||
|
|
||
|
|
||
|
class ActionDecoder(tools.Module):
|
||
|
|
||
|
def __init__(
|
||
|
self, size, layers, units, dist='tanh_normal', act=tf.nn.elu,
|
||
|
min_std=1e-4, init_std=5, mean_scale=5):
|
||
|
self._size = size
|
||
|
self._layers = layers
|
||
|
self._units = units
|
||
|
self._dist = dist
|
||
|
self._act = act
|
||
|
self._min_std = min_std
|
||
|
self._init_std = init_std
|
||
|
self._mean_scale = mean_scale
|
||
|
|
||
|
def __call__(self, features):
|
||
|
raw_init_std = np.log(np.exp(self._init_std) - 1)
|
||
|
x = features
|
||
|
for index in range(self._layers):
|
||
|
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
|
||
|
if self._dist == 'tanh_normal':
|
||
|
# https://www.desmos.com/calculator/rcmcf5jwe7
|
||
|
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||
|
mean, std = tf.split(x, 2, -1)
|
||
|
mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
|
||
|
std = tf.nn.softplus(std + raw_init_std) + self._min_std
|
||
|
dist = tfd.Normal(mean, std)
|
||
|
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
|
||
|
dist = tfd.Independent(dist, 1)
|
||
|
dist = tools.SampleDist(dist)
|
||
|
elif self._dist == 'onehot':
|
||
|
x = self.get(f'hout', tfkl.Dense, self._size)(x)
|
||
|
dist = tools.OneHotDist(x)
|
||
|
else:
|
||
|
raise NotImplementedError(dist)
|
||
|
return dist
|