tia/Dreamer/models.py
2023-07-17 10:48:01 +02:00

298 lines
11 KiB
Python

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as tfkl
from tensorflow_probability import distributions as tfd
from tensorflow.keras.mixed_precision import experimental as prec
import tools
class RSSM(tools.Module):
def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu):
super().__init__()
self._activation = act
self._stoch_size = stoch
self._deter_size = deter
self._hidden_size = hidden
self._cell = tfkl.GRUCell(self._deter_size)
def initial(self, batch_size):
dtype = prec.global_policy().compute_dtype
return dict(
mean=tf.zeros([batch_size, self._stoch_size], dtype),
std=tf.zeros([batch_size, self._stoch_size], dtype),
stoch=tf.zeros([batch_size, self._stoch_size], dtype),
deter=self._cell.get_initial_state(None, batch_size, dtype))
@tf.function
def observe(self, embed, action, state=None):
if state is None:
state = self.initial(tf.shape(action)[0])
embed = tf.transpose(embed, [1, 0, 2])
action = tf.transpose(action, [1, 0, 2])
post, prior = tools.static_scan(
lambda prev, inputs: self.obs_step(
prev[0], *inputs),
(action, embed), (state, state))
post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
return post, prior
@tf.function
def imagine(self, action, state=None):
if state is None:
state = self.initial(tf.shape(action)[0])
assert isinstance(state, dict), state
action = tf.transpose(action, [1, 0, 2])
prior = tools.static_scan(self.img_step, action, state)
prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
return prior
def get_feat(self, state):
return tf.concat([state['stoch'], state['deter']], -1)
def get_dist(self, state):
return tfd.MultivariateNormalDiag(state['mean'], state['std'])
@tf.function
def obs_step(self, prev_state, prev_action, embed):
prior = self.img_step(prev_state, prev_action)
x = tf.concat([prior['deter'], embed], -1)
x = self.get('obs1', tfkl.Dense, self._hidden_size,
self._activation)(x)
x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x)
mean, std = tf.split(x, 2, -1)
std = tf.nn.softplus(std) + 0.1
stoch = self.get_dist({'mean': mean, 'std': std}).sample()
post = {'mean': mean, 'std': std,
'stoch': stoch, 'deter': prior['deter']}
return post, prior
@tf.function
def img_step(self, prev_state, prev_action):
x = tf.concat([prev_state['stoch'], prev_action], -1)
x = self.get('img1', tfkl.Dense, self._hidden_size,
self._activation)(x)
x, deter = self._cell(x, [prev_state['deter']])
deter = deter[0] # Keras wraps the state in a list.
x = self.get('img2', tfkl.Dense, self._hidden_size,
self._activation)(x)
x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x)
mean, std = tf.split(x, 2, -1)
std = tf.nn.softplus(std) + 0.1
stoch = self.get_dist({'mean': mean, 'std': std}).sample()
prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter}
return prior
class ConvEncoder(tools.Module):
def __init__(self, depth=32, act=tf.nn.relu, image_size=64):
self._act = act
self._depth = depth
self._image_size = image_size
if image_size == 64:
self._outdim = 32 * self._depth
self._kernel_sizes = [4, 4, 4, 4]
elif image_size == 32:
self._outdim = 8 * self._depth
self._kernel_sizes = [3, 3, 3, 3]
elif image_size == 84:
self._outdim = 72 * self._depth
self._kernel_sizes = [4, 4, 4, 4]
else:
raise NotImplementedError
def __call__(self, obs):
kwargs = dict(strides=2, activation=self._act)
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
x = self.get('h1', tfkl.Conv2D, 1 * self._depth,
self._kernel_sizes[0], **kwargs)(x)
x = self.get('h2', tfkl.Conv2D, 2 * self._depth,
self._kernel_sizes[1], **kwargs)(x)
x = self.get('h3', tfkl.Conv2D, 4 * self._depth,
self._kernel_sizes[2], **kwargs)(x)
x = self.get('h4', tfkl.Conv2D, 8 * self._depth,
self._kernel_sizes[3], **kwargs)(x)
shape = tf.concat([tf.shape(obs['image'])[:-3], [self._outdim]], 0)
return tf.reshape(x, shape)
class ConvDecoder(tools.Module):
def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)):
self._act = act
self._depth = depth
self._shape = shape
if shape[0] == 64:
self._outdim = 32 * self._depth
self._kernel_sizes = [5, 5, 6, 6]
elif shape[0] == 32:
self._outdim = 8 * self._depth
self._kernel_sizes = [3, 3, 3, 4]
elif shape[0] == 84:
self._outdim = 72 * self._depth
self._kernel_sizes = [7, 6, 6, 6]
else:
raise NotImplementedError
def __call__(self, features):
kwargs = dict(strides=2, activation=self._act)
x = self.get('h1', tfkl.Dense, self._outdim, None)(features)
x = tf.reshape(x, [-1, 1, 1, self._outdim])
x = self.get('h2', tfkl.Conv2DTranspose,
4 * self._depth, self._kernel_sizes[0], **kwargs)(x)
x = self.get('h3', tfkl.Conv2DTranspose,
2 * self._depth, self._kernel_sizes[1], **kwargs)(x)
x = self.get('h4', tfkl.Conv2DTranspose,
1 * self._depth, self._kernel_sizes[2], **kwargs)(x)
x = self.get('h5', tfkl.Conv2DTranspose,
self._shape[-1], self._kernel_sizes[3], strides=2)(x)
mean = tf.reshape(x, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
class ConvDecoderMask(tools.Module):
def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)):
self._act = act
self._depth = depth
self._shape = shape
if shape[0] == 64:
self._outdim = 32 * self._depth
self._kernel_sizes = [5, 5, 6, 6]
elif shape[0] == 32:
self._outdim = 8 * self._depth
self._kernel_sizes = [3, 3, 3, 4]
elif shape[0] == 84:
self._outdim = 72 * self._depth
self._kernel_sizes = [7, 6, 6, 6]
else:
raise NotImplementedError
def __call__(self, features):
kwargs = dict(strides=2, activation=self._act)
x = self.get('h1', tfkl.Dense, self._outdim, None)(features)
x = tf.reshape(x, [-1, 1, 1, self._outdim])
x = self.get('h2', tfkl.Conv2DTranspose,
4 * self._depth, self._kernel_sizes[0], **kwargs)(x)
x = self.get('h3', tfkl.Conv2DTranspose,
2 * self._depth, self._kernel_sizes[1], **kwargs)(x)
x = self.get('h4', tfkl.Conv2DTranspose,
1 * self._depth, self._kernel_sizes[2], **kwargs)(x)
x = self.get('h5', tfkl.Conv2DTranspose,
3 + self._shape[-1], self._kernel_sizes[3], strides=2)(x)
mean, mask = tf.split(x, [3, 3], -1)
mean = tf.reshape(mean, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
mask = tf.reshape(mask, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), mask
class ConvDecoderMaskEnsemble(tools.Module):
"""
ensemble two convdecoder with <Normal, mask> outputs
"""
def __init__(self, decoder1, decoder2, precision):
self._decoder1 = decoder1
self._decoder2 = decoder2
self._precision = 'float' + str(precision)
self._shape = decoder1._shape
def __call__(self, feat1, feat2):
kwargs = dict(strides=1, activation=tf.nn.sigmoid)
pred1, mask1 = self._decoder1(feat1)
pred2, mask2 = self._decoder2(feat2)
mean1 = pred1.submodules[0].loc
mean2 = pred2.submodules[0].loc
mask_feat = tf.concat([mask1, mask2], -1)
mask = self.get('mask1', tfkl.Conv2D, 1, 1, **kwargs)(mask_feat)
mask_use1 = mask
mask_use2 = 1-mask
mean = mean1 * tf.cast(mask_use1, self._precision) + \
mean2 * tf.cast(mask_use2, self._precision)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)), pred1, pred2, tf.cast(mask_use1, self._precision)
class InverseDecoder(tools.Module):
def __init__(self, shape, layers, units, act=tf.nn.elu):
self._shape = shape
self._layers = layers
self._units = units
self._act = act
def __call__(self, features):
x = tf.concat([features[:, :-1], features[:, 1:]], -1)
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x)
return tfd.Independent(tfd.Normal(x, 1), 1)
class DenseDecoder(tools.Module):
def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu):
self._shape = shape
self._layers = layers
self._units = units
self._dist = dist
self._act = act
def __call__(self, features):
x = features
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x)
x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
if self._dist == 'normal':
return tfd.Independent(tfd.Normal(x, 1), len(self._shape))
if self._dist == 'binary':
return tfd.Independent(tfd.Bernoulli(x), len(self._shape))
raise NotImplementedError(self._dist)
class ActionDecoder(tools.Module):
def __init__(
self, size, layers, units, dist='tanh_normal', act=tf.nn.elu,
min_std=1e-4, init_std=5, mean_scale=5):
self._size = size
self._layers = layers
self._units = units
self._dist = dist
self._act = act
self._min_std = min_std
self._init_std = init_std
self._mean_scale = mean_scale
def __call__(self, features):
raw_init_std = np.log(np.exp(self._init_std) - 1)
x = features
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
if self._dist == 'tanh_normal':
# https://www.desmos.com/calculator/rcmcf5jwe7
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
mean, std = tf.split(x, 2, -1)
mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
std = tf.nn.softplus(std + raw_init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'onehot':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
dist = tools.OneHotDist(x)
else:
raise NotImplementedError(dist)
return dist