2019-09-23 18:20:48 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import math
|
|
|
|
import gym
|
|
|
|
import sys
|
|
|
|
import random
|
|
|
|
import time
|
|
|
|
import json
|
|
|
|
import dmc2gym
|
|
|
|
import copy
|
|
|
|
|
|
|
|
import utils
|
|
|
|
from logger import Logger
|
|
|
|
from video import VideoRecorder
|
|
|
|
|
2019-09-23 18:38:55 +00:00
|
|
|
from sac_ae import SacAeAgent
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# environment
|
|
|
|
parser.add_argument('--domain_name', default='cheetah')
|
|
|
|
parser.add_argument('--task_name', default='run')
|
|
|
|
parser.add_argument('--image_size', default=84, type=int)
|
|
|
|
parser.add_argument('--action_repeat', default=1, type=int)
|
|
|
|
parser.add_argument('--frame_stack', default=3, type=int)
|
2023-05-16 10:28:27 +00:00
|
|
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
|
|
|
parser.add_argument('--resource_files', type=str)
|
|
|
|
parser.add_argument('--total_frames', default=10000, type=int)
|
2019-09-23 18:20:48 +00:00
|
|
|
# replay buffer
|
|
|
|
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
|
|
|
# train
|
2019-09-23 18:38:55 +00:00
|
|
|
parser.add_argument('--agent', default='sac_ae', type=str)
|
2019-09-23 18:20:48 +00:00
|
|
|
parser.add_argument('--init_steps', default=1000, type=int)
|
|
|
|
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
2023-05-16 10:28:27 +00:00
|
|
|
parser.add_argument('--batch_size', default=512, type=int)
|
2019-09-24 01:22:49 +00:00
|
|
|
parser.add_argument('--hidden_dim', default=1024, type=int)
|
2019-09-23 18:20:48 +00:00
|
|
|
# eval
|
|
|
|
parser.add_argument('--eval_freq', default=10000, type=int)
|
|
|
|
parser.add_argument('--num_eval_episodes', default=10, type=int)
|
|
|
|
# critic
|
|
|
|
parser.add_argument('--critic_lr', default=1e-3, type=float)
|
|
|
|
parser.add_argument('--critic_beta', default=0.9, type=float)
|
2019-09-24 01:22:49 +00:00
|
|
|
parser.add_argument('--critic_tau', default=0.01, type=float)
|
2019-09-23 18:20:48 +00:00
|
|
|
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
|
|
|
# actor
|
|
|
|
parser.add_argument('--actor_lr', default=1e-3, type=float)
|
|
|
|
parser.add_argument('--actor_beta', default=0.9, type=float)
|
|
|
|
parser.add_argument('--actor_log_std_min', default=-10, type=float)
|
|
|
|
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
|
|
|
parser.add_argument('--actor_update_freq', default=2, type=int)
|
|
|
|
# encoder/decoder
|
2019-09-23 18:38:55 +00:00
|
|
|
parser.add_argument('--encoder_type', default='pixel', type=str)
|
2019-09-23 18:20:48 +00:00
|
|
|
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
|
|
|
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
2019-09-24 01:22:49 +00:00
|
|
|
parser.add_argument('--encoder_tau', default=0.05, type=float)
|
2019-09-23 18:38:55 +00:00
|
|
|
parser.add_argument('--decoder_type', default='pixel', type=str)
|
2019-09-23 18:20:48 +00:00
|
|
|
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
|
|
|
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
2019-09-24 01:22:49 +00:00
|
|
|
parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float)
|
|
|
|
parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float)
|
2019-09-23 18:20:48 +00:00
|
|
|
parser.add_argument('--num_layers', default=4, type=int)
|
|
|
|
parser.add_argument('--num_filters', default=32, type=int)
|
|
|
|
# sac
|
|
|
|
parser.add_argument('--discount', default=0.99, type=float)
|
2019-09-24 01:22:49 +00:00
|
|
|
parser.add_argument('--init_temperature', default=0.1, type=float)
|
|
|
|
parser.add_argument('--alpha_lr', default=1e-4, type=float)
|
|
|
|
parser.add_argument('--alpha_beta', default=0.5, type=float)
|
2019-09-23 18:20:48 +00:00
|
|
|
# misc
|
|
|
|
parser.add_argument('--seed', default=1, type=int)
|
|
|
|
parser.add_argument('--work_dir', default='.', type=str)
|
|
|
|
parser.add_argument('--save_tb', default=False, action='store_true')
|
|
|
|
parser.add_argument('--save_model', default=False, action='store_true')
|
|
|
|
parser.add_argument('--save_buffer', default=False, action='store_true')
|
|
|
|
parser.add_argument('--save_video', default=False, action='store_true')
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(env, agent, video, num_episodes, L, step):
|
|
|
|
for i in range(num_episodes):
|
|
|
|
obs = env.reset()
|
|
|
|
video.init(enabled=(i == 0))
|
|
|
|
done = False
|
|
|
|
episode_reward = 0
|
|
|
|
while not done:
|
|
|
|
with utils.eval_mode(agent):
|
|
|
|
action = agent.select_action(obs)
|
|
|
|
obs, reward, done, _ = env.step(action)
|
|
|
|
video.record(env)
|
|
|
|
episode_reward += reward
|
|
|
|
|
|
|
|
video.save('%d.mp4' % step)
|
|
|
|
L.log('eval/episode_reward', episode_reward, step)
|
|
|
|
L.dump(step)
|
|
|
|
|
|
|
|
|
2019-09-23 19:24:30 +00:00
|
|
|
def make_agent(obs_shape, action_shape, args, device):
|
2019-09-23 18:38:55 +00:00
|
|
|
if args.agent == 'sac_ae':
|
|
|
|
return SacAeAgent(
|
2019-09-23 18:20:48 +00:00
|
|
|
obs_shape=obs_shape,
|
|
|
|
action_shape=action_shape,
|
|
|
|
device=device,
|
|
|
|
hidden_dim=args.hidden_dim,
|
|
|
|
discount=args.discount,
|
|
|
|
init_temperature=args.init_temperature,
|
|
|
|
alpha_lr=args.alpha_lr,
|
|
|
|
alpha_beta=args.alpha_beta,
|
|
|
|
actor_lr=args.actor_lr,
|
|
|
|
actor_beta=args.actor_beta,
|
|
|
|
actor_log_std_min=args.actor_log_std_min,
|
|
|
|
actor_log_std_max=args.actor_log_std_max,
|
|
|
|
actor_update_freq=args.actor_update_freq,
|
|
|
|
critic_lr=args.critic_lr,
|
|
|
|
critic_beta=args.critic_beta,
|
|
|
|
critic_tau=args.critic_tau,
|
|
|
|
critic_target_update_freq=args.critic_target_update_freq,
|
|
|
|
encoder_type=args.encoder_type,
|
|
|
|
encoder_feature_dim=args.encoder_feature_dim,
|
|
|
|
encoder_lr=args.encoder_lr,
|
|
|
|
encoder_tau=args.encoder_tau,
|
|
|
|
decoder_type=args.decoder_type,
|
|
|
|
decoder_lr=args.decoder_lr,
|
|
|
|
decoder_update_freq=args.decoder_update_freq,
|
|
|
|
decoder_latent_lambda=args.decoder_latent_lambda,
|
|
|
|
decoder_weight_lambda=args.decoder_weight_lambda,
|
|
|
|
num_layers=args.num_layers,
|
2019-09-23 18:38:55 +00:00
|
|
|
num_filters=args.num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
assert 'agent is not supported: %s' % args.agent
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
utils.set_seed_everywhere(args.seed)
|
|
|
|
|
|
|
|
env = dmc2gym.make(
|
|
|
|
domain_name=args.domain_name,
|
|
|
|
task_name=args.task_name,
|
|
|
|
seed=args.seed,
|
|
|
|
visualize_reward=False,
|
|
|
|
from_pixels=(args.encoder_type == 'pixel'),
|
|
|
|
height=args.image_size,
|
|
|
|
width=args.image_size,
|
2023-05-16 10:28:27 +00:00
|
|
|
frame_skip=args.action_repeat,
|
|
|
|
img_source=args.img_source,
|
|
|
|
resource_files=args.resource_files,
|
|
|
|
total_frames=args.total_frames
|
2019-09-23 18:20:48 +00:00
|
|
|
)
|
|
|
|
env.seed(args.seed)
|
|
|
|
|
|
|
|
# stack several consecutive frames together
|
|
|
|
if args.encoder_type == 'pixel':
|
|
|
|
env = utils.FrameStack(env, k=args.frame_stack)
|
|
|
|
|
|
|
|
utils.make_dir(args.work_dir)
|
|
|
|
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
|
|
|
|
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
|
|
|
|
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
|
|
|
|
|
|
|
|
video = VideoRecorder(video_dir if args.save_video else None)
|
|
|
|
|
|
|
|
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
|
|
|
|
json.dump(vars(args), f, sort_keys=True, indent=4)
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
# the dmc2gym wrapper standardizes actions
|
|
|
|
assert env.action_space.low.min() >= -1
|
|
|
|
assert env.action_space.high.max() <= 1
|
|
|
|
|
|
|
|
replay_buffer = utils.ReplayBuffer(
|
|
|
|
obs_shape=env.observation_space.shape,
|
|
|
|
action_shape=env.action_space.shape,
|
|
|
|
capacity=args.replay_buffer_capacity,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
device=device
|
|
|
|
)
|
|
|
|
|
|
|
|
agent = make_agent(
|
|
|
|
obs_shape=env.observation_space.shape,
|
|
|
|
action_shape=env.action_space.shape,
|
|
|
|
args=args,
|
|
|
|
device=device
|
|
|
|
)
|
|
|
|
|
|
|
|
L = Logger(args.work_dir, use_tb=args.save_tb)
|
|
|
|
|
|
|
|
episode, episode_reward, done = 0, 0, True
|
|
|
|
start_time = time.time()
|
|
|
|
for step in range(args.num_train_steps):
|
|
|
|
if done:
|
|
|
|
if step > 0:
|
|
|
|
L.log('train/duration', time.time() - start_time, step)
|
|
|
|
start_time = time.time()
|
|
|
|
L.dump(step)
|
|
|
|
|
|
|
|
# evaluate agent periodically
|
|
|
|
if step % args.eval_freq == 0:
|
|
|
|
L.log('eval/episode', episode, step)
|
|
|
|
evaluate(env, agent, video, args.num_eval_episodes, L, step)
|
|
|
|
if args.save_model:
|
|
|
|
agent.save(model_dir, step)
|
|
|
|
if args.save_buffer:
|
|
|
|
replay_buffer.save(buffer_dir)
|
|
|
|
|
|
|
|
L.log('train/episode_reward', episode_reward, step)
|
|
|
|
|
|
|
|
obs = env.reset()
|
|
|
|
done = False
|
|
|
|
episode_reward = 0
|
|
|
|
episode_step = 0
|
|
|
|
episode += 1
|
|
|
|
|
|
|
|
L.log('train/episode', episode, step)
|
|
|
|
|
2023-05-16 10:28:27 +00:00
|
|
|
if episode_step == 0:
|
|
|
|
last_obs = obs
|
|
|
|
# sample action for data collection
|
|
|
|
if step < args.init_steps:
|
|
|
|
last_action = env.action_space.sample()
|
|
|
|
else:
|
|
|
|
with utils.eval_mode(agent):
|
|
|
|
last_action = agent.sample_action(last_obs)
|
|
|
|
|
|
|
|
curr_obs, last_reward, last_done, _ = env.step(last_action)
|
|
|
|
|
|
|
|
# allow infinit bootstrap
|
|
|
|
last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done)
|
|
|
|
episode_reward += last_reward
|
|
|
|
|
|
|
|
# sample action for data collection
|
|
|
|
if step < args.init_steps:
|
|
|
|
action = env.action_space.sample()
|
|
|
|
else:
|
|
|
|
with utils.eval_mode(agent):
|
|
|
|
action = agent.sample_action(curr_obs)
|
|
|
|
|
|
|
|
next_obs, reward, done, _ = env.step(action)
|
|
|
|
|
|
|
|
# allow infinit bootstrap
|
|
|
|
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
|
|
|
episode_reward += reward
|
|
|
|
|
|
|
|
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
|
|
|
|
|
|
|
last_obs = curr_obs
|
|
|
|
last_action = action
|
|
|
|
last_reward = reward
|
|
|
|
last_done = done
|
|
|
|
curr_obs = next_obs
|
|
|
|
|
2019-09-23 18:20:48 +00:00
|
|
|
# sample action for data collection
|
|
|
|
if step < args.init_steps:
|
|
|
|
action = env.action_space.sample()
|
|
|
|
else:
|
|
|
|
with utils.eval_mode(agent):
|
2023-05-16 10:28:27 +00:00
|
|
|
action = agent.sample_action(curr_obs)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
2023-05-16 10:28:27 +00:00
|
|
|
|
2019-09-23 18:20:48 +00:00
|
|
|
# run training update
|
|
|
|
if step >= args.init_steps:
|
2023-05-16 10:28:27 +00:00
|
|
|
#num_updates = args.init_steps if step == args.init_steps else 1
|
|
|
|
num_updates = 1 if step == args.init_steps else 1
|
2019-09-23 18:20:48 +00:00
|
|
|
for _ in range(num_updates):
|
|
|
|
agent.update(replay_buffer, L, step)
|
|
|
|
|
|
|
|
next_obs, reward, done, _ = env.step(action)
|
|
|
|
|
|
|
|
# allow infinit bootstrap
|
2023-05-16 10:28:27 +00:00
|
|
|
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
2019-09-23 18:20:48 +00:00
|
|
|
episode_reward += reward
|
|
|
|
|
2023-05-16 10:28:27 +00:00
|
|
|
#replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
|
|
|
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
obs = next_obs
|
|
|
|
episode_step += 1
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|