# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import argparse import os import gym import time import json import dmc2gym import utils from logger import Logger from video import VideoRecorder from agent.baseline_agent import BaselineAgent from agent.bisim_agent import BisimAgent from agent.deepmdp_agent import DeepMDPAgent from agents.navigation.carla_env import CarlaEnv def parse_args(): parser = argparse.ArgumentParser() # environment parser.add_argument('--domain_name', default='cheetah') parser.add_argument('--task_name', default='run') parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--action_repeat', default=1, type=int) parser.add_argument('--frame_stack', default=3, type=int) parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--total_frames', default=1000, type=int) # replay buffer parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) # train parser.add_argument('--agent', default='bisim', type=str, choices=['baseline', 'bisim', 'deepmdp']) parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int) parser.add_argument('--batch_size', default=512, type=int) parser.add_argument('--hidden_dim', default=256, type=int) parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model') parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms') parser.add_argument('--load_encoder', default=None, type=str) # eval parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--num_eval_episodes', default=20, type=int) # critic parser.add_argument('--critic_lr', default=1e-3, type=float) parser.add_argument('--critic_beta', default=0.9, type=float) parser.add_argument('--critic_tau', default=0.005, type=float) parser.add_argument('--critic_target_update_freq', default=2, type=int) # actor parser.add_argument('--actor_lr', default=1e-3, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_update_freq', default=2, type=int) # encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_stride', default=1, type=int) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--decoder_lr', default=1e-3, type=float) parser.add_argument('--decoder_update_freq', default=1, type=int) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) # sac parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--alpha_lr', default=1e-3, type=float) parser.add_argument('--alpha_beta', default=0.9, type=float) # misc parser.add_argument('--seed', default=1, type=int) parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_buffer', default=False, action='store_true') parser.add_argument('--save_video', default=False, action='store_true') parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble']) parser.add_argument('--render', default=False, action='store_true') parser.add_argument('--port', default=2000, type=int) args = parser.parse_args() return args def evaluate(env, agent, video, num_episodes, L, step, device=None, embed_viz_dir=None, do_carla_metrics=None): # carla metrics: reason_each_episode_ended = [] distance_driven_each_episode = [] crash_intensity = 0. steer = 0. brake = 0. count = 0 # embedding visualization obses = [] values = [] embeddings = [] for i in range(num_episodes): # carla metrics: dist_driven_this_episode = 0. obs = env.reset() video.init(enabled=(i == 0)) done = False episode_reward = 0 while not done: with utils.eval_mode(agent): action = agent.select_action(obs) if embed_viz_dir: obses.append(obs) with torch.no_grad(): values.append(min(agent.critic(torch.Tensor(obs).to(device).unsqueeze(0), torch.Tensor(action).to(device).unsqueeze(0))).item()) embeddings.append(agent.critic.encoder(torch.Tensor(obs).unsqueeze(0).to(device)).cpu().detach().numpy()) obs, reward, done, info = env.step(action) # metrics: if do_carla_metrics: dist_driven_this_episode += info['distance'] crash_intensity += info['crash_intensity'] steer += abs(info['steer']) brake += info['brake'] count += 1 video.record(env) episode_reward += reward # metrics: if do_carla_metrics: reason_each_episode_ended.append(info['reason_episode_ended']) distance_driven_each_episode.append(dist_driven_this_episode) video.save('%d.mp4' % step) L.log('eval/episode_reward', episode_reward, step) if embed_viz_dir: dataset = {'obs': obses, 'values': values, 'embeddings': embeddings} torch.save(dataset, os.path.join(embed_viz_dir, 'train_dataset_{}.pt'.format(step))) L.dump(step) if do_carla_metrics: print('METRICS--------------------------') print("reason_each_episode_ended: {}".format(reason_each_episode_ended)) print("distance_driven_each_episode: {}".format(distance_driven_each_episode)) print('crash_intensity: {}'.format(crash_intensity / num_episodes)) print('steer: {}'.format(steer / count)) print('brake: {}'.format(brake / count)) print('---------------------------------') def make_agent(obs_shape, action_shape, args, device): if args.agent == 'baseline': agent = BaselineAgent( obs_shape=obs_shape, action_shape=action_shape, device=device, hidden_dim=args.hidden_dim, discount=args.discount, init_temperature=args.init_temperature, alpha_lr=args.alpha_lr, alpha_beta=args.alpha_beta, actor_lr=args.actor_lr, actor_beta=args.actor_beta, actor_log_std_min=args.actor_log_std_min, actor_log_std_max=args.actor_log_std_max, actor_update_freq=args.actor_update_freq, critic_lr=args.critic_lr, critic_beta=args.critic_beta, critic_tau=args.critic_tau, critic_target_update_freq=args.critic_target_update_freq, encoder_type=args.encoder_type, encoder_feature_dim=args.encoder_feature_dim, encoder_lr=args.encoder_lr, encoder_tau=args.encoder_tau, encoder_stride=args.encoder_stride, decoder_type=args.decoder_type, decoder_lr=args.decoder_lr, decoder_update_freq=args.decoder_update_freq, decoder_weight_lambda=args.decoder_weight_lambda, transition_model_type=args.transition_model_type, num_layers=args.num_layers, num_filters=args.num_filters ) elif args.agent == 'bisim': agent = BisimAgent( obs_shape=obs_shape, action_shape=action_shape, device=device, hidden_dim=args.hidden_dim, discount=args.discount, init_temperature=args.init_temperature, alpha_lr=args.alpha_lr, alpha_beta=args.alpha_beta, actor_lr=args.actor_lr, actor_beta=args.actor_beta, actor_log_std_min=args.actor_log_std_min, actor_log_std_max=args.actor_log_std_max, actor_update_freq=args.actor_update_freq, critic_lr=args.critic_lr, critic_beta=args.critic_beta, critic_tau=args.critic_tau, critic_target_update_freq=args.critic_target_update_freq, encoder_type=args.encoder_type, encoder_feature_dim=args.encoder_feature_dim, encoder_lr=args.encoder_lr, encoder_tau=args.encoder_tau, encoder_stride=args.encoder_stride, decoder_type=args.decoder_type, decoder_lr=args.decoder_lr, decoder_update_freq=args.decoder_update_freq, decoder_weight_lambda=args.decoder_weight_lambda, transition_model_type=args.transition_model_type, num_layers=args.num_layers, num_filters=args.num_filters, bisim_coef=args.bisim_coef ) elif args.agent == 'deepmdp': agent = DeepMDPAgent( obs_shape=obs_shape, action_shape=action_shape, device=device, hidden_dim=args.hidden_dim, discount=args.discount, init_temperature=args.init_temperature, alpha_lr=args.alpha_lr, alpha_beta=args.alpha_beta, actor_lr=args.actor_lr, actor_beta=args.actor_beta, actor_log_std_min=args.actor_log_std_min, actor_log_std_max=args.actor_log_std_max, actor_update_freq=args.actor_update_freq, encoder_stride=args.encoder_stride, critic_lr=args.critic_lr, critic_beta=args.critic_beta, critic_tau=args.critic_tau, critic_target_update_freq=args.critic_target_update_freq, encoder_type=args.encoder_type, encoder_feature_dim=args.encoder_feature_dim, encoder_lr=args.encoder_lr, encoder_tau=args.encoder_tau, decoder_type=args.decoder_type, decoder_lr=args.decoder_lr, decoder_update_freq=args.decoder_update_freq, decoder_weight_lambda=args.decoder_weight_lambda, transition_model_type=args.transition_model_type, num_layers=args.num_layers, num_filters=args.num_filters ) if args.load_encoder: model_dict = agent.actor.encoder.state_dict() encoder_dict = torch.load(args.load_encoder) encoder_dict = {k[8:]: v for k, v in encoder_dict.items() if 'encoder.' in k} # hack to remove encoder. string agent.actor.encoder.load_state_dict(encoder_dict) agent.critic.encoder.load_state_dict(encoder_dict) return agent def main(): args = parse_args() utils.set_seed_everywhere(args.seed) if args.domain_name == 'carla': env = CarlaEnv( render_display=args.render, # for local debugging only display_text=args.render, # for local debugging only changing_weather_speed=0.1, # [0, +inf) rl_image_size=args.image_size, max_episode_steps=1000, frame_skip=args.action_repeat, is_other_cars=True, port=args.port ) # TODO: implement env.seed(args.seed) ? eval_env = env else: env = dmc2gym.make( domain_name=args.domain_name, task_name=args.task_name, resource_files=args.resource_files, img_source=args.img_source, total_frames=args.total_frames, seed=args.seed, visualize_reward=False, from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, frame_skip=args.action_repeat ) env.seed(args.seed) eval_env = dmc2gym.make( domain_name=args.domain_name, task_name=args.task_name, resource_files=args.eval_resource_files, img_source=args.img_source, total_frames=args.total_frames, seed=args.seed, visualize_reward=False, from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, frame_skip=args.action_repeat ) # stack several consecutive frames together if args.encoder_type.startswith('pixel'): env = utils.FrameStack(env, k=args.frame_stack) eval_env = utils.FrameStack(eval_env, k=args.frame_stack) utils.make_dir(args.work_dir) video_dir = utils.make_dir(os.path.join(args.work_dir, 'video')) model_dir = utils.make_dir(os.path.join(args.work_dir, 'model')) buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer')) video = VideoRecorder(video_dir if args.save_video else None) with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: json.dump(vars(args), f, sort_keys=True, indent=4) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # the dmc2gym wrapper standardizes actions assert env.action_space.low.min() >= -1 assert env.action_space.high.max() <= 1 replay_buffer = utils.ReplayBuffer( obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, capacity=args.replay_buffer_capacity, batch_size=args.batch_size, device=device ) agent = make_agent( obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, args=args, device=device ) L = Logger(args.work_dir, use_tb=args.save_tb) episode, episode_reward, done = 0, 0, True start_time = time.time() for step in range(args.num_train_steps): if done: if args.decoder_type == 'inverse': for i in range(1, args.k): # fill k_obs with 0s if episode is done replay_buffer.k_obses[replay_buffer.idx - i] = 0 if step > 0: L.log('train/duration', time.time() - start_time, step) start_time = time.time() L.dump(step) # evaluate agent periodically if episode % args.eval_freq == 0: L.log('eval/episode', episode, step) evaluate(eval_env, agent, video, args.num_eval_episodes, L, step) if args.save_model: agent.save(model_dir, step) if args.save_buffer: replay_buffer.save(buffer_dir) L.log('train/episode_reward', episode_reward, step) obs = env.reset() done = False episode_reward = 0 episode_step = 0 episode += 1 reward = 0 L.log('train/episode', episode, step) # sample action for data collection if step < args.init_steps: action = env.action_space.sample() else: with utils.eval_mode(agent): action = agent.sample_action(obs) # run training update if step >= args.init_steps: num_updates = args.init_steps if step == args.init_steps else 1 for _ in range(num_updates): agent.update(replay_buffer, L, step) curr_reward = reward next_obs, reward, done, _ = env.step(action) # allow infinit bootstrap done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( done ) episode_reward += reward replay_buffer.add(obs, action, curr_reward, reward, next_obs, done_bool) np.copyto(replay_buffer.k_obses[replay_buffer.idx - args.k], next_obs) obs = next_obs episode_step += 1 def collect_data(env, agent, num_rollouts, path_length, checkpoint_path): rollouts = [] for i in range(num_rollouts): obses = [] acs = [] rews = [] observation = env.reset() for j in range(path_length): action = agent.sample_action(observation) next_observation, reward, done, _ = env.step(action) obses.append(observation) acs.append(action) rews.append(reward) observation = next_observation obses.append(next_observation) rollouts.append((obses, acs, rews)) from scipy.io import savemat savemat( os.path.join(checkpoint_path, "dynamics-data.mat"), { "trajs": np.array([path[0] for path in rollouts]), "acs": np.array([path[1] for path in rollouts]), "rews": np.array([path[2] for path in rollouts]) } ) if __name__ == '__main__': main()