tia/Dreamer/dreamers.py

741 lines
34 KiB
Python
Raw Permalink Normal View History

2023-07-17 08:48:01 +00:00
import tools
import models
from tensorflow_probability import distributions as tfd
from tensorflow.keras.mixed_precision import experimental as prec
import tensorflow as tf
import numpy as np
import collections
import functools
import json
import time
from env_tools import preprocess, count_steps
def load_dataset(directory, config):
episode = next(tools.load_episodes(directory, 1))
types = {k: v.dtype for k, v in episode.items()}
shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()}
def generator(): return tools.load_episodes(
directory, config.train_steps, config.batch_length,
config.dataset_balance)
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.map(functools.partial(preprocess, config=config))
dataset = dataset.prefetch(10)
return dataset
class Dreamer(tools.Module):
def __init__(self, config, datadir, actspace, writer):
self._c = config
self._actspace = actspace
self._actdim = actspace.n if hasattr(
actspace, 'n') else actspace.shape[0]
self._writer = writer
self._random = np.random.RandomState(config.seed)
self._should_pretrain = tools.Once()
self._should_train = tools.Every(config.train_every)
self._should_log = tools.Every(config.log_every)
self._last_log = None
self._last_time = time.time()
self._metrics = collections.defaultdict(tf.metrics.Mean)
self._metrics['expl_amount'] # Create variable for checkpoint.
self._float = prec.global_policy().compute_dtype
self._strategy = tf.distribute.MirroredStrategy()
with tf.device('cpu:0'):
self._step = tf.Variable(count_steps(
datadir, config), dtype=tf.int64)
with self._strategy.scope():
self._dataset = iter(self._strategy.experimental_distribute_dataset(
load_dataset(datadir, self._c)))
self._build_model()
def __call__(self, obs, reset, state=None, training=True):
step = self._step.numpy().item()
tf.summary.experimental.set_step(step)
if state is not None and reset.any():
mask = tf.cast(1 - reset, self._float)[:, None]
state = tf.nest.map_structure(lambda x: x * mask, state)
if self._should_train(step):
log = self._should_log(step)
n = self._c.pretrain if self._should_pretrain() else self._c.train_steps
print(f'Training for {n} steps.')
with self._strategy.scope():
for train_step in range(n):
log_images = self._c.log_images and log and train_step == 0
self.train(next(self._dataset), log_images)
if log:
self._write_summaries()
action, state = self.policy(obs, state, training)
if training:
self._step.assign_add(len(reset) * self._c.action_repeat)
return action, state
@tf.function
def policy(self, obs, state, training):
if state is None:
latent = self._dynamics.initial(len(obs['image']))
action = tf.zeros((len(obs['image']), self._actdim), self._float)
else:
latent, action = state
embed = self._encode(preprocess(obs, self._c))
latent, _ = self._dynamics.obs_step(latent, action, embed)
feat = self._dynamics.get_feat(latent)
if training:
action = self._actor(feat).sample()
else:
action = self._actor(feat).mode()
action = self._exploration(action, training)
state = (latent, action)
return action, state
def load(self, filename):
super().load(filename)
self._should_pretrain()
@tf.function()
def train(self, data, log_images=False):
self._strategy.run(
self._train, args=(data, log_images))
def _train(self, data, log_images):
with tf.GradientTape() as model_tape:
data["image"] = tf.transpose(data["image"], perm=[0, 1, 3, 4, 2])
embed = self._encode(data)
post, prior = self._dynamics.observe(embed, data['action'])
feat = self._dynamics.get_feat(post)
image_pred = self._decode(feat)
reward_pred = self._reward(feat)
likes = tools.AttrDict()
likes.image = tf.reduce_mean(image_pred.log_prob(data['image']))
likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward']))
if self._c.pcont:
pcont_pred = self._pcont(feat)
pcont_target = self._c.discount * data['discount']
likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
likes.pcont *= self._c.pcont_scale
prior_dist = self._dynamics.get_dist(prior)
post_dist = self._dynamics.get_dist(post)
div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
div = tf.maximum(div, self._c.free_nats)
model_loss = self._c.kl_scale * div - sum(likes.values())
model_loss /= float(self._strategy.num_replicas_in_sync)
with tf.GradientTape() as actor_tape:
imag_feat = self._imagine_ahead(post)
reward = self._reward(imag_feat).mode()
if self._c.pcont:
pcont = self._pcont(imag_feat).mean()
else:
pcont = self._c.discount * tf.ones_like(reward)
value = self._value(imag_feat).mode()
returns = tools.lambda_return(
reward[:-1], value[:-1], pcont[:-1],
bootstrap=value[-1], lambda_=self._c.disclam, axis=0)
discount = tf.stop_gradient(tf.math.cumprod(tf.concat(
[tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0))
actor_loss = -tf.reduce_mean(discount * returns)
actor_loss /= float(self._strategy.num_replicas_in_sync)
with tf.GradientTape() as value_tape:
value_pred = self._value(imag_feat)[:-1]
target = tf.stop_gradient(returns)
value_loss = - \
tf.reduce_mean(discount * value_pred.log_prob(target))
value_loss /= float(self._strategy.num_replicas_in_sync)
model_norm = self._model_opt(model_tape, model_loss)
actor_norm = self._actor_opt(actor_tape, actor_loss)
value_norm = self._value_opt(value_tape, value_loss)
if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
if self._c.log_scalars:
self._scalar_summaries(
data, feat, prior_dist, post_dist, likes, div,
model_loss, value_loss, actor_loss, model_norm, value_norm,
actor_norm)
if tf.equal(log_images, True):
self._image_summaries(data, embed, image_pred)
def _build_model(self):
acts = dict(
elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish,
leaky_relu=tf.nn.leaky_relu)
cnn_act = acts[self._c.cnn_act]
act = acts[self._c.dense_act]
self._encode = models.ConvEncoder(
self._c.cnn_depth, cnn_act, self._c.image_size)
self._dynamics = models.RSSM(
self._c.stoch_size, self._c.deter_size, self._c.deter_size)
self._decode = models.ConvDecoder(
self._c.cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3))
self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act)
if self._c.pcont:
self._pcont = models.DenseDecoder(
(), 3, self._c.num_units, 'binary', act=act)
self._value = models.DenseDecoder((), 3, self._c.num_units, act=act)
self._actor = models.ActionDecoder(
self._actdim, 4, self._c.num_units, self._c.action_dist,
init_std=self._c.action_init_std, act=act)
model_modules = [self._encode, self._dynamics,
self._decode, self._reward]
if self._c.pcont:
model_modules.append(self._pcont)
Optimizer = functools.partial(
tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip,
wdpattern=self._c.weight_decay_pattern)
self._model_opt = Optimizer('model', model_modules, self._c.model_lr)
self._value_opt = Optimizer('value', [self._value], self._c.value_lr)
self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr)
self.train(next(self._dataset))
def _exploration(self, action, training):
if training:
amount = self._c.expl_amount
if self._c.expl_decay:
amount *= 0.5 ** (tf.cast(self._step,
tf.float32) / self._c.expl_decay)
if self._c.expl_min:
amount = tf.maximum(self._c.expl_min, amount)
self._metrics['expl_amount'].update_state(amount)
elif self._c.eval_noise:
amount = self._c.eval_noise
else:
return action
if self._c.expl == 'additive_gaussian':
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
if self._c.expl == 'completely_random':
return tf.random.uniform(action.shape, -1, 1)
if self._c.expl == 'epsilon_greedy':
indices = tfd.Categorical(0 * action).sample()
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
return tf.where(
tf.random.uniform(action.shape[:1], 0, 1) < amount,
tf.one_hot(indices, action.shape[-1], dtype=self._float),
action)
raise NotImplementedError(self._c.expl)
def _imagine_ahead(self, post):
if self._c.pcont: # Last step could be terminal.
post = {k: v[:, :-1] for k, v in post.items()}
def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in post.items()}
def policy(state): return self._actor(
tf.stop_gradient(self._dynamics.get_feat(state))).sample()
states = tools.static_scan(
lambda prev, _: self._dynamics.img_step(prev, policy(prev)),
tf.range(self._c.horizon), start)
imag_feat = self._dynamics.get_feat(states)
return imag_feat
def _scalar_summaries(
self, data, feat, prior_dist, post_dist, likes, div,
model_loss, value_loss, actor_loss, model_norm, value_norm,
actor_norm):
self._metrics['model_grad_norm'].update_state(model_norm)
self._metrics['value_grad_norm'].update_state(value_norm)
self._metrics['actor_grad_norm'].update_state(actor_norm)
self._metrics['prior_ent'].update_state(prior_dist.entropy())
self._metrics['post_ent'].update_state(post_dist.entropy())
for name, logprob in likes.items():
self._metrics[name + '_loss'].update_state(-logprob)
self._metrics['div'].update_state(div)
self._metrics['model_loss'].update_state(model_loss)
self._metrics['value_loss'].update_state(value_loss)
self._metrics['actor_loss'].update_state(actor_loss)
self._metrics['action_ent'].update_state(self._actor(feat).entropy())
def _image_summaries(self, data, embed, image_pred):
truth = data['image'][:6] + 0.5
recon = image_pred.mode()[:6]
init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5])
init = {k: v[:, -1] for k, v in init.items()}
prior = self._dynamics.imagine(data['action'][:6, 5:], init)
openl = self._decode(self._dynamics.get_feat(prior)).mode()
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (model - truth + 1) / 2
openl = tf.concat([truth, model, error], 2)
tools.graph_summary(
self._writer, tools.video_summary, self._step, 'agent/openl', openl)
def image_summary_from_data(self, data):
truth = data['image'][:6] + 0.5
embed = self._encode(data)
post, _ = self._dynamics.observe(
embed[:6, :5], data['action'][:6, :5])
feat = self._dynamics.get_feat(post)
init = {k: v[:, -1] for k, v in post.items()}
recon = self._decode(feat).mode()[:6]
prior = self._dynamics.imagine(data['action'][:6, 5:], init)
openl = self._decode(self._dynamics.get_feat(prior)).mode()
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (model - truth + 1) / 2
openl = tf.concat([truth, model, error], 2)
tools.graph_summary(
self._writer, tools.video_summary, self._step, 'agent/eval_openl', openl)
def _write_summaries(self):
step = int(self._step.numpy())
metrics = [(k, float(v.result())) for k, v in self._metrics.items()]
if self._last_log is not None:
duration = time.time() - self._last_time
self._last_time += duration
metrics.append(('fps', (step - self._last_log) / duration))
self._last_log = step
[m.reset_states() for m in self._metrics.values()]
with (self._c.logdir / 'metrics.jsonl').open('a') as f:
f.write(json.dumps({'step': step, **dict(metrics)}) + '\n')
[tf.summary.scalar('agent/' + k, m) for k, m in metrics]
print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics))
self._writer.flush()
class SeparationDreamer(Dreamer):
def __init__(self, config, datadir, actspace, writer):
self._metrics_disen = collections.defaultdict(tf.metrics.Mean)
self._metrics_disen['expl_amount']
super().__init__(config, datadir, actspace, writer)
def _train(self, data, log_images):
with tf.GradientTape(persistent=True) as model_tape:
# main
data["image"] = tf.transpose(data["image"], perm=[0, 1, 3, 4, 2])
embed = self._encode(data)
post, prior = self._dynamics.observe(embed, data['action'])
feat = self._dynamics.get_feat(post)
# disen
embed_disen = self._disen_encode(data)
post_disen, prior_disen = self._disen_dynamics.observe(
embed_disen, data['action'])
feat_disen = self._disen_dynamics.get_feat(post_disen)
# disen image pred
image_pred_disen = self._disen_only_decode(feat_disen)
# joint image pred
image_pred_joint, image_pred_joint_main, image_pred_joint_disen, mask_pred = self._joint_decode(
feat, feat_disen)
# reward pred
reward_pred = self._reward(feat)
# optimize disen reward predictor till optimal
for _ in range(self._c.num_reward_opt_iters):
with tf.GradientTape() as disen_reward_tape:
reward_pred_disen = self._disen_reward(
tf.stop_gradient(feat_disen))
reward_like_disen = reward_pred_disen.log_prob(
data['reward'])
reward_loss_disen = -tf.reduce_mean(reward_like_disen)
reward_loss_disen /= float(
self._strategy.num_replicas_in_sync)
reward_disen_norm = self._disen_reward_opt(
disen_reward_tape, reward_loss_disen)
# disen reward pred with optimal reward predictor
reward_pred_disen = self._disen_reward(feat_disen)
reward_like_disen = tf.reduce_mean(
reward_pred_disen.log_prob(data['reward']))
# main model loss
likes = tools.AttrDict()
likes.image = tf.reduce_mean(
image_pred_joint.log_prob(data['image']))
likes.reward = tf.reduce_mean(reward_pred.log_prob(
data['reward'])) * self._c.reward_scale
if self._c.pcont:
pcont_pred = self._pcont(feat)
pcont_target = self._c.discount * data['discount']
likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
likes.pcont *= self._c.pcont_scale
prior_dist = self._dynamics.get_dist(prior)
post_dist = self._dynamics.get_dist(post)
div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
div = tf.maximum(div, self._c.free_nats)
model_loss = self._c.kl_scale * div - sum(likes.values())
model_loss /= float(self._strategy.num_replicas_in_sync)
# disen model loss with reward negative gradient
likes_disen = tools.AttrDict()
likes_disen.image = tf.reduce_mean(
image_pred_joint.log_prob(data['image']))
likes_disen.disen_only = tf.reduce_mean(
image_pred_disen.log_prob(data['image']))
reward_like_disen = reward_pred_disen.log_prob(data['reward'])
reward_like_disen = tf.reduce_mean(reward_like_disen)
reward_loss_disen = -reward_like_disen
prior_dist_disen = self._disen_dynamics.get_dist(prior_disen)
post_dist_disen = self._disen_dynamics.get_dist(post_disen)
div_disen = tf.reduce_mean(tfd.kl_divergence(
post_dist_disen, prior_dist_disen))
div_disen = tf.maximum(div_disen, self._c.free_nats)
model_loss_disen = div_disen * self._c.disen_kl_scale + \
reward_like_disen * self._c.disen_neg_rew_scale - \
likes_disen.image - likes_disen.disen_only * self._c.disen_rec_scale
model_loss_disen /= float(self._strategy.num_replicas_in_sync)
decode_loss = model_loss_disen + model_loss
with tf.GradientTape() as actor_tape:
imag_feat = self._imagine_ahead(post)
reward = self._reward(imag_feat).mode()
if self._c.pcont:
pcont = self._pcont(imag_feat).mean()
else:
pcont = self._c.discount * tf.ones_like(reward)
value = self._value(imag_feat).mode()
returns = tools.lambda_return(
reward[:-1], value[:-1], pcont[:-1],
bootstrap=value[-1], lambda_=self._c.disclam, axis=0)
discount = tf.stop_gradient(tf.math.cumprod(tf.concat(
[tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0))
actor_loss = -tf.reduce_mean(discount * returns)
actor_loss /= float(self._strategy.num_replicas_in_sync)
with tf.GradientTape() as value_tape:
value_pred = self._value(imag_feat)[:-1]
target = tf.stop_gradient(returns)
value_loss = - \
tf.reduce_mean(discount * value_pred.log_prob(target))
value_loss /= float(self._strategy.num_replicas_in_sync)
model_norm = self._model_opt(model_tape, model_loss)
model_disen_norm = self._disen_opt(model_tape, model_loss_disen)
decode_norm = self._decode_opt(model_tape, decode_loss)
actor_norm = self._actor_opt(actor_tape, actor_loss)
value_norm = self._value_opt(value_tape, value_loss)
if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
if self._c.log_scalars:
self._scalar_summaries(
data, feat, prior_dist, post_dist, likes, div,
model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm)
self._scalar_summaries_disen(
prior_dist_disen, post_dist_disen, likes_disen, div_disen,
model_loss_disen, reward_loss_disen,
model_disen_norm, reward_disen_norm)
if tf.equal(log_images, True):
self._image_summaries_joint(
data, embed, embed_disen, image_pred_joint, mask_pred)
self._image_summaries(
self._disen_dynamics, self._disen_decode, data, embed_disen, image_pred_joint_disen, tag='disen/openl_joint_disen')
self._image_summaries(
self._disen_dynamics, self._disen_only_decode, data, embed_disen, image_pred_disen, tag='disen_only/openl_disen_only')
self._image_summaries(
self._dynamics, self._main_decode, data, embed, image_pred_joint_main, tag='main/openl_joint_main')
def _build_model(self):
acts = dict(
elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish,
leaky_relu=tf.nn.leaky_relu)
cnn_act = acts[self._c.cnn_act]
act = acts[self._c.dense_act]
# Distractor dynamic model
self._disen_encode = models.ConvEncoder(
self._c.disen_cnn_depth, cnn_act, self._c.image_size)
self._disen_dynamics = models.RSSM(
self._c.disen_stoch_size, self._c.disen_deter_size, self._c.disen_deter_size)
self._disen_only_decode = models.ConvDecoder(
self._c.disen_cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3))
self._disen_reward = models.DenseDecoder(
(), 2, self._c.num_units, act=act)
# Task dynamic model
self._encode = models.ConvEncoder(
self._c.cnn_depth, cnn_act, self._c.image_size)
self._dynamics = models.RSSM(
self._c.stoch_size, self._c.deter_size, self._c.deter_size)
self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act)
if self._c.pcont:
self._pcont = models.DenseDecoder(
(), 3, self._c.num_units, 'binary', act=act)
self._value = models.DenseDecoder((), 3, self._c.num_units, act=act)
self._actor = models.ActionDecoder(
self._actdim, 4, self._c.num_units, self._c.action_dist,
init_std=self._c.action_init_std, act=act)
# Joint decode
self._main_decode = models.ConvDecoderMask(
self._c.cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3))
self._disen_decode = models.ConvDecoderMask(
self._c.disen_cnn_depth, cnn_act, (self._c.image_size, self._c.image_size, 3))
self._joint_decode = models.ConvDecoderMaskEnsemble(
self._main_decode, self._disen_decode, self._c.precision
)
disen_modules = [self._disen_encode,
self._disen_dynamics, self._disen_only_decode]
model_modules = [self._encode, self._dynamics, self._reward]
if self._c.pcont:
model_modules.append(self._pcont)
Optimizer = functools.partial(
tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip,
wdpattern=self._c.weight_decay_pattern)
self._model_opt = Optimizer('model', model_modules, self._c.model_lr)
self._disen_opt = Optimizer('disen', disen_modules, self._c.model_lr)
self._decode_opt = Optimizer(
'decode', [self._joint_decode], self._c.model_lr)
self._disen_reward_opt = Optimizer(
'disen_reward', [self._disen_reward], self._c.disen_reward_lr)
self._value_opt = Optimizer('value', [self._value], self._c.value_lr)
self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr)
self.train(next(self._dataset))
def _scalar_summaries_disen(
self, prior_dist_disen, post_dist_disen, likes_disen, div_disen,
model_loss_disen, reward_loss_disen,
model_disen_norm, reward_disen_norm):
self._metrics_disen['model_grad_norm'].update_state(model_disen_norm)
self._metrics_disen['reward_grad_norm'].update_state(reward_disen_norm)
self._metrics_disen['prior_ent'].update_state(
prior_dist_disen.entropy())
self._metrics_disen['post_ent'].update_state(post_dist_disen.entropy())
for name, logprob in likes_disen.items():
self._metrics_disen[name + '_loss'].update_state(-logprob)
self._metrics_disen['div'].update_state(div_disen)
self._metrics_disen['model_loss'].update_state(model_loss_disen)
self._metrics_disen['reward_loss'].update_state(
reward_loss_disen)
def _image_summaries(self, dynamics, decoder, data, embed, image_pred, tag='agent/openl'):
truth = data['image'][:6] + 0.5
recon = image_pred.mode()[:6]
init, _ = dynamics.observe(embed[:6, :5], data['action'][:6, :5])
init = {k: v[:, -1] for k, v in init.items()}
prior = dynamics.imagine(data['action'][:6, 5:], init)
if isinstance(decoder, models.ConvDecoderMask):
openl, _ = decoder(dynamics.get_feat(prior))
openl = openl.mode()
else:
openl = decoder(dynamics.get_feat(prior)).mode()
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (model - truth + 1) / 2
openl = tf.concat([truth, model, error], 2)
tools.graph_summary(
self._writer, tools.video_summary, self._step, tag, openl)
def _image_summaries_joint(self, data, embed, embed_disen, image_pred_joint, mask_pred):
truth = data['image'][:6] + 0.5
recon_joint = image_pred_joint.mode()[:6]
mask_pred = mask_pred[:6]
init, _ = self._dynamics.observe(
embed[:6, :5], data['action'][:6, :5])
init_disen, _ = self._disen_dynamics.observe(
embed_disen[:6, :5], data['action'][:6, :5])
init = {k: v[:, -1] for k, v in init.items()}
init_disen = {k: v[:, -1] for k, v in init_disen.items()}
prior = self._dynamics.imagine(
data['action'][:6, 5:], init)
prior_disen = self._disen_dynamics.imagine(
data['action'][:6, 5:], init_disen)
feat = self._dynamics.get_feat(prior)
feat_disen = self._disen_dynamics.get_feat(prior_disen)
openl, _, _, openl_mask = self._joint_decode(feat, feat_disen)
openl = openl.mode()
model = tf.concat([recon_joint[:, :5] + 0.5, openl + 0.5], 1)
error = (model - truth + 1) / 2
openl = tf.concat([truth, model, error], 2)
openl_mask = tf.concat([mask_pred[:, :5] + 0.5, openl_mask + 0.5], 1)
tools.graph_summary(
self._writer, tools.video_summary, self._step, 'joint/openl_joint', openl)
tools.graph_summary(
self._writer, tools.video_summary, self._step, 'mask/openl_mask', openl_mask)
def image_summary_from_data(self, data):
truth = data['image'][:6] + 0.5
# main
embed = self._encode(data)
post, _ = self._dynamics.observe(
embed[:6, :5], data['action'][:6, :5])
feat = self._dynamics.get_feat(post)
init = {k: v[:, -1] for k, v in post.items()}
# disen
embed_disen = self._disen_encode(data)
post_disen, _ = self._disen_dynamics.observe(
embed_disen[:6, :5], data['action'][:6, :5])
feat_disen = self._disen_dynamics.get_feat(post_disen)
init_disen = {k: v[:, -1] for k, v in post_disen.items()}
# joint image pred
recon_joint, recon_main, recon_disen, recon_mask = self._joint_decode(
feat, feat_disen)
recon_joint = recon_joint.mode()[:6]
recon_main = recon_main.mode()[:6]
recon_disen = recon_disen.mode()[:6]
recon_mask = recon_mask[:6]
prior = self._dynamics.imagine(
data['action'][:6, 5:], init)
prior_disen = self._disen_dynamics.imagine(
data['action'][:6, 5:], init_disen)
feat = self._dynamics.get_feat(prior)
feat_disen = self._disen_dynamics.get_feat(prior_disen)
openl_joint, openl_main, openl_disen, openl_mask = self._joint_decode(
feat, feat_disen)
openl_joint = openl_joint.mode()
openl_main = openl_main.mode()
openl_disen = openl_disen.mode()
model_joint = tf.concat(
[recon_joint[:, :5] + 0.5, openl_joint + 0.5], 1)
error_joint = (model_joint - truth + 1) / 2
model_main = tf.concat(
[recon_main[:, :5] + 0.5, openl_main + 0.5], 1)
model_disen = tf.concat(
[recon_disen[:, :5] + 0.5, openl_disen + 0.5], 1)
model_mask = tf.concat(
[recon_mask[:, :5] + 0.5, openl_mask + 0.5], 1)
output_joint = tf.concat(
[truth, model_main, model_disen, model_joint, error_joint], 2)
output_mask = model_mask
tools.graph_summary(
self._writer, tools.video_summary, self._step, 'summary/openl', output_joint)
tools.graph_summary(
self._writer, tools.video_summary, self._step, 'summary/openl_mask', output_mask)
def _write_summaries(self):
step = int(self._step.numpy())
metrics = [(k, float(v.result())) for k, v in self._metrics.items()]
metrics_disen = [(k, float(v.result()))
for k, v in self._metrics_disen.items()]
if self._last_log is not None:
duration = time.time() - self._last_time
self._last_time += duration
metrics.append(('fps', (step - self._last_log) / duration))
self._last_log = step
[m.reset_states() for m in self._metrics.values()]
[m.reset_states() for m in self._metrics_disen.values()]
with (self._c.logdir / 'metrics.jsonl').open('a') as f:
f.write(json.dumps({'step': step, **dict(metrics)}) + '\n')
[tf.summary.scalar('agent/' + k, m) for k, m in metrics]
[tf.summary.scalar('disen/' + k, m) for k, m in metrics_disen]
print('#'*30 + ' Main ' + '#'*30)
print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics))
print('#'*30 + ' Disen ' + '#'*30)
print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics_disen))
self._writer.flush()
class InverseDreamer(Dreamer):
def __init__(self, config, datadir, actspace, writer):
super().__init__(config, datadir, actspace, writer)
def _train(self, data, log_images):
with tf.GradientTape() as model_tape:
embed = self._encode(data)
post, prior = self._dynamics.observe(embed, data['action'])
feat = self._dynamics.get_feat(post)
action_pred = self._decode(feat)
reward_pred = self._reward(feat)
likes = tools.AttrDict()
likes.action = tf.reduce_mean(
action_pred.log_prob(data['action'][:, :-1]))
likes.reward = tf.reduce_mean(
reward_pred.log_prob(data['reward']))
if self._c.pcont:
pcont_pred = self._pcont(feat)
pcont_target = self._c.discount * data['discount']
likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target))
likes.pcont *= self._c.pcont_scale
prior_dist = self._dynamics.get_dist(prior)
post_dist = self._dynamics.get_dist(post)
div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
div = tf.maximum(div, self._c.free_nats)
model_loss = self._c.kl_scale * div - sum(likes.values())
model_loss /= float(self._strategy.num_replicas_in_sync)
with tf.GradientTape() as actor_tape:
imag_feat = self._imagine_ahead(post)
reward = self._reward(imag_feat).mode()
if self._c.pcont:
pcont = self._pcont(imag_feat).mean()
else:
pcont = self._c.discount * tf.ones_like(reward)
value = self._value(imag_feat).mode()
returns = tools.lambda_return(
reward[:-1], value[:-1], pcont[:-1],
bootstrap=value[-1], lambda_=self._c.disclam, axis=0)
discount = tf.stop_gradient(tf.math.cumprod(tf.concat(
[tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0))
actor_loss = -tf.reduce_mean(discount * returns)
actor_loss /= float(self._strategy.num_replicas_in_sync)
with tf.GradientTape() as value_tape:
value_pred = self._value(imag_feat)[:-1]
target = tf.stop_gradient(returns)
value_loss = - \
tf.reduce_mean(discount * value_pred.log_prob(target))
value_loss /= float(self._strategy.num_replicas_in_sync)
model_norm = self._model_opt(model_tape, model_loss)
actor_norm = self._actor_opt(actor_tape, actor_loss)
value_norm = self._value_opt(value_tape, value_loss)
if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
if self._c.log_scalars:
self._scalar_summaries(
data, feat, prior_dist, post_dist, likes, div,
model_loss, value_loss, actor_loss, model_norm, value_norm,
actor_norm)
def _build_model(self):
acts = dict(
elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish,
leaky_relu=tf.nn.leaky_relu)
cnn_act = acts[self._c.cnn_act]
act = acts[self._c.dense_act]
self._encode = models.ConvEncoder(
self._c.cnn_depth, cnn_act, self._c.image_size)
self._dynamics = models.RSSM(
self._c.stoch_size, self._c.deter_size, self._c.deter_size)
self._decode = models.InverseDecoder(
self._actdim, 4, self._c.num_units, act=act)
self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act)
if self._c.pcont:
self._pcont = models.DenseDecoder(
(), 3, self._c.num_units, 'binary', act=act)
self._value = models.DenseDecoder((), 3, self._c.num_units, act=act)
self._actor = models.ActionDecoder(
self._actdim, 4, self._c.num_units, self._c.action_dist,
init_std=self._c.action_init_std, act=act)
model_modules = [self._encode, self._dynamics,
self._decode, self._reward]
if self._c.pcont:
model_modules.append(self._pcont)
Optimizer = functools.partial(
tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip,
wdpattern=self._c.weight_decay_pattern)
self._model_opt = Optimizer('model', model_modules, self._c.model_lr)
self._value_opt = Optimizer('value', [self._value], self._c.value_lr)
self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr)
self.train(next(self._dataset))