sac_ae_if/train.py

270 lines
9.6 KiB
Python
Raw Normal View History

2019-09-23 18:20:48 +00:00
import numpy as np
import torch
import argparse
import os
import math
import gym
import sys
import random
import time
import json
import dmc2gym
import copy
import utils
from logger import Logger
from video import VideoRecorder
2019-09-23 18:38:55 +00:00
from sac_ae import SacAeAgent
2019-09-23 18:20:48 +00:00
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
2023-05-24 17:43:02 +00:00
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--resource_files', type=str)
2023-05-24 17:53:22 +00:00
<<<<<<< HEAD
2023-05-24 17:43:02 +00:00
parser.add_argument('--resource_files_test', type=str)
2023-05-24 17:53:22 +00:00
=======
>>>>>>> origin/tester_1
2023-05-24 17:43:02 +00:00
parser.add_argument('--total_frames', default=10000, type=int)
2019-09-23 18:20:48 +00:00
# replay buffer
2023-05-24 17:43:02 +00:00
parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
2019-09-23 18:20:48 +00:00
# train
2019-09-23 18:38:55 +00:00
parser.add_argument('--agent', default='sac_ae', type=str)
2019-09-23 18:20:48 +00:00
parser.add_argument('--init_steps', default=1000, type=int)
2023-05-24 17:43:02 +00:00
parser.add_argument('--num_train_steps', default=2000000, type=int)
parser.add_argument('--batch_size', default=32, type=int)
2019-09-24 01:22:49 +00:00
parser.add_argument('--hidden_dim', default=1024, type=int)
2019-09-23 18:20:48 +00:00
# eval
parser.add_argument('--eval_freq', default=10000, type=int)
parser.add_argument('--num_eval_episodes', default=10, type=int)
# critic
2023-05-24 17:43:02 +00:00
parser.add_argument('--critic_lr', default=1e-4, type=float)
2019-09-23 18:20:48 +00:00
parser.add_argument('--critic_beta', default=0.9, type=float)
2019-09-24 01:22:49 +00:00
parser.add_argument('--critic_tau', default=0.01, type=float)
2019-09-23 18:20:48 +00:00
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
2023-05-24 17:43:02 +00:00
parser.add_argument('--actor_lr', default=1e-4, type=float)
2019-09-23 18:20:48 +00:00
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
2019-09-23 18:38:55 +00:00
parser.add_argument('--encoder_type', default='pixel', type=str)
2023-05-24 17:43:02 +00:00
parser.add_argument('--encoder_feature_dim', default=250, type=int)
parser.add_argument('--encoder_lr', default=1e-4, type=float)
2019-09-24 01:22:49 +00:00
parser.add_argument('--encoder_tau', default=0.05, type=float)
2019-09-23 18:38:55 +00:00
parser.add_argument('--decoder_type', default='pixel', type=str)
2023-05-24 17:43:02 +00:00
parser.add_argument('--decoder_lr', default=1e-4, type=float)
2019-09-23 18:20:48 +00:00
parser.add_argument('--decoder_update_freq', default=1, type=int)
2019-09-24 01:22:49 +00:00
parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float)
parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float)
2019-09-23 18:20:48 +00:00
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
# sac
parser.add_argument('--discount', default=0.99, type=float)
2019-09-24 01:22:49 +00:00
parser.add_argument('--init_temperature', default=0.1, type=float)
parser.add_argument('--alpha_lr', default=1e-4, type=float)
parser.add_argument('--alpha_beta', default=0.5, type=float)
2019-09-23 18:20:48 +00:00
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
args = parser.parse_args()
return args
def evaluate(env, agent, video, num_episodes, L, step):
for i in range(num_episodes):
obs = env.reset()
video.init(enabled=(i == 0))
done = False
episode_reward = 0
while not done:
with utils.eval_mode(agent):
action = agent.select_action(obs)
obs, reward, done, _ = env.step(action)
video.record(env)
episode_reward += reward
video.save('%d.mp4' % step)
L.log('eval/episode_reward', episode_reward, step)
L.dump(step)
2019-09-23 19:24:30 +00:00
def make_agent(obs_shape, action_shape, args, device):
2019-09-23 18:38:55 +00:00
if args.agent == 'sac_ae':
return SacAeAgent(
2019-09-23 18:20:48 +00:00
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_latent_lambda=args.decoder_latent_lambda,
decoder_weight_lambda=args.decoder_weight_lambda,
num_layers=args.num_layers,
2019-09-23 18:38:55 +00:00
num_filters=args.num_filters
2019-09-23 18:20:48 +00:00
)
else:
assert 'agent is not supported: %s' % args.agent
def main():
args = parse_args()
utils.set_seed_everywhere(args.seed)
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
2023-05-24 17:43:02 +00:00
frame_skip=args.action_repeat,
img_source=args.img_source,
resource_files=args.resource_files,
total_frames=args.total_frames
2019-09-23 18:20:48 +00:00
)
env.seed(args.seed)
2023-05-24 17:43:02 +00:00
env_test = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat,
img_source=args.img_source,
resource_files=args.resource_files_test,
total_frames=args.total_frames
)
env_test.seed(args.seed)
2019-09-23 18:20:48 +00:00
# stack several consecutive frames together
if args.encoder_type == 'pixel':
env = utils.FrameStack(env, k=args.frame_stack)
2023-05-24 17:43:02 +00:00
env_test = utils.FrameStack(env_test, k=args.frame_stack)
2019-09-23 18:20:48 +00:00
utils.make_dir(args.work_dir)
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
video = VideoRecorder(video_dir if args.save_video else None)
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# the dmc2gym wrapper standardizes actions
assert env.action_space.low.min() >= -1
assert env.action_space.high.max() <= 1
replay_buffer = utils.ReplayBuffer(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
capacity=args.replay_buffer_capacity,
batch_size=args.batch_size,
device=device
)
agent = make_agent(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
args=args,
device=device
)
L = Logger(args.work_dir, use_tb=args.save_tb)
episode, episode_reward, done = 0, 0, True
start_time = time.time()
for step in range(args.num_train_steps):
if done:
if step > 0:
L.log('train/duration', time.time() - start_time, step)
start_time = time.time()
L.dump(step)
# evaluate agent periodically
if step % args.eval_freq == 0:
L.log('eval/episode', episode, step)
2023-05-24 17:43:02 +00:00
evaluate(env_test, agent, video, args.num_eval_episodes, L, step)
2019-09-23 18:20:48 +00:00
if args.save_model:
agent.save(model_dir, step)
if args.save_buffer:
replay_buffer.save(buffer_dir)
L.log('train/episode_reward', episode_reward, step)
obs = env.reset()
done = False
episode_reward = 0
episode_step = 0
episode += 1
L.log('train/episode', episode, step)
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
action = agent.sample_action(obs)
# run training update
if step >= args.init_steps:
num_updates = args.init_steps if step == args.init_steps else 1
for _ in range(num_updates):
agent.update(replay_buffer, L, step)
next_obs, reward, done, _ = env.step(action)
# allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
done
)
episode_reward += reward
2019-09-23 18:38:55 +00:00
replay_buffer.add(obs, action, reward, next_obs, done_bool)
2019-09-23 18:20:48 +00:00
obs = next_obs
episode_step += 1
if __name__ == '__main__':
main()