tia/DreamerV2/networks.py
2021-06-29 21:20:44 -04:00

466 lines
18 KiB
Python

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as tfkl
from tensorflow_probability import distributions as tfd
from tensorflow.keras.mixed_precision import experimental as prec
import tools
class RSSM(tools.Module):
def __init__(
self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1,
shared=False, discrete=False, act=tf.nn.elu, mean_act='none',
std_act='softplus', min_std=0.1, cell='keras'):
super().__init__()
self._stoch = stoch
self._deter = deter
self._hidden = hidden
self._min_std = min_std
self._layers_input = layers_input
self._layers_output = layers_output
self._shared = shared
self._discrete = discrete
self._act = act
self._mean_act = mean_act
self._std_act = std_act
self._embed = None
if cell == 'gru':
self._cell = tfkl.GRUCell(self._deter)
elif cell == 'gru_layer_norm':
self._cell = GRUCell(self._deter, norm=True)
else:
raise NotImplementedError(cell)
def initial(self, batch_size):
dtype = prec.global_policy().compute_dtype
if self._discrete:
state = dict(
logit=tf.zeros(
[batch_size, self._stoch, self._discrete], dtype),
stoch=tf.zeros(
[batch_size, self._stoch, self._discrete], dtype),
deter=self._cell.get_initial_state(None, batch_size, dtype))
else:
state = dict(
mean=tf.zeros([batch_size, self._stoch], dtype),
std=tf.zeros([batch_size, self._stoch], dtype),
stoch=tf.zeros([batch_size, self._stoch], dtype),
deter=self._cell.get_initial_state(None, batch_size, dtype))
return state
@tf.function
def observe(self, embed, action, state=None):
def swap(x): return tf.transpose(
x, [1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(tf.shape(action)[0])
embed, action = swap(embed), swap(action)
post, prior = tools.static_scan(
lambda prev, inputs: self.obs_step(prev[0], *inputs),
(action, embed), (state, state))
post = {k: swap(v) for k, v in post.items()}
prior = {k: swap(v) for k, v in prior.items()}
return post, prior
@tf.function
def imagine(self, action, state=None):
def swap(x): return tf.transpose(
x, [1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(tf.shape(action)[0])
assert isinstance(state, dict), state
action = swap(action)
prior = tools.static_scan(self.img_step, action, state)
prior = {k: swap(v) for k, v in prior.items()}
return prior
def get_feat(self, state):
stoch = state['stoch']
if self._discrete:
shape = stoch.shape[:-2] + [self._stoch * self._discrete]
stoch = tf.reshape(stoch, shape)
return tf.concat([stoch, state['deter']], -1)
def get_dist(self, state, dtype=None):
if self._discrete:
logit = state['logit']
logit = tf.cast(logit, tf.float32)
dist = tfd.Independent(tools.OneHotDist(logit), 1)
if dtype != tf.float32:
dist = tools.DtypeDist(dist, dtype or state['logit'].dtype)
else:
mean, std = state['mean'], state['std']
if dtype:
mean = tf.cast(mean, dtype)
std = tf.cast(std, dtype)
dist = tfd.MultivariateNormalDiag(mean, std)
return dist
@tf.function
def obs_step(self, prev_state, prev_action, embed, sample=True):
if not self._embed:
self._embed = embed.shape[-1]
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
else:
x = tf.concat([prior['deter'], embed], -1)
for i in range(self._layers_output):
x = self.get(f'obi{i}', tfkl.Dense, self._hidden, self._act)(x)
stats = self._suff_stats_layer('obs', x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
post = {'stoch': stoch, 'deter': prior['deter'], **stats}
return post, prior
@tf.function
def img_step(self, prev_state, prev_action, embed=None, sample=True):
prev_stoch = prev_state['stoch']
if self._discrete:
shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
prev_stoch = tf.reshape(prev_stoch, shape)
if self._shared:
if embed is None:
shape = prev_action.shape[:-1] + [self._embed]
embed = tf.zeros(shape, prev_action.dtype)
x = tf.concat([prev_stoch, prev_action, embed], -1)
else:
x = tf.concat([prev_stoch, prev_action], -1)
for i in range(self._layers_input):
x = self.get(f'ini{i}', tfkl.Dense, self._hidden, self._act)(x)
x, deter = self._cell(x, [prev_state['deter']])
deter = deter[0] # Keras wraps the state in a list.
for i in range(self._layers_output):
x = self.get(f'imo{i}', tfkl.Dense, self._hidden, self._act)(x)
stats = self._suff_stats_layer('ims', x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
prior = {'stoch': stoch, 'deter': deter, **stats}
return prior
def _suff_stats_layer(self, name, x):
if self._discrete:
x = self.get(name, tfkl.Dense, self._stoch *
self._discrete, None)(x)
logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
return {'logit': logit}
else:
x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
mean, std = tf.split(x, 2, -1)
mean = {
'none': lambda: mean,
'tanh5': lambda: 5.0 * tf.math.tanh(mean / 5.0),
}[self._mean_act]()
std = {
'softplus': lambda: tf.nn.softplus(std),
'abs': lambda: tf.math.abs(std + 1),
'sigmoid': lambda: tf.nn.sigmoid(std),
'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2),
}[self._std_act]()
std = std + self._min_std
return {'mean': mean, 'std': std}
def kl_loss(self, post, prior, balance, free, scale):
kld = tfd.kl_divergence
def dist(x): return self.get_dist(x, tf.float32)
if balance == 0.5:
value = kld(dist(prior), dist(post))
loss = tf.reduce_mean(tf.maximum(value, free))
else:
def sg(x): return tf.nest.map_structure(tf.stop_gradient, x)
value = kld(dist(prior), dist(sg(post)))
pri = tf.reduce_mean(value)
pos = tf.reduce_mean(kld(dist(sg(prior)), dist(post)))
pri, pos = tf.maximum(pri, free), tf.maximum(pos, free)
loss = balance * pri + (1 - balance) * pos
loss *= scale
return loss, value
class ConvEncoder(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, kernels=(4, 4, 4, 4)):
self._act = act
self._depth = depth
self._kernels = kernels
def __call__(self, obs):
kwargs = dict(strides=2, activation=self._act)
Conv = tfkl.Conv2D
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
x = self.get('h1', Conv, 1 * self._depth,
self._kernels[0], **kwargs)(x)
x = self.get('h2', Conv, 2 * self._depth,
self._kernels[1], **kwargs)(x)
x = self.get('h3', Conv, 4 * self._depth,
self._kernels[2], **kwargs)(x)
x = self.get('h4', Conv, 8 * self._depth,
self._kernels[3], **kwargs)(x)
x = tf.reshape(x, [x.shape[0], np.prod(x.shape[1:])])
shape = tf.concat([tf.shape(obs['image'])[:-3], [x.shape[-1]]], 0)
return tf.reshape(x, shape)
class ConvDecoder(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
thin=True):
self._act = act
self._depth = depth
self._shape = shape
self._kernels = kernels
self._thin = thin
def __call__(self, features, dtype=None):
kwargs = dict(strides=2, activation=self._act)
ConvT = tfkl.Conv2DTranspose
if self._thin:
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
else:
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
x = self.get('h2', ConvT, 4 * self._depth,
self._kernels[0], **kwargs)(x)
x = self.get('h3', ConvT, 2 * self._depth,
self._kernels[1], **kwargs)(x)
x = self.get('h4', ConvT, 1 * self._depth,
self._kernels[2], **kwargs)(x)
x = self.get(
'h5', ConvT, self._shape[-1], self._kernels[3], strides=2)(x)
mean = tf.reshape(x, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
if dtype:
mean = tf.cast(mean, dtype)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
class ConvDecoderMask(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
thin=True):
self._act = act
self._depth = depth
self._shape = shape
self._kernels = kernels
self._thin = thin
def __call__(self, features, dtype=None):
kwargs = dict(strides=2, activation=self._act)
ConvT = tfkl.Conv2DTranspose
if self._thin:
x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
else:
x = self.get('h1', tfkl.Dense, 128 * self._depth, None)(features)
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
x = self.get('h2', ConvT, 4 * self._depth,
self._kernels[0], **kwargs)(x)
x = self.get('h3', ConvT, 2 * self._depth,
self._kernels[1], **kwargs)(x)
x = self.get('h4', ConvT, 1 * self._depth,
self._kernels[2], **kwargs)(x)
x = self.get(
'h5', ConvT, 2 * self._shape[-1], self._kernels[3], strides=2)(x)
mean, mask = tf.split(x, [self._shape[-1], self._shape[-1]], -1)
mean = tf.reshape(mean, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
mask = tf.reshape(mask, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
if dtype:
mean = tf.cast(mean, dtype)
mask = tf.cast(mask, dtype)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask
class ConvDecoderMaskEnsemble(tools.Module):
"""
ensemble two convdecoder with <Normal, mask> outputs
NOTE: remove pred1/pred2 for maximum performance.
"""
def __init__(self, decoder1, decoder2):
self._decoder1 = decoder1
self._decoder2 = decoder2
self._shape = decoder1._shape
def __call__(self, feat1, feat2, dtype=None):
kwargs = dict(strides=1, activation=tf.nn.sigmoid)
pred1, mask1 = self._decoder1(feat1, dtype)
pred2, mask2 = self._decoder2(feat2, dtype)
mean1 = pred1.submodules[0].loc
mean2 = pred2.submodules[0].loc
mask_feat = tf.concat([mask1, mask2], -1)
mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat)
mask_use1 = mask
mask_use2 = 1-mask
mean = mean1 * tf.cast(mask_use1, mean1.dtype) + \
mean2 * tf.cast(mask_use2, mean2.dtype)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, mean1.dtype)
class DenseHead(tools.Module):
def __init__(
self, shape, layers, units, act=tf.nn.elu, dist='normal', std=1.0):
self._shape = (shape,) if isinstance(shape, int) else shape
self._layers = layers
self._units = units
self._act = act
self._dist = dist
self._std = std
def __call__(self, features, dtype=None):
x = features
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
mean = self.get(f'hmean', tfkl.Dense, np.prod(self._shape))(x)
mean = tf.reshape(mean, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
if self._std == 'learned':
std = self.get(f'hstd', tfkl.Dense, np.prod(self._shape))(x)
std = tf.nn.softplus(std) + 0.01
std = tf.reshape(std, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
else:
std = self._std
if dtype:
mean, std = tf.cast(mean, dtype), tf.cast(std, dtype)
if self._dist == 'normal':
return tfd.Independent(tfd.Normal(mean, std), len(self._shape))
if self._dist == 'huber':
return tfd.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape))
if self._dist == 'binary':
return tfd.Independent(tfd.Bernoulli(mean), len(self._shape))
raise NotImplementedError(self._dist)
class ActionHead(tools.Module):
def __init__(
self, size, layers, units, act=tf.nn.elu, dist='trunc_normal',
init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0):
# assert min_std <= 2
self._size = size
self._layers = layers
self._units = units
self._dist = dist
self._act = act
self._min_std = min_std
self._init_std = init_std
self._action_disc = action_disc
self._temp = temp() if callable(temp) else temp
self._outscale = outscale
def __call__(self, features, dtype=None):
x = features
for index in range(self._layers):
kw = {}
if index == self._layers - 1 and self._outscale:
kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
self._outscale)
x = self.get(f'h{index}', tfkl.Dense,
self._units, self._act, **kw)(x)
if self._dist == 'tanh_normal':
# https://www.desmos.com/calculator/rcmcf5jwe7
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = tf.tanh(mean)
std = tf.nn.softplus(std + self._init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'tanh_normal_5':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = 5 * tf.tanh(mean / 5)
std = tf.nn.softplus(std + 5) + 5
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'normal':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
std = tf.nn.softplus(std + self._init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.Independent(dist, 1)
elif self._dist == 'normal_1':
mean = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
mean = tf.cast(mean, dtype)
dist = tfd.Normal(mean, 1)
dist = tfd.Independent(dist, 1)
elif self._dist == 'trunc_normal':
# https://www.desmos.com/calculator/mmuvuhnyxo
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
x = tf.cast(x, tf.float32)
mean, std = tf.split(x, 2, -1)
mean = tf.tanh(mean)
std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.DtypeDist(dist, dtype)
dist = tfd.Independent(dist, 1)
elif self._dist == 'onehot':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
x = tf.cast(x, tf.float32)
dist = tools.OneHotDist(x, dtype=dtype)
dist = tools.DtypeDist(dist, dtype)
elif self._dist == 'onehot_gumble':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
x = tf.cast(x, dtype)
temp = self._temp
dist = tools.GumbleDist(temp, x, dtype=dtype)
else:
raise NotImplementedError(self._dist)
return dist
class GRUCell(tf.keras.layers.AbstractRNNCell):
def __init__(self, size, norm=False, act=tf.tanh, update_bias=-1, **kwargs):
super().__init__()
self._size = size
self._act = act
self._norm = norm
self._update_bias = update_bias
self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs)
if norm:
self._norm = tfkl.LayerNormalization(dtype=tf.float32)
@property
def state_size(self):
return self._size
def call(self, inputs, state):
state = state[0] # Keras wraps the state in a list.
parts = self._layer(tf.concat([inputs, state], -1))
if self._norm:
dtype = parts.dtype
parts = tf.cast(parts, tf.float32)
parts = self._norm(parts)
parts = tf.cast(parts, dtype)
reset, cand, update = tf.split(parts, 3, -1)
reset = tf.nn.sigmoid(reset)
cand = self._act(reset * cand)
update = tf.nn.sigmoid(update + self._update_bias)
output = update * cand + (1 - update) * state
return output, [output]