Curiosity/DPI/train.py

196 lines
8.8 KiB
Python

import numpy as np
import torch
import argparse
import os
import gym
import time
import json
import dmc2gym
import wandb
import utils
from utils import ReplayBuffer, make_env
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger
from video import VideoRecorder
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent
#from agents.navigation.carla_env import CarlaEnv
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=4, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int)
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
parser.add_argument('--episode_length', default=1000, type=int)
# train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000, type=int)
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
parser.add_argument('--load_encoder', default=None, type=str)
# eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
# sac
parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--seed_steps', default=5000, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('--port', default=2000, type=int)
args = parser.parse_args()
return args
class DPI:
def __init__(self, args):
# wandb config
#run = wandb.init(project="dpi")
self.args = args
# environment setup
self.env = make_env(self.args)
self.env.seed(self.args.seed)
# stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
# create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length,
batch_size=args.batch_size)
# create work directory
utils.make_dir(self.args.work_dir)
video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create video recorder
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
#video.init(enabled=True)
# create models
self.build_models(use_saved=False, saved_model_dir=model_dir)
def build_models(self, use_saved, saved_model_dir=None):
self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128
)
# model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + list(self.transition_model.parameters())
# optimizer
self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr)
if use_saved:
self._use_saved_models(saved_model_dir)
def _use_saved_models(self, saved_model_dir):
self.obs_encoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_encoder.pt')))
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_random_episodes(self, episodes):
obs = self.env.reset()
done = False
for episode_count in range(episodes):
for i in range(self.args.episode_length):
action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
if done:
obs = self.env.reset()
done=False
else:
obs = next_obs
print("Collected {} random episodes".format(episode_count+1))
#if args.save_video:
# video.record(env)
#video.save('%d.mp4' % step)
#video.close()
def upper_bound_minimization(self):
pass
if __name__ == '__main__':
args = parse_args()
dpi = DPI(args)
dpi.collect_random_episodes(episodes=5)