tia/DreamerV2/exploration.py

84 lines
2.8 KiB
Python
Raw Permalink Normal View History

2021-06-30 01:20:44 +00:00
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow_probability import distributions as tfd
import models
import networks
import tools
class Random(tools.Module):
def __init__(self, config):
self._config = config
self._float = prec.global_policy().comput_dtype
def actor(self, feat):
shape = feat.shape[:-1] + [self._config.num_actions]
if self._config.actor_dist == 'onehot':
return tools.OneHotDist(tf.zeros(shape))
else:
ones = tf.ones(shape, self._float)
return tfd.Uniform(-ones, ones)
def train(self, start, feat, embed, kl):
return None, {}
class Plan2Explore(tools.Module):
def __init__(self, config, world_model, reward=None):
self._config = config
self._reward = reward
self._behavior = models.ImagBehavior(config, world_model)
self.actor = self._behavior.actor
size = {
'embed': 32 * config.cnn_depth,
'stoch': config.dyn_stoch,
'deter': config.dyn_deter,
'feat': config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target]
kw = dict(
shape=size, layers=config.disag_layers, units=config.disag_units,
act=config.act)
self._networks = [
networks.DenseHead(**kw) for _ in range(config.disag_models)]
self._opt = tools.Optimizer(
'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
def train(self, start, feat, embed, kl):
metrics = {}
target = {
'embed': embed,
'stoch': start['stoch'],
'deter': start['deter'],
'feat': feat,
}[self._config.disag_target]
metrics.update(self._train_ensemble(feat, target))
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
return None, metrics
def _intrinsic_reward(self, feat, state, action):
preds = [head(feat, tf.float32).mean() for head in self._networks]
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
if self._config.disag_log:
disag = tf.math.log(disag)
reward = self._config.expl_intr_scale * disag
if self._config.expl_extr_scale:
reward += tf.cast(self._config.expl_extr_scale * self._reward(
feat, state, action), tf.float32)
return reward
def _train_ensemble(self, inputs, targets):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset:]
inputs = inputs[:, :-self._config.disag_offset]
targets = tf.stop_gradient(targets)
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
metrics = self._opt(tape, loss, self._networks)
return metrics