sac_ae_if/train.py

316 lines
11 KiB
Python
Raw Normal View History

2019-09-23 18:20:48 +00:00
import numpy as np
import torch
import argparse
import os
import math
import gym
import sys
import random
import time
import json
import dmc2gym
import copy
import utils
from logger import Logger
from video import VideoRecorder
from sac import SACAgent
from td3 import TD3Agent
from ddpg import DDPGAgent
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
# train
parser.add_argument('--agent', default='sac', type=str)
parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000000, type=int)
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--hidden_dim', default=256, type=int)
# eval
parser.add_argument('--eval_freq', default=10000, type=int)
parser.add_argument('--num_eval_episodes', default=10, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
parser.add_argument('--encoder_type', default='identity', type=str)
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--decoder_type', default='identity', type=str)
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--decoder_kl_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--freeze_encoder', default=False, action='store_true')
parser.add_argument('--use_dynamics', default=False, action='store_true')
# sac
parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# td3
parser.add_argument('--policy_noise', default=0.2, type=float)
parser.add_argument('--expl_noise', default=0.1, type=float)
parser.add_argument('--noise_clip', default=0.5, type=float)
parser.add_argument('--tau', default=0.005, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--pretrained_info', default=None, type=str)
parser.add_argument('--pretrained_decoder', default=False, action='store_true')
args = parser.parse_args()
return args
def evaluate(env, agent, video, num_episodes, L, step):
for i in range(num_episodes):
obs = env.reset()
video.init(enabled=(i == 0))
done = False
episode_reward = 0
while not done:
with utils.eval_mode(agent):
action = agent.select_action(obs)
obs, reward, done, _ = env.step(action)
video.record(env)
episode_reward += reward
video.save('%d.mp4' % step)
L.log('eval/episode_reward', episode_reward, step)
L.dump(step)
def make_agent(obs_shape, state_shape, action_shape, args, device):
if args.agent == 'sac':
return SACAgent(
obs_shape=obs_shape,
state_shape=state_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_latent_lambda=args.decoder_latent_lambda,
decoder_weight_lambda=args.decoder_weight_lambda,
decoder_kl_lambda=args.decoder_kl_lambda,
num_layers=args.num_layers,
num_filters=args.num_filters,
freeze_encoder=args.freeze_encoder,
use_dynamics=args.use_dynamics
)
elif args.agent == 'td3':
return TD3Agent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
discount=args.discount,
tau=args.tau,
policy_noise=args.policy_noise,
noise_clip=args.noise_clip,
expl_noise=args.expl_noise,
actor_lr=args.actor_lr,
critic_lr=args.critic_lr,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
actor_update_freq=args.actor_update_freq,
target_update_freq=args.critic_target_update_freq
)
elif args.agent == 'ddpg':
return DDPGAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
discount=args.discount,
tau=args.tau,
actor_lr=args.actor_lr,
critic_lr=args.critic_lr,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim
)
else:
assert 'agent is not supported: %s' % args.agent
def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder):
path, version = pretrained_info.split(':')
pretrained_agent = copy.deepcopy(agent)
pretrained_agent.load(path, int(version))
agent.critic.encoder.load_state_dict(
pretrained_agent.critic.encoder.state_dict()
)
agent.actor.encoder.load_state_dict(
pretrained_agent.actor.encoder.state_dict()
)
if pretrained_decoder:
agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict())
return agent
def main():
args = parse_args()
utils.set_seed_everywhere(args.seed)
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
)
env.seed(args.seed)
# stack several consecutive frames together
if args.encoder_type == 'pixel':
env = utils.FrameStack(env, k=args.frame_stack)
utils.make_dir(args.work_dir)
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
video = VideoRecorder(video_dir if args.save_video else None)
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# the dmc2gym wrapper standardizes actions
assert env.action_space.low.min() >= -1
assert env.action_space.high.max() <= 1
replay_buffer = utils.ReplayBuffer(
obs_shape=env.observation_space.shape,
state_shape=env.state_space.shape,
action_shape=env.action_space.shape,
capacity=args.replay_buffer_capacity,
batch_size=args.batch_size,
device=device
)
agent = make_agent(
obs_shape=env.observation_space.shape,
state_shape=env.state_space.shape,
action_shape=env.action_space.shape,
args=args,
device=device
)
if args.pretrained_info is not None:
agent = load_pretrained_encoder(
agent, args.pretrained_info, args.pretrained_decoder
)
L = Logger(args.work_dir, use_tb=args.save_tb)
episode, episode_reward, done = 0, 0, True
start_time = time.time()
for step in range(args.num_train_steps):
if done:
if step > 0:
L.log('train/duration', time.time() - start_time, step)
start_time = time.time()
L.dump(step)
# evaluate agent periodically
if step % args.eval_freq == 0:
L.log('eval/episode', episode, step)
evaluate(env, agent, video, args.num_eval_episodes, L, step)
if args.save_model:
agent.save(model_dir, step)
if args.save_buffer:
replay_buffer.save(buffer_dir)
L.log('train/episode_reward', episode_reward, step)
obs = env.reset()
done = False
episode_reward = 0
episode_step = 0
episode += 1
L.log('train/episode', episode, step)
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
action = agent.sample_action(obs)
# run training update
if step >= args.init_steps:
num_updates = args.init_steps if step == args.init_steps else 1
for _ in range(num_updates):
agent.update(replay_buffer, L, step)
state = env.env.env._current_state
next_obs, reward, done, _ = env.step(action)
next_state = env.env.env._current_state.shape
# allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
done
)
episode_reward += reward
replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state)
obs = next_obs
episode_step += 1
if __name__ == '__main__':
main()