Curiosity/DPI/train.py

227 lines
10 KiB
Python

import numpy as np
import torch
import argparse
import os
import gym
import time
import json
import dmc2gym
import wandb
import utils
from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger
from video import VideoRecorder
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent
#from agents.navigation.carla_env import CarlaEnv
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=4, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int)
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
parser.add_argument('--episode_length', default=50, type=int)
# train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000, type=int)
parser.add_argument('--batch_size', default=200, type=int) #512
parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagination_horizon', default=15, type=str)
# eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
# sac
parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('--port', default=2000, type=int)
args = parser.parse_args()
return args
class DPI:
def __init__(self, args):
# wandb config
#run = wandb.init(project="dpi")
self.args = args
# environment setup
self.env = make_env(self.args)
self.env.seed(self.args.seed)
# stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
# create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length,
batch_size=args.batch_size,
args=self.args)
# create work directory
utils.make_dir(self.args.work_dir)
self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create video recorder
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
#video.init(enabled=True)
# create models
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None):
self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128
)
# model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + list(self.transition_model.parameters())
# optimizer
self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr)
if use_saved:
self._use_saved_models(saved_model_dir)
def _use_saved_models(self, saved_model_dir):
self.obs_encoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_encoder.pt')))
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_random_episodes(self, episodes):
obs = self.env.reset()
done = False
for episode_count in range(episodes):
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
video.init(enabled=True)
for i in range(self.args.episode_length):
action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
if args.save_video:
video.record(self.env)
if done:
obs = self.env.reset()
done=False
else:
obs = next_obs
video.save('%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1))
#if args.save_video:
# video.record(self.env)
#video.save('%d.mp4' % step)
#video.close()
def train(self):
# collect experience
self.collect_random_episodes(self.args.batch_size)
# Group observations and next_observations by steps
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
# Train encoder
previous_information_loss = 0
for i in range(self.args.episode_length):
# Encode observations and next_observations
self.features = self.obs_encoder(observations[i]) # (N,128)
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
# Calculate upper bound loss
past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features)
previous_information_loss = past_loss
print("past_loss: ", past_loss)
def upper_bound_minimization(self, features, next_features):
club_sample = CLUBSample(self.args.state_size,
self.args.state_size,
self.args.hidden_size)
club_loss = club_sample(features, next_features)
return club_loss
if __name__ == '__main__':
args = parse_args()
dpi = DPI(args)
dpi.train()