466 lines
18 KiB
Python
466 lines
18 KiB
Python
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.keras import layers as tfkl
|
|
from tensorflow_probability import distributions as tfd
|
|
from tensorflow.keras.mixed_precision import experimental as prec
|
|
|
|
import tools
|
|
|
|
class RSSM(tools.Module):
|
|
|
|
def __init__(
|
|
self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1,
|
|
shared=False, discrete=False, act=tf.nn.elu, mean_act='none',
|
|
std_act='softplus', min_std=0.1, cell='keras'):
|
|
super().__init__()
|
|
self._stoch = stoch
|
|
self._deter = deter
|
|
self._hidden = hidden
|
|
self._min_std = min_std
|
|
self._layers_input = layers_input
|
|
self._layers_output = layers_output
|
|
self._shared = shared
|
|
self._discrete = discrete
|
|
self._act = act
|
|
self._mean_act = mean_act
|
|
self._std_act = std_act
|
|
self._embed = None
|
|
if cell == 'gru':
|
|
self._cell = tfkl.GRUCell(self._deter)
|
|
elif cell == 'gru_layer_norm':
|
|
self._cell = GRUCell(self._deter, norm=True)
|
|
else:
|
|
raise NotImplementedError(cell)
|
|
|
|
def initial(self, batch_size):
|
|
dtype = prec.global_policy().compute_dtype
|
|
if self._discrete:
|
|
state = dict(
|
|
logit=tf.zeros(
|
|
[batch_size, self._stoch, self._discrete], dtype),
|
|
stoch=tf.zeros(
|
|
[batch_size, self._stoch, self._discrete], dtype),
|
|
deter=self._cell.get_initial_state(None, batch_size, dtype))
|
|
else:
|
|
state = dict(
|
|
mean=tf.zeros([batch_size, self._stoch], dtype),
|
|
std=tf.zeros([batch_size, self._stoch], dtype),
|
|
stoch=tf.zeros([batch_size, self._stoch], dtype),
|
|
deter=self._cell.get_initial_state(None, batch_size, dtype))
|
|
return state
|
|
|
|
@tf.function
|
|
def observe(self, embed, action, state=None):
|
|
def swap(x): return tf.transpose(
|
|
x, [1, 0] + list(range(2, len(x.shape))))
|
|
if state is None:
|
|
state = self.initial(tf.shape(action)[0])
|
|
embed, action = swap(embed), swap(action)
|
|
post, prior = tools.static_scan(
|
|
lambda prev, inputs: self.obs_step(prev[0], *inputs),
|
|
(action, embed), (state, state))
|
|
post = {k: swap(v) for k, v in post.items()}
|
|
prior = {k: swap(v) for k, v in prior.items()}
|
|
return post, prior
|
|
|
|
@tf.function
|
|
def imagine(self, action, state=None):
|
|
def swap(x): return tf.transpose(
|
|
x, [1, 0] + list(range(2, len(x.shape))))
|
|
if state is None:
|
|
state = self.initial(tf.shape(action)[0])
|
|
assert isinstance(state, dict), state
|
|
action = swap(action)
|
|
prior = tools.static_scan(self.img_step, action, state)
|
|
prior = {k: swap(v) for k, v in prior.items()}
|
|
return prior
|
|
|
|
def get_feat(self, state):
|
|
stoch = state['stoch']
|
|
if self._discrete:
|
|
shape = stoch.shape[:-2] + [self._stoch * self._discrete]
|
|
stoch = tf.reshape(stoch, shape)
|
|
return tf.concat([stoch, state['deter']], -1)
|
|
|
|
def get_dist(self, state, dtype=None):
|
|
if self._discrete:
|
|
logit = state['logit']
|
|
logit = tf.cast(logit, tf.float32)
|
|
dist = tfd.Independent(tools.OneHotDist(logit), 1)
|
|
if dtype != tf.float32:
|
|
dist = tools.DtypeDist(dist, dtype or state['logit'].dtype)
|
|
else:
|
|
mean, std = state['mean'], state['std']
|
|
if dtype:
|
|
mean = tf.cast(mean, dtype)
|
|
std = tf.cast(std, dtype)
|
|
dist = tfd.MultivariateNormalDiag(mean, std)
|
|
return dist
|
|
|
|
@tf.function
|
|
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
|
if not self._embed:
|
|
self._embed = embed.shape[-1]
|
|
prior = self.img_step(prev_state, prev_action, None, sample)
|
|
if self._shared:
|
|
post = self.img_step(prev_state, prev_action, embed, sample)
|
|
else:
|
|
x = tf.concat([prior['deter'], embed], -1)
|
|
for i in range(self._layers_output):
|
|
x = self.get(f'obi{i}', tfkl.Dense, self._hidden, self._act)(x)
|
|
stats = self._suff_stats_layer('obs', x)
|
|
if sample:
|
|
stoch = self.get_dist(stats).sample()
|
|
else:
|
|
stoch = self.get_dist(stats).mode()
|
|
post = {'stoch': stoch, 'deter': prior['deter'], **stats}
|
|
return post, prior
|
|
|
|
@tf.function
|
|
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
|
prev_stoch = prev_state['stoch']
|
|
if self._discrete:
|
|
shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
|
|
prev_stoch = tf.reshape(prev_stoch, shape)
|
|
if self._shared:
|
|
if embed is None:
|
|
shape = prev_action.shape[:-1] + [self._embed]
|
|
embed = tf.zeros(shape, prev_action.dtype)
|
|
x = tf.concat([prev_stoch, prev_action, embed], -1)
|
|
else:
|
|
x = tf.concat([prev_stoch, prev_action], -1)
|
|
for i in range(self._layers_input):
|
|
x = self.get(f'ini{i}', tfkl.Dense, self._hidden, self._act)(x)
|
|
x, deter = self._cell(x, [prev_state['deter']])
|
|
deter = deter[0] # Keras wraps the state in a list.
|
|
for i in range(self._layers_output):
|
|
x = self.get(f'imo{i}', tfkl.Dense, self._hidden, self._act)(x)
|
|
stats = self._suff_stats_layer('ims', x)
|
|
if sample:
|
|
stoch = self.get_dist(stats).sample()
|
|
else:
|
|
stoch = self.get_dist(stats).mode()
|
|
prior = {'stoch': stoch, 'deter': deter, **stats}
|
|
return prior
|
|
|
|
def _suff_stats_layer(self, name, x):
|
|
if self._discrete:
|
|
x = self.get(name, tfkl.Dense, self._stoch *
|
|
self._discrete, None)(x)
|
|
logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
|
|
return {'logit': logit}
|
|
else:
|
|
x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
|
|
mean, std = tf.split(x, 2, -1)
|
|
mean = {
|
|
'none': lambda: mean,
|
|
'tanh5': lambda: 5.0 * tf.math.tanh(mean / 5.0),
|
|
}[self._mean_act]()
|
|
std = {
|
|
'softplus': lambda: tf.nn.softplus(std),
|
|
'abs': lambda: tf.math.abs(std + 1),
|
|
'sigmoid': lambda: tf.nn.sigmoid(std),
|
|
'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2),
|
|
}[self._std_act]()
|
|
std = std + self._min_std
|
|
return {'mean': mean, 'std': std}
|
|
|
|
def kl_loss(self, post, prior, balance, free, scale):
|
|
kld = tfd.kl_divergence
|
|
def dist(x): return self.get_dist(x, tf.float32)
|
|
if balance == 0.5:
|
|
value = kld(dist(prior), dist(post))
|
|
loss = tf.reduce_mean(tf.maximum(value, free))
|
|
else:
|
|
def sg(x): return tf.nest.map_structure(tf.stop_gradient, x)
|
|
value = kld(dist(prior), dist(sg(post)))
|
|
pri = tf.reduce_mean(value)
|
|
pos = tf.reduce_mean(kld(dist(sg(prior)), dist(post)))
|
|
pri, pos = tf.maximum(pri, free), tf.maximum(pos, free)
|
|
loss = balance * pri + (1 - balance) * pos
|
|
loss *= scale
|
|
return loss, value
|
|
|
|
|
|
class ConvEncoder(tools.Module):
|
|
|
|
def __init__(
|
|
self, depth=32, act=tf.nn.relu, kernels=(4, 4, 4, 4)):
|
|
self._act = act
|
|
self._depth = depth
|
|
self._kernels = kernels
|
|
|
|
def __call__(self, obs):
|
|
kwargs = dict(strides=2, activation=self._act)
|
|
Conv = tfkl.Conv2D
|
|
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
|
|
x = self.get('h1', Conv, 1 * self._depth,
|
|
self._kernels[0], **kwargs)(x)
|
|
x = self.get('h2', Conv, 2 * self._depth,
|
|
self._kernels[1], **kwargs)(x)
|
|
x = self.get('h3', Conv, 4 * self._depth,
|
|
self._kernels[2], **kwargs)(x)
|
|
x = self.get('h4', Conv, 8 * self._depth,
|
|
self._kernels[3], **kwargs)(x)
|
|
x = tf.reshape(x, [x.shape[0], np.prod(x.shape[1:])])
|
|
shape = tf.concat([tf.shape(obs['image'])[:-3], [x.shape[-1]]], 0)
|
|
return tf.reshape(x, shape)
|
|
|
|
|
|
class ConvDecoder(tools.Module):
|
|
|
|
def __init__(
|
|
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
|
|
thin=True):
|
|
self._act = act
|
|
self._depth = depth
|
|
self._shape = shape
|
|
self._kernels = kernels
|
|
self._thin = thin
|
|
|
|
def __call__(self, features, dtype=None):
|
|
kwargs = dict(strides=2, activation=self._act)
|
|
ConvT = tfkl.Conv2DTranspose
|
|
if self._thin:
|
|
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
|
|
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
|
|
else:
|
|
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
|
|
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
|
|
x = self.get('h2', ConvT, 4 * self._depth,
|
|
self._kernels[0], **kwargs)(x)
|
|
x = self.get('h3', ConvT, 2 * self._depth,
|
|
self._kernels[1], **kwargs)(x)
|
|
x = self.get('h4', ConvT, 1 * self._depth,
|
|
self._kernels[2], **kwargs)(x)
|
|
x = self.get(
|
|
'h5', ConvT, self._shape[-1], self._kernels[3], strides=2)(x)
|
|
mean = tf.reshape(x, tf.concat(
|
|
[tf.shape(features)[:-1], self._shape], 0))
|
|
if dtype:
|
|
mean = tf.cast(mean, dtype)
|
|
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
|
|
|
|
|
|
class ConvDecoderMask(tools.Module):
|
|
|
|
def __init__(
|
|
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
|
|
thin=True):
|
|
self._act = act
|
|
self._depth = depth
|
|
self._shape = shape
|
|
self._kernels = kernels
|
|
self._thin = thin
|
|
|
|
def __call__(self, features, dtype=None):
|
|
kwargs = dict(strides=2, activation=self._act)
|
|
ConvT = tfkl.Conv2DTranspose
|
|
if self._thin:
|
|
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
|
|
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
|
|
else:
|
|
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
|
|
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
|
|
x = self.get('h2', ConvT, 4 * self._depth,
|
|
self._kernels[0], **kwargs)(x)
|
|
x = self.get('h3', ConvT, 2 * self._depth,
|
|
self._kernels[1], **kwargs)(x)
|
|
x = self.get('h4', ConvT, 1 * self._depth,
|
|
self._kernels[2], **kwargs)(x)
|
|
x = self.get(
|
|
'h5', ConvT, 2 * self._shape[-1], self._kernels[3], strides=2)(x)
|
|
mean, mask = tf.split(x, [self._shape[-1], self._shape[-1]], -1)
|
|
mean = tf.reshape(mean, tf.concat(
|
|
[tf.shape(features)[:-1], self._shape], 0))
|
|
mask = tf.reshape(mask, tf.concat(
|
|
[tf.shape(features)[:-1], self._shape], 0))
|
|
if dtype:
|
|
mean = tf.cast(mean, dtype)
|
|
mask = tf.cast(mask, dtype)
|
|
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask
|
|
|
|
|
|
class ConvDecoderMaskEnsemble(tools.Module):
|
|
"""
|
|
ensemble two convdecoder with <Normal, mask> outputs
|
|
NOTE: remove pred1/pred2 for maximum performance.
|
|
"""
|
|
|
|
def __init__(self, decoder1, decoder2):
|
|
self._decoder1 = decoder1
|
|
self._decoder2 = decoder2
|
|
self._shape = decoder1._shape
|
|
|
|
def __call__(self, feat1, feat2, dtype=None):
|
|
kwargs = dict(strides=1, activation=tf.nn.sigmoid)
|
|
pred1, mask1 = self._decoder1(feat1, dtype)
|
|
pred2, mask2 = self._decoder2(feat2, dtype)
|
|
mean1 = pred1.submodules[0].loc
|
|
mean2 = pred2.submodules[0].loc
|
|
mask_feat = tf.concat([mask1, mask2], -1)
|
|
mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat)
|
|
mask_use1 = mask
|
|
mask_use2 = 1-mask
|
|
mean = mean1 * tf.cast(mask_use1, mean1.dtype) + \
|
|
mean2 * tf.cast(mask_use2, mean2.dtype)
|
|
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, mean1.dtype)
|
|
|
|
|
|
class DenseHead(tools.Module):
|
|
|
|
def __init__(
|
|
self, shape, layers, units, act=tf.nn.elu, dist='normal', std=1.0):
|
|
self._shape = (shape,) if isinstance(shape, int) else shape
|
|
self._layers = layers
|
|
self._units = units
|
|
self._act = act
|
|
self._dist = dist
|
|
self._std = std
|
|
|
|
def __call__(self, features, dtype=None):
|
|
x = features
|
|
for index in range(self._layers):
|
|
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
|
|
mean = self.get(f'hmean', tfkl.Dense, np.prod(self._shape))(x)
|
|
mean = tf.reshape(mean, tf.concat(
|
|
[tf.shape(features)[:-1], self._shape], 0))
|
|
if self._std == 'learned':
|
|
std = self.get(f'hstd', tfkl.Dense, np.prod(self._shape))(x)
|
|
std = tf.nn.softplus(std) + 0.01
|
|
std = tf.reshape(std, tf.concat(
|
|
[tf.shape(features)[:-1], self._shape], 0))
|
|
else:
|
|
std = self._std
|
|
if dtype:
|
|
mean, std = tf.cast(mean, dtype), tf.cast(std, dtype)
|
|
if self._dist == 'normal':
|
|
return tfd.Independent(tfd.Normal(mean, std), len(self._shape))
|
|
if self._dist == 'huber':
|
|
return tfd.Independent(
|
|
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape))
|
|
if self._dist == 'binary':
|
|
return tfd.Independent(tfd.Bernoulli(mean), len(self._shape))
|
|
raise NotImplementedError(self._dist)
|
|
|
|
|
|
class ActionHead(tools.Module):
|
|
|
|
def __init__(
|
|
self, size, layers, units, act=tf.nn.elu, dist='trunc_normal',
|
|
init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0):
|
|
# assert min_std <= 2
|
|
self._size = size
|
|
self._layers = layers
|
|
self._units = units
|
|
self._dist = dist
|
|
self._act = act
|
|
self._min_std = min_std
|
|
self._init_std = init_std
|
|
self._action_disc = action_disc
|
|
self._temp = temp() if callable(temp) else temp
|
|
self._outscale = outscale
|
|
|
|
def __call__(self, features, dtype=None):
|
|
x = features
|
|
for index in range(self._layers):
|
|
kw = {}
|
|
if index == self._layers - 1 and self._outscale:
|
|
kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
|
|
self._outscale)
|
|
x = self.get(f'h{index}', tfkl.Dense,
|
|
self._units, self._act, **kw)(x)
|
|
if self._dist == 'tanh_normal':
|
|
# https://www.desmos.com/calculator/rcmcf5jwe7
|
|
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
|
if dtype:
|
|
x = tf.cast(x, dtype)
|
|
mean, std = tf.split(x, 2, -1)
|
|
mean = tf.tanh(mean)
|
|
std = tf.nn.softplus(std + self._init_std) + self._min_std
|
|
dist = tfd.Normal(mean, std)
|
|
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
|
|
dist = tfd.Independent(dist, 1)
|
|
dist = tools.SampleDist(dist)
|
|
elif self._dist == 'tanh_normal_5':
|
|
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
|
if dtype:
|
|
x = tf.cast(x, dtype)
|
|
mean, std = tf.split(x, 2, -1)
|
|
mean = 5 * tf.tanh(mean / 5)
|
|
std = tf.nn.softplus(std + 5) + 5
|
|
dist = tfd.Normal(mean, std)
|
|
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
|
|
dist = tfd.Independent(dist, 1)
|
|
dist = tools.SampleDist(dist)
|
|
elif self._dist == 'normal':
|
|
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
|
if dtype:
|
|
x = tf.cast(x, dtype)
|
|
mean, std = tf.split(x, 2, -1)
|
|
std = tf.nn.softplus(std + self._init_std) + self._min_std
|
|
dist = tfd.Normal(mean, std)
|
|
dist = tfd.Independent(dist, 1)
|
|
elif self._dist == 'normal_1':
|
|
mean = self.get(f'hout', tfkl.Dense, self._size)(x)
|
|
if dtype:
|
|
mean = tf.cast(mean, dtype)
|
|
dist = tfd.Normal(mean, 1)
|
|
dist = tfd.Independent(dist, 1)
|
|
elif self._dist == 'trunc_normal':
|
|
# https://www.desmos.com/calculator/mmuvuhnyxo
|
|
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
|
x = tf.cast(x, tf.float32)
|
|
mean, std = tf.split(x, 2, -1)
|
|
mean = tf.tanh(mean)
|
|
std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
|
|
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
|
dist = tools.DtypeDist(dist, dtype)
|
|
dist = tfd.Independent(dist, 1)
|
|
elif self._dist == 'onehot':
|
|
x = self.get(f'hout', tfkl.Dense, self._size)(x)
|
|
x = tf.cast(x, tf.float32)
|
|
dist = tools.OneHotDist(x, dtype=dtype)
|
|
dist = tools.DtypeDist(dist, dtype)
|
|
elif self._dist == 'onehot_gumble':
|
|
x = self.get(f'hout', tfkl.Dense, self._size)(x)
|
|
if dtype:
|
|
x = tf.cast(x, dtype)
|
|
temp = self._temp
|
|
dist = tools.GumbleDist(temp, x, dtype=dtype)
|
|
else:
|
|
raise NotImplementedError(self._dist)
|
|
return dist
|
|
|
|
|
|
class GRUCell(tf.keras.layers.AbstractRNNCell):
|
|
|
|
def __init__(self, size, norm=False, act=tf.tanh, update_bias=-1, **kwargs):
|
|
super().__init__()
|
|
self._size = size
|
|
self._act = act
|
|
self._norm = norm
|
|
self._update_bias = update_bias
|
|
self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs)
|
|
if norm:
|
|
self._norm = tfkl.LayerNormalization(dtype=tf.float32)
|
|
|
|
@property
|
|
def state_size(self):
|
|
return self._size
|
|
|
|
def call(self, inputs, state):
|
|
state = state[0] # Keras wraps the state in a list.
|
|
parts = self._layer(tf.concat([inputs, state], -1))
|
|
if self._norm:
|
|
dtype = parts.dtype
|
|
parts = tf.cast(parts, tf.float32)
|
|
parts = self._norm(parts)
|
|
parts = tf.cast(parts, dtype)
|
|
reset, cand, update = tf.split(parts, 3, -1)
|
|
reset = tf.nn.sigmoid(reset)
|
|
cand = self._act(reset * cand)
|
|
update = tf.nn.sigmoid(update + self._update_bias)
|
|
output = update * cand + (1 - update) * state
|
|
return output, [output]
|