sac_ae_if/train.py

287 lines
10 KiB
Python
Raw Permalink Normal View History

2019-09-23 18:20:48 +00:00
import numpy as np
import torch
import argparse
import os
import math
import gym
import sys
import random
import time
import json
import dmc2gym
import copy
import utils
from logger import Logger
from video import VideoRecorder
2019-09-23 18:38:55 +00:00
from sac_ae import SacAeAgent
2019-09-23 18:20:48 +00:00
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
2023-05-16 10:28:27 +00:00
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--resource_files', type=str)
parser.add_argument('--total_frames', default=10000, type=int)
2019-09-23 18:20:48 +00:00
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
# train
2019-09-23 18:38:55 +00:00
parser.add_argument('--agent', default='sac_ae', type=str)
2019-09-23 18:20:48 +00:00
parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000000, type=int)
2023-05-16 10:28:27 +00:00
parser.add_argument('--batch_size', default=512, type=int)
2019-09-24 01:22:49 +00:00
parser.add_argument('--hidden_dim', default=1024, type=int)
2019-09-23 18:20:48 +00:00
# eval
parser.add_argument('--eval_freq', default=10000, type=int)
parser.add_argument('--num_eval_episodes', default=10, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
2019-09-24 01:22:49 +00:00
parser.add_argument('--critic_tau', default=0.01, type=float)
2019-09-23 18:20:48 +00:00
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
2019-09-23 18:38:55 +00:00
parser.add_argument('--encoder_type', default='pixel', type=str)
2019-09-23 18:20:48 +00:00
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
2019-09-24 01:22:49 +00:00
parser.add_argument('--encoder_tau', default=0.05, type=float)
2019-09-23 18:38:55 +00:00
parser.add_argument('--decoder_type', default='pixel', type=str)
2019-09-23 18:20:48 +00:00
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
2019-09-24 01:22:49 +00:00
parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float)
parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float)
2019-09-23 18:20:48 +00:00
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
# sac
parser.add_argument('--discount', default=0.99, type=float)
2019-09-24 01:22:49 +00:00
parser.add_argument('--init_temperature', default=0.1, type=float)
parser.add_argument('--alpha_lr', default=1e-4, type=float)
parser.add_argument('--alpha_beta', default=0.5, type=float)
2019-09-23 18:20:48 +00:00
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
args = parser.parse_args()
return args
def evaluate(env, agent, video, num_episodes, L, step):
for i in range(num_episodes):
obs = env.reset()
video.init(enabled=(i == 0))
done = False
episode_reward = 0
while not done:
with utils.eval_mode(agent):
action = agent.select_action(obs)
obs, reward, done, _ = env.step(action)
video.record(env)
episode_reward += reward
video.save('%d.mp4' % step)
L.log('eval/episode_reward', episode_reward, step)
L.dump(step)
2019-09-23 19:24:30 +00:00
def make_agent(obs_shape, action_shape, args, device):
2019-09-23 18:38:55 +00:00
if args.agent == 'sac_ae':
return SacAeAgent(
2019-09-23 18:20:48 +00:00
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_latent_lambda=args.decoder_latent_lambda,
decoder_weight_lambda=args.decoder_weight_lambda,
num_layers=args.num_layers,
2019-09-23 18:38:55 +00:00
num_filters=args.num_filters
2019-09-23 18:20:48 +00:00
)
else:
assert 'agent is not supported: %s' % args.agent
def main():
args = parse_args()
utils.set_seed_everywhere(args.seed)
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
2023-05-16 10:28:27 +00:00
frame_skip=args.action_repeat,
img_source=args.img_source,
resource_files=args.resource_files,
total_frames=args.total_frames
2019-09-23 18:20:48 +00:00
)
env.seed(args.seed)
# stack several consecutive frames together
if args.encoder_type == 'pixel':
env = utils.FrameStack(env, k=args.frame_stack)
utils.make_dir(args.work_dir)
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
video = VideoRecorder(video_dir if args.save_video else None)
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# the dmc2gym wrapper standardizes actions
assert env.action_space.low.min() >= -1
assert env.action_space.high.max() <= 1
replay_buffer = utils.ReplayBuffer(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
capacity=args.replay_buffer_capacity,
batch_size=args.batch_size,
device=device
)
agent = make_agent(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
args=args,
device=device
)
L = Logger(args.work_dir, use_tb=args.save_tb)
episode, episode_reward, done = 0, 0, True
start_time = time.time()
for step in range(args.num_train_steps):
if done:
if step > 0:
L.log('train/duration', time.time() - start_time, step)
start_time = time.time()
L.dump(step)
# evaluate agent periodically
if step % args.eval_freq == 0:
L.log('eval/episode', episode, step)
evaluate(env, agent, video, args.num_eval_episodes, L, step)
if args.save_model:
agent.save(model_dir, step)
if args.save_buffer:
replay_buffer.save(buffer_dir)
L.log('train/episode_reward', episode_reward, step)
obs = env.reset()
done = False
episode_reward = 0
episode_step = 0
episode += 1
L.log('train/episode', episode, step)
2023-05-16 10:28:27 +00:00
if episode_step == 0:
last_obs = obs
# sample action for data collection
if step < args.init_steps:
last_action = env.action_space.sample()
else:
with utils.eval_mode(agent):
last_action = agent.sample_action(last_obs)
curr_obs, last_reward, last_done, _ = env.step(last_action)
# allow infinit bootstrap
last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done)
episode_reward += last_reward
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
action = agent.sample_action(curr_obs)
next_obs, reward, done, _ = env.step(action)
# allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
episode_reward += reward
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
last_obs = curr_obs
last_action = action
last_reward = reward
last_done = done
curr_obs = next_obs
2019-09-23 18:20:48 +00:00
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
2023-05-16 10:28:27 +00:00
action = agent.sample_action(curr_obs)
2019-09-23 18:20:48 +00:00
2023-05-16 10:28:27 +00:00
2019-09-23 18:20:48 +00:00
# run training update
if step >= args.init_steps:
2023-05-16 10:28:27 +00:00
#num_updates = args.init_steps if step == args.init_steps else 1
num_updates = 1 if step == args.init_steps else 1
2019-09-23 18:20:48 +00:00
for _ in range(num_updates):
agent.update(replay_buffer, L, step)
next_obs, reward, done, _ = env.step(action)
# allow infinit bootstrap
2023-05-16 10:28:27 +00:00
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
2019-09-23 18:20:48 +00:00
episode_reward += reward
2023-05-16 10:28:27 +00:00
#replay_buffer.add(obs, action, reward, next_obs, done_bool)
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
2019-09-23 18:20:48 +00:00
obs = next_obs
episode_step += 1
if __name__ == '__main__':
main()