Curiosity/DPI/train.py

349 lines
17 KiB
Python

import numpy as np
import torch
import argparse
import os
import gym
import time
import json
import dmc2gym
import tqdm
import wandb
import utils
from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger
from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
import torchvision.transforms as T
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent
#from agents.navigation.carla_env import CarlaEnv
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--version', default=1, type=int)
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true')
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=51, type=int)
# train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=10000, type=int)
parser.add_argument('--batch_size', default=20, type=int) #512
parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str)
# eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
# sac
parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('--port', default=2000, type=int)
args = parser.parse_args()
return args
class DPI:
def __init__(self, args):
# wandb config
#run = wandb.init(project="dpi")
self.args = args
# set environment noise
set_global_var(self.args.high_noise)
# environment setup
self.env = make_env(self.args)
#self.args.seed = np.random.randint(0, 1000)
self.env.seed(self.args.seed)
# noiseless environment setup
self.args.version = 2 # env_id changes to v2
self.args.img_source = None # no image noise
self.args.resource_files = None
self.env_clean = make_env(self.args)
self.env_clean.seed(self.args.seed)
# stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env_clean = utils.FrameStack(self.env_clean, k=self.args.frame_stack)
# create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length,
batch_size=args.batch_size,
args=self.args)
self.data_buffer_clean = ReplayBuffer(size=self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length,
batch_size=args.batch_size,
args=self.args)
# create work directory
utils.make_dir(self.args.work_dir)
self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create models
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None):
self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128
)
# model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
list(self.obs_decoder.parameters()) + list(self.transition_model.parameters())
# optimizer
self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr)
if use_saved:
self._use_saved_models(saved_model_dir)
def _use_saved_models(self, saved_model_dir):
self.obs_encoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_encoder.pt')))
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_sequences(self, episodes):
obs = self.env.reset()
#obs_clean = self.env_clean.reset()
done = False
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
if args.save_video:
self.env.video.init(enabled=True)
#self.env_clean.video.init(enabled=True)
for i in range(self.args.episode_length):
action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action)
#next_obs_clean, _, done, _ = self.env_clean.step(action)
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
#self.data_buffer_clean.add(obs_clean, action, next_obs_clean, episode_count+1, done)
if args.save_video:
self.env.video.record(self.env)
#self.env_clean.video.record(self.env_clean)
if done or i == self.args.episode_length-1:
obs = self.env.reset()
#obs_clean = self.env_clean.reset()
done=False
else:
obs = next_obs
#obs_clean = next_obs_clean
if args.save_video:
self.env.video.save('noisy/%d.mp4' % episode_count)
#self.env_clean.video.save('clean/%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1))
def train(self):
# collect experience
self.collect_sequences(self.args.batch_size)
# Group observations and next_observations by steps from past to present
last_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1]
current_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[:self.args.episode_length-1]
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[1:]
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1]
next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:]
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# Train encoder
total_ub_loss = 0
total_encoder_loss = 0
for i in range(self.args.episode_length-1):
if i > 0:
# Encode observations and next_observations
self.last_states_dict = self.obs_encoder(last_observations[i])
self.current_states_dict = self.obs_encoder(current_observations[i])
self.next_states_dict = self.obs_encoder_momentum(next_observations[i])
self.action = actions[i] # (N,6)
history = self.transition_model.prev_history
# Encode negative observations
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
# Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss
ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
total_ub_loss += ub_loss
total_encoder_loss += encoder_loss
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
print(imagine_horizon)
#exit()
#print(total_ub_loss, total_encoder_loss)
def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states):
club_sample = CLUBSample(last_states,
current_states,
negative_current_states,
predicted_current_states)
club_loss = club_sample()
return club_loss
def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict):
# current state distribution
curr_states_dist = curr_states_dict["distribution"]
# predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean()
return loss
"""
def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
# Imagine next state
if step == 0:
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
self.history = imagined_next_states["history"]
else:
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
# State Distribution
imagined_next_states_dist = imagined_next_states["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
return loss
"""
def get_features(self, x, momentum=False):
if self.aug:
x = T.RandomCrop((80, 80))(x) # (None,80,80,4)
x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4)
x = T.RandomCrop((84, 84))(x) # (None,84,84,4)
with torch.no_grad():
x = (x.float() - self.ob_mean) / self.ob_std
if momentum:
x = self.obs_encoder(x).detach()
else:
x = self.obs_encoder_momentum(x)
return x
if __name__ == '__main__':
args = parse_args()
dpi = DPI(args)
dpi.train()