tia/Dreamer/run.py

122 lines
4.5 KiB
Python
Raw Permalink Normal View History

2021-06-30 01:20:44 +00:00
import wrappers
import tools
import tensorflow as tf
import argparse
import functools
import yaml
import os
import pathlib
import sys
from tensorflow.keras.mixed_precision import experimental as prec
from dreamers import Dreamer, SeparationDreamer, InverseDreamer
from env_tools import count_steps, make_env
METHOD2DREAMER = {
'dreamer': Dreamer,
'tia': SeparationDreamer,
'inverse': InverseDreamer
}
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['MUJOCO_GL'] = 'egl'
tf.get_logger().setLevel('ERROR')
sys.path.append(str(pathlib.Path(__file__).parent))
def main(method, config):
if method == 'separation':
config.logdir = os.path.join(
config.logdir, config.task,
'separation' + '_' + str(config.disen_neg_rew_scale) +
'_' + str(config.disen_rec_scale),
str(config.seed))
else:
config.logdir = os.path.join(
config.logdir, config.task,
method,
str(config.seed))
logdir = pathlib.Path(config.logdir)
logdir.mkdir(parents=True, exist_ok=True)
snapshot_dir = os.path.join(config.logdir, 'snapshots')
snapshot_dir = pathlib.Path(snapshot_dir)
snapshot_dir.mkdir(parents=True, exist_ok=True)
with open(os.path.join(config.logdir, 'config.yaml'), 'w') as f:
yaml.dump(vars(config), f, sort_keys=False)
if config.gpu_growth:
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
prec.set_policy(prec.Policy('mixed_float16'))
config.steps = int(config.steps)
config.logdir = logdir
print('Logdir', config.logdir)
# Create environments.
datadir = config.logdir / 'episodes'
writer = tf.summary.create_file_writer(
str(config.logdir), max_queue=1000, flush_millis=20000)
writer.set_as_default()
train_envs = [wrappers.Async(lambda: make_env(
2023-07-16 20:08:57 +00:00
config, writer, 'train', datadir, config.video_dir_train, store=True), config.parallel)
2021-06-30 01:20:44 +00:00
for _ in range(config.envs)]
test_envs = [wrappers.Async(lambda: make_env(
2023-07-16 20:08:57 +00:00
config, writer, 'test', datadir, config.video_dir_test, store=False), config.parallel)
2021-06-30 01:20:44 +00:00
for _ in range(config.envs)]
actspace = train_envs[0].action_space
# Prefill dataset with random episodes.
step = count_steps(datadir, config)
prefill = max(0, config.prefill - step)
print(f'Prefill dataset with {prefill} steps.')
def random_agent(o, d, _): return ([actspace.sample() for _ in d], None)
tools.simulate(random_agent, train_envs, prefill / config.action_repeat)
writer.flush()
# Train and regularly evaluate the agent.
step = count_steps(datadir, config)
print(f'Simulating agent for {config.steps-step} steps.')
DreamerModel = METHOD2DREAMER[method]
agent = DreamerModel(config, datadir, actspace, writer)
if (config.logdir / 'variables.pkl').exists():
print('Load checkpoint.')
agent.load(config.logdir / 'variables.pkl')
state = None
should_snapshot = tools.Every(config.snapshot_every)
while step < config.steps:
print('Start evaluation.')
tools.simulate(
functools.partial(agent, training=False), test_envs, episodes=1)
writer.flush()
print('Start collection.')
steps = config.eval_every // config.action_repeat
state = tools.simulate(agent, train_envs, steps, state=state)
step = count_steps(datadir, config)
agent.save(config.logdir / 'variables.pkl')
if should_snapshot(step):
agent.save(snapshot_dir / ('variables_' + str(step) + '.pkl'))
for env in train_envs + test_envs:
env.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--method', type=str, choices=['dreamer', 'inverse', 'tia'], required=True)
parser.add_argument('--configs', nargs='+', required=True)
args, remaining = parser.parse_known_args()
config_path = 'train_configs/' + args.method + '.yaml'
configs = yaml.safe_load(
(pathlib.Path(__file__).parent / config_path).read_text())
config_ = {}
for name in args.configs:
config_.update(configs[name])
parser = argparse.ArgumentParser()
for key, value in config_.items():
parser.add_argument(
f'--{key}', type=tools.args_type(value), default=value)
main(args.method, parser.parse_args(remaining))