DBC/train.py

450 lines
17 KiB
Python
Raw Normal View History

2020-10-12 22:39:25 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import argparse
import os
import gym
import time
import json
import dmc2gym
import utils
from logger import Logger
from video import VideoRecorder
from agent.baseline_agent import BaselineAgent
from agent.bisim_agent import BisimAgent
from agent.deepmdp_agent import DeepMDPAgent
from agents.navigation.carla_env import CarlaEnv
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int)
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
# train
parser.add_argument('--agent', default='bisim', type=str, choices=['baseline', 'bisim', 'deepmdp'])
parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000000, type=int)
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--hidden_dim', default=256, type=int)
parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms')
parser.add_argument('--load_encoder', default=None, type=str)
# eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
# sac
parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('--port', default=2000, type=int)
args = parser.parse_args()
return args
def evaluate(env, agent, video, num_episodes, L, step, device=None, embed_viz_dir=None, do_carla_metrics=None):
# carla metrics:
reason_each_episode_ended = []
distance_driven_each_episode = []
crash_intensity = 0.
steer = 0.
brake = 0.
count = 0
# embedding visualization
obses = []
values = []
embeddings = []
for i in range(num_episodes):
# carla metrics:
dist_driven_this_episode = 0.
obs = env.reset()
video.init(enabled=(i == 0))
done = False
episode_reward = 0
while not done:
with utils.eval_mode(agent):
action = agent.select_action(obs)
if embed_viz_dir:
obses.append(obs)
with torch.no_grad():
values.append(min(agent.critic(torch.Tensor(obs).to(device).unsqueeze(0), torch.Tensor(action).to(device).unsqueeze(0))).item())
embeddings.append(agent.critic.encoder(torch.Tensor(obs).unsqueeze(0).to(device)).cpu().detach().numpy())
obs, reward, done, info = env.step(action)
# metrics:
if do_carla_metrics:
dist_driven_this_episode += info['distance']
crash_intensity += info['crash_intensity']
steer += abs(info['steer'])
brake += info['brake']
count += 1
video.record(env)
episode_reward += reward
# metrics:
if do_carla_metrics:
reason_each_episode_ended.append(info['reason_episode_ended'])
distance_driven_each_episode.append(dist_driven_this_episode)
video.save('%d.mp4' % step)
L.log('eval/episode_reward', episode_reward, step)
if embed_viz_dir:
dataset = {'obs': obses, 'values': values, 'embeddings': embeddings}
torch.save(dataset, os.path.join(embed_viz_dir, 'train_dataset_{}.pt'.format(step)))
L.dump(step)
if do_carla_metrics:
print('METRICS--------------------------')
print("reason_each_episode_ended: {}".format(reason_each_episode_ended))
print("distance_driven_each_episode: {}".format(distance_driven_each_episode))
print('crash_intensity: {}'.format(crash_intensity / num_episodes))
print('steer: {}'.format(steer / count))
print('brake: {}'.format(brake / count))
print('---------------------------------')
def make_agent(obs_shape, action_shape, args, device):
if args.agent == 'baseline':
agent = BaselineAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
encoder_stride=args.encoder_stride,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_weight_lambda=args.decoder_weight_lambda,
transition_model_type=args.transition_model_type,
num_layers=args.num_layers,
num_filters=args.num_filters
)
elif args.agent == 'bisim':
agent = BisimAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
encoder_stride=args.encoder_stride,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_weight_lambda=args.decoder_weight_lambda,
transition_model_type=args.transition_model_type,
num_layers=args.num_layers,
num_filters=args.num_filters,
bisim_coef=args.bisim_coef
)
elif args.agent == 'deepmdp':
agent = DeepMDPAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
encoder_stride=args.encoder_stride,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_weight_lambda=args.decoder_weight_lambda,
transition_model_type=args.transition_model_type,
num_layers=args.num_layers,
num_filters=args.num_filters
)
if args.load_encoder:
model_dict = agent.actor.encoder.state_dict()
encoder_dict = torch.load(args.load_encoder)
encoder_dict = {k[8:]: v for k, v in encoder_dict.items() if 'encoder.' in k} # hack to remove encoder. string
agent.actor.encoder.load_state_dict(encoder_dict)
agent.critic.encoder.load_state_dict(encoder_dict)
return agent
def main():
args = parse_args()
utils.set_seed_everywhere(args.seed)
if args.domain_name == 'carla':
env = CarlaEnv(
render_display=args.render, # for local debugging only
display_text=args.render, # for local debugging only
changing_weather_speed=0.1, # [0, +inf)
rl_image_size=args.image_size,
max_episode_steps=1000,
frame_skip=args.action_repeat,
is_other_cars=True,
port=args.port
)
# TODO: implement env.seed(args.seed) ?
eval_env = env
else:
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
resource_files=args.resource_files,
img_source=args.img_source,
total_frames=args.total_frames,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
)
env.seed(args.seed)
eval_env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
resource_files=args.eval_resource_files,
img_source=args.img_source,
total_frames=args.total_frames,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
)
# stack several consecutive frames together
if args.encoder_type.startswith('pixel'):
2020-10-12 22:39:25 +00:00
env = utils.FrameStack(env, k=args.frame_stack)
eval_env = utils.FrameStack(eval_env, k=args.frame_stack)
utils.make_dir(args.work_dir)
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
video = VideoRecorder(video_dir if args.save_video else None)
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# the dmc2gym wrapper standardizes actions
assert env.action_space.low.min() >= -1
assert env.action_space.high.max() <= 1
replay_buffer = utils.ReplayBuffer(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
capacity=args.replay_buffer_capacity,
batch_size=args.batch_size,
device=device
)
agent = make_agent(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
args=args,
device=device
)
L = Logger(args.work_dir, use_tb=args.save_tb)
episode, episode_reward, done = 0, 0, True
start_time = time.time()
for step in range(args.num_train_steps):
if done:
if args.decoder_type == 'inverse':
for i in range(1, args.k): # fill k_obs with 0s if episode is done
replay_buffer.k_obses[replay_buffer.idx - i] = 0
if step > 0:
L.log('train/duration', time.time() - start_time, step)
start_time = time.time()
L.dump(step)
# evaluate agent periodically
if episode % args.eval_freq == 0:
L.log('eval/episode', episode, step)
evaluate(eval_env, agent, video, args.num_eval_episodes, L, step)
if args.save_model:
agent.save(model_dir, step)
if args.save_buffer:
replay_buffer.save(buffer_dir)
L.log('train/episode_reward', episode_reward, step)
obs = env.reset()
done = False
episode_reward = 0
episode_step = 0
episode += 1
reward = 0
L.log('train/episode', episode, step)
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
action = agent.sample_action(obs)
# run training update
if step >= args.init_steps:
num_updates = args.init_steps if step == args.init_steps else 1
for _ in range(num_updates):
agent.update(replay_buffer, L, step)
curr_reward = reward
next_obs, reward, done, _ = env.step(action)
# allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
done
)
episode_reward += reward
replay_buffer.add(obs, action, curr_reward, reward, next_obs, done_bool)
np.copyto(replay_buffer.k_obses[replay_buffer.idx - args.k], next_obs)
obs = next_obs
episode_step += 1
def collect_data(env, agent, num_rollouts, path_length, checkpoint_path):
rollouts = []
for i in range(num_rollouts):
obses = []
acs = []
rews = []
observation = env.reset()
for j in range(path_length):
action = agent.sample_action(observation)
next_observation, reward, done, _ = env.step(action)
obses.append(observation)
acs.append(action)
rews.append(reward)
observation = next_observation
obses.append(next_observation)
rollouts.append((obses, acs, rews))
from scipy.io import savemat
savemat(
os.path.join(checkpoint_path, "dynamics-data.mat"),
{
"trajs": np.array([path[0] for path in rollouts]),
"acs": np.array([path[1] for path in rollouts]),
"rews": np.array([path[2] for path in rollouts])
}
)
if __name__ == '__main__':
main()