85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
import os
|
|
import json
|
|
import dmc2gym
|
|
import tensorflow as tf
|
|
from tensorflow.keras.mixed_precision import experimental as prec
|
|
|
|
import tools
|
|
import wrappers
|
|
|
|
def preprocess(obs, config):
|
|
dtype = prec.global_policy().compute_dtype
|
|
obs = obs.copy()
|
|
with tf.device('cpu:0'):
|
|
obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5
|
|
clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[
|
|
config.clip_rewards]
|
|
obs['reward'] = clip_rewards(obs['reward'])
|
|
return obs
|
|
|
|
def count_steps(datadir, config):
|
|
return tools.count_episodes(datadir)[1] * config.action_repeat
|
|
|
|
def summarize_episode(episode, config, datadir, writer, prefix):
|
|
episodes, steps = tools.count_episodes(datadir)
|
|
length = (len(episode['reward']) - 1) * config.action_repeat
|
|
ret = episode['reward'].sum()
|
|
print(f'{prefix.title()} episode of length {length} with return {ret:.1f}.')
|
|
metrics = [
|
|
(f'{prefix}/return', float(episode['reward'].sum())),
|
|
(f'{prefix}/length', len(episode['reward']) - 1),
|
|
('episodes', episodes)]
|
|
|
|
step = count_steps(datadir, config)
|
|
with (config.logdir / 'metrics.jsonl').open('a') as f:
|
|
f.write(json.dumps(dict([('step', step)] + metrics)) + '\n')
|
|
with writer.as_default(): # Env might run in a different thread.
|
|
tf.summary.experimental.set_step(step)
|
|
[tf.summary.scalar('sim/' + k, v) for k, v in metrics]
|
|
if prefix == 'test':
|
|
tools.video_summary(f'sim/{prefix}/video', episode['image'][None])
|
|
|
|
def make_env(config, writer, prefix, datadir, video_dir, store):
|
|
suite, domain_task_distractor = config.task.split('_', 1)
|
|
domain, task_distractor = domain_task_distractor.split('_', 1)
|
|
task, distractor = task_distractor.split('_', 1)
|
|
|
|
if distractor == 'driving':
|
|
img_source = 'video'
|
|
total_frames = 1000
|
|
resource_files = os.path.join(video_dir, '*.mp4')
|
|
elif distractor == 'noise':
|
|
img_source = 'noise'
|
|
total_frames = None
|
|
resource_files = None
|
|
elif distractor == 'none':
|
|
img_source = None
|
|
total_frames = None
|
|
resource_files = None
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
env = dmc2gym.make(
|
|
domain_name=domain,
|
|
task_name=task,
|
|
resource_files=resource_files,
|
|
img_source=img_source,
|
|
total_frames=total_frames,
|
|
seed=config.seed,
|
|
visualize_reward=False,
|
|
from_pixels=True,
|
|
height=config.image_size,
|
|
width=config.image_size,
|
|
frame_skip=config.action_repeat
|
|
)
|
|
env = wrappers.DMC2GYMWrapper(env)
|
|
env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
|
|
callbacks = []
|
|
if store:
|
|
callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
|
|
callbacks.append(
|
|
lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
|
|
env = wrappers.Collect(env, callbacks, config.precision)
|
|
env = wrappers.RewardObs(env)
|
|
return env
|