Curiosity/DPI/train.py

295 lines
14 KiB
Python

import numpy as np
import torch
import argparse
import os
import gym
import time
import json
import dmc2gym
import tqdm
import wandb
import utils
from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger
from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent
#from agents.navigation.carla_env import CarlaEnv
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--version', default=1, type=int)
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=4, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true')
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=50, type=int)
# train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000, type=int)
parser.add_argument('--batch_size', default=200, type=int) #512
parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagination_horizon', default=15, type=str)
# eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
# sac
parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('--port', default=2000, type=int)
args = parser.parse_args()
return args
class DPI:
def __init__(self, args):
# wandb config
#run = wandb.init(project="dpi")
self.args = args
# set environment noise
set_global_var(self.args.high_noise)
# environment setup
self.env = make_env(self.args)
self.env.seed(self.args.seed)
# noiseless environment setup
self.args.version = 2 # env_id changes to v2
self.args.img_source = None # no image noise
self.args.resource_files = None
self.env_clean = make_env(self.args)
self.env_clean.seed(self.args.seed)
# stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env_clean = utils.FrameStack(self.env_clean, k=self.args.frame_stack)
# create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length,
batch_size=args.batch_size,
args=self.args)
self.data_buffer_clean = ReplayBuffer(size=self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length,
batch_size=args.batch_size,
args=self.args)
# create work directory
utils.make_dir(self.args.work_dir)
self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create models
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None):
self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128
)
# model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
list(self.obs_decoder.parameters()) + list(self.transition_model.parameters())
# optimizer
self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr)
if use_saved:
self._use_saved_models(saved_model_dir)
def _use_saved_models(self, saved_model_dir):
self.obs_encoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_encoder.pt')))
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_sequences(self, episodes):
obs = self.env.reset()
obs_clean = self.env_clean.reset()
done = False
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
if args.save_video:
self.env.video.init(enabled=True)
self.env_clean.video.init(enabled=True)
for i in range(self.args.episode_length):
action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action)
next_obs_clean, _, done, _ = self.env_clean.step(action)
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
self.data_buffer_clean.add(obs_clean, action, next_obs_clean, episode_count+1, done)
if args.save_video:
self.env.video.record(self.env_clean)
self.env_clean.video.record(self.env_clean)
if done:
obs = self.env.reset()
obs_clean = self.env_clean.reset()
done=False
else:
obs = next_obs
obs_clean = next_obs_clean
if args.save_video:
self.env.video.save('noisy/%d.mp4' % episode_count)
self.env_clean.video.save('clean/%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1))
def train(self):
# collect experience
self.collect_sequences(self.args.batch_size)
# Group observations and next_observations by steps
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# Train encoder
previous_information_loss = 0
previous_encoder_loss = 0
for i in range(self.args.episode_length):
# Encode observations and next_observations
self.states_dist = self.obs_encoder(observations[i])
self.next_states_dist = self.obs_encoder(next_observations[i])
# Sample states and next_states
self.states = self.states_dist["sample"] # (N,128)
self.next_states = self.next_states_dist["sample"] # (N,128)
self.actions = actions[i] # (N,6)
# Calculate upper bound loss
past_latent_loss = previous_information_loss + self._upper_bound_minimization(self.states, self.next_states)
# Calculate encoder loss
past_encoder_loss = previous_encoder_loss + self._past_encoder_loss(self.states, self.next_states,
self.states_dist, self.next_states_dist,
self.actions, self.history, i)
print(past_encoder_loss, past_latent_loss)
previous_information_loss = past_latent_loss
previous_encoder_loss = past_encoder_loss
def _upper_bound_minimization(self, states, next_states):
club_sample = CLUBSample(self.args.state_size,
self.args.state_size,
self.args.hidden_size)
club_loss = club_sample(states, next_states)
return club_loss
def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
# Imagine next state
if step == 0:
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
self.history = imagined_next_states["history"]
else:
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
# State Distribution
imagined_next_states_dist = imagined_next_states["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
return loss
if __name__ == '__main__':
args = parse_args()
dpi = DPI(args)
dpi.train()