Curiosity/DPI/train.py

746 lines
35 KiB
Python

import os
import gc
import copy
import tqdm
import wandb
import random
import argparse
import numpy as np
from collections import OrderedDict
import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image, shuffle_along_axis, Logger
from replay_buffer import ReplayBuffer
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent
#from agents.navigation.carla_env import CarlaEnv
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--version', default=1, type=int)
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=2, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--collection_interval', default=100, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=5000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true')
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=51, type=int)
# train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=100000, type=int)
parser.add_argument('--update_steps', default=100, type=int)
parser.add_argument('--batch_size', default=64, type=int) #512
parser.add_argument('--state_size', default=50, type=int)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--history_size', default=256, type=int)
parser.add_argument('--episode_collection', default=5, type=int)
parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer')
parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
# eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
# value
parser.add_argument('--value_lr', default=1e-6, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-6, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--world_model_lr', default=1e-5, type=float)
parser.add_argument('--decoder_lr', default=1e-5, type=float)
parser.add_argument('--reward_lr', default=1e-5, type=float)
parser.add_argument('--encoder_tau', default=0.001, type=float)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true')
# sac
parser.add_argument('--discount', default=0.99, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int)
parser.add_argument('--saving_interval', default=2500, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
args = parser.parse_args()
return args
class DPI:
def __init__(self, args):
# wandb config
#run = wandb.init(project="dpi")
self.args = args
# set environment noise
set_global_var(self.args.high_noise)
# environment setup
self.env = make_env(self.args)
#self.args.seed = np.random.randint(0, 1000)
self.env.seed(self.args.seed)
self.global_episodes = 0
# noiseless environment setup
self.args.version = 2 # env_id changes to v2
self.args.img_source = None # no image noise
self.args.resource_files = None
# stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env)
self.env = utils.TimeLimit(self.env, 1000 // args.action_repeat)
# create replay buffer
self.data_buffer = ReplayBuffer(self.args.replay_buffer_capacity,
self.env.observation_space.shape,
self.env.action_space.shape[0],
self.args.episode_length,
self.args.batch_size)
# create work directory
utils.make_dir(self.args.work_dir)
self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create models
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None):
# World Models
self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128
).to(device)
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128
).to(device)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.channels*self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
).to(device)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128
).to(device)
# Actor Model
self.actor_model = Actor(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6
).to(device)
#self.actor_model.apply(self.init_weights)
# Value Models
self.value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
self.reward_model = RewardModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
# Contrastive Models
self.prjoection_head = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
).to(device)
self.prjoection_head_momentum = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
).to(device)
self.contrastive_head = ContrastiveHead(
hidden_size=self.args.hidden_size, # 256
).to(device)
self.club_sample = CLUBSample(
x_dim=self.args.state_size, # 128
y_dim=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
# model parameters
self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \
list(self.transition_model.parameters()) + list(self.club_sample.parameters()) + \
list(self.contrastive_head.parameters())
self.past_transition_parameters = self.transition_model.parameters()
# optimizers
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6)
self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6)
self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6)
# Create Modules
self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model]
self.decoder_modules = [self.obs_decoder]
self.reward_modules = [self.reward_model]
#self.decoder_modules = [self.obs_decoder]
if use_saved:
self._use_saved_models(saved_model_dir)
def _use_saved_models(self, saved_model_dir):
self.obs_encoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_encoder.pt')))
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_random_sequences(self, seed_steps):
obs = self.env.reset()
done = False
all_rews = []
self.global_episodes += 1
epi_reward = 0
for _ in tqdm.tqdm(range(seed_steps), desc='Collecting episodes'):
action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done)
obs = next_obs
epi_reward += rew
if done:
obs = self.env.reset()
done=False
all_rews.append(epi_reward)
epi_reward = 0
return all_rews
def collect_sequences(self, collect_steps, actor_model):
obs = self.env.reset()
done = False
all_rews = []
self.global_episodes += 1
epi_reward = 0
for episode_count in tqdm.tqdm(range(collect_steps), desc='Collecting episodes'):
with torch.no_grad():
obs_ = torch.tensor(obs.copy(), dtype=torch.float32)
obs_ = preprocess_obs(obs_).to(device)
state = self.get_features(obs_)["distribution"].rsample().unsqueeze(0)
action = actor_model(state)
action = actor_model.add_exploration(action)
action = action.cpu().numpy()[0]
next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done)
if done:
obs = self.env.reset()
done = False
all_rews.append(epi_reward)
epi_reward = 0
else:
obs = next_obs
epi_reward += rew
return all_rews
def train(self, step, total_steps):
# logger
logdir = os.path.dirname(os.path.realpath(__file__)) + "/log/logs/"
if not(os.path.exists(logdir)):
os.makedirs(logdir)
initial_logs = OrderedDict()
logger = Logger(logdir)
episodic_rews = self.collect_random_sequences(5000//args.action_repeat)
self.global_step = self.data_buffer.steps
initial_logs.update({
'train_avg_reward':np.mean(episodic_rews),
'train_max_reward': np.max(episodic_rews),
'train_min_reward': np.min(episodic_rews),
'train_std_reward':np.std(episodic_rews),
})
logger.log_scalars(initial_logs, step=0)
logger.flush()
while self.global_step < total_steps:
logs = OrderedDict()
step += 1
for update_steps in range(self.args.update_steps):
model_loss, actor_loss, value_loss, actor_model = self.update((step-1)*args.update_steps + update_steps)
initial_logs.update({
'model_loss' : model_loss,
'actor_loss': actor_loss,
'value_loss': value_loss,
'train_avg_reward':np.mean(episodic_rews),
'train_max_reward': np.max(episodic_rews),
'train_min_reward': np.min(episodic_rews),
'train_std_reward':np.std(episodic_rews),
})
logger.log_scalars(logs, self.global_step)
print("########## Global Step:", self.global_step, " ##########")
for key, value in initial_logs.items():
print(key, " : ", value)
episodic_rews = self.collect_sequences(1000//self.args.action_repeat, actor_model)
if self.global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
print("Saving model")
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path)
self.evaluate()
self.global_step = self.data_buffer.steps * self.args.action_repeat
"""
# collect experience
if step !=0:
encoder = self.obs_encoder
actor = self.actor_model
all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder)
"""
def collect_batch(self):
obs_, acs_, nxt_obs_, rews_, terms_ = self.data_buffer.sample()
obs = torch.tensor(obs_, dtype=torch.float32)[1:]
last_obs = torch.tensor(obs_, dtype=torch.float32)[:-1]
nxt_obs = torch.tensor(nxt_obs_, dtype=torch.float32)[1:]
acs = torch.tensor(acs_, dtype=torch.float32)[:-1].to(device)
nxt_acs = torch.tensor(acs_, dtype=torch.float32)[1:].to(device)
rews = torch.tensor(rews_, dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
nonterms = torch.tensor((1.0-terms_), dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
last_obs = preprocess_obs(last_obs).to(device)
obs = preprocess_obs(obs).to(device)
nxt_obs = preprocess_obs(nxt_obs).to(device)
return last_obs, obs, nxt_obs, acs, rews, nxt_acs, nonterms
def update(self, step):
last_observations, current_observations, next_observations, actions, rewards, next_actions, nonterms = self.collect_batch()
#last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch()
world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations,
current_observations,
next_observations,
actions,
next_actions,
rewards,
nonterms)
self.world_model_opt.zero_grad()
world_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
self.decoder_opt.zero_grad()
dec_loss.backward()
nn.utils.clip_grad_norm_(self.obs_decoder.parameters(), self.args.grad_clip_norm)
self.decoder_opt.step()
self.reward_opt.zero_grad()
rew_loss.backward()
nn.utils.clip_grad_norm_(self.reward_model.parameters(), self.args.grad_clip_norm)
self.reward_opt.step()
actor_loss = self.actor_model_losses()
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
value_loss = self.value_model_losses()
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update momentum encoder and projection head
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# update target value networks
#if step % self.args.value_target_update_freq == 0:
# self.target_value_model = copy.deepcopy(self.value_model)
if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Encoder Loss', enc_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Decoder Loss', dec_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', -lb_loss.detach().item(), step)
return world_loss.item(), actor_loss.item(), value_loss.item(), self.actor_model
def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards, nonterms):
# get features
self.last_state_feat = self.get_features(last_obs)
self.curr_state_feat = self.get_features(curr_obs)
self.nxt_state_feat = self.get_features(nxt_obs, momentum=True)
# states
self.last_state_enc = self.last_state_feat["sample"]
self.curr_state_enc = self.curr_state_feat["sample"]
self.nxt_state_enc = self.nxt_state_feat["sample"]
# actions
actions = actions.clone()
nxt_actions = nxt_actions.clone()
# rewards
rewards = rewards.clone()
# predict next states
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms)
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
self.pred_curr_state_enc = self.pred_curr_state_dist.mean
# encoder loss
enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
# reward loss
rew_dist = self.reward_model(self.curr_state_enc)
rew_loss = -torch.mean(rew_dist.log_prob(rewards))
# decoder loss
dec_dist = self.obs_decoder(self.nxt_state_enc)
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
# upper bound loss
_, ub_loss = self._upper_bound_minimization(self.curr_state_enc,
self.pred_curr_state_enc)
# lower bound loss
# contrastive projection
vec_anchor = self.pred_curr_state_enc.detach()
vec_positive = self.nxt_state_enc.detach()
z_anchor = self.prjoection_head(vec_anchor, nxt_actions)
z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions)
# contrastive loss
past_lb_loss = 0
for i in range(z_anchor.shape[0]):
logits = self.contrastive_head(z_anchor[i], z_positive[i])
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels) + past_lb_loss
past_lb_loss = lb_loss.detach().item()
lb_loss = -0.1 * lb_loss/(z_anchor.shape[0])
world_loss = enc_loss + ub_loss + lb_loss
return world_loss, enc_loss , rew_loss, dec_loss, ub_loss, lb_loss
def actor_model_losses(self):
with torch.no_grad():
curr_state_enc = self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"]
curr_state_hist = self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"]
with FreezeParameters(self.world_model_modules + self.decoder_modules + self.reward_modules):
imagine_horizon = self.args.imagine_horizon
action = self.actor_model(curr_state_enc)
self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc,
action, curr_state_hist,
imagine_horizon)
with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules):
imag_rewards = self.reward_model(self.imagined_rollout["sample"]).mean
imag_values = self.value_model(self.imagined_rollout["sample"]).mean
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1],
imag_values[:-1],
discounts[:-1] ,
self.args.td_lambda,
imag_values[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.returns)
return actor_loss
def value_model_losses(self):
# value loss
with torch.no_grad():
value_feat = self.imagined_rollout["sample"][:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
return value_loss
def select_one_batch(self):
# collect sequences
non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0]
current_obs = self.data_buffer.observations[non_zero_indices]
next_obs = self.data_buffer.next_observations[non_zero_indices]
actions_raw = self.data_buffer.actions[non_zero_indices]
rewards = self.data_buffer.rewards[non_zero_indices]
self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=False)[0]
# group by episodes
current_obs = self.grouped_arrays(current_obs)
next_obs = self.grouped_arrays(next_obs)
actions_raw = self.grouped_arrays(actions_raw)
rewards_ = self.grouped_arrays(rewards)
# select random chunks of episodes
if current_obs.shape[0] < self.args.batch_size:
random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size)
else:
random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size)
# select random starting points
if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size:
init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size)
else:
init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size))
# shuffle
random.shuffle(random_episode_number)
random.shuffle(init_index)
# select first k elements
last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1]
current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:]
next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1]
actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device)
next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device)
rewards = self.select_first_k(rewards_, init_index, random_episode_number)[:-1].to(device)
# preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
return last_observations, current_observations, next_observations, actions, next_actions, rewards
def evaluate(self, eval_episodes=10):
path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.restore_checkpoint(path)
obs = self.env.reset()
done = False
#video = VideoRecorder(self.video_dir, resource_files=self.args.resource_files)
if self.args.save_video:
self.env.video.init(enabled=True)
episodic_rewards = []
for episode in range(eval_episodes):
rewards = 0
done = False
while not done:
with torch.no_grad():
obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0)
obs_processed = preprocess_obs(obs).to(device)
state = self.get_features(obs_processed)["distribution"].rsample()
action = self.actor_model(state).cpu().detach().numpy().squeeze()
next_obs, rew, done, _ = self.env.step(action)
rewards += rew
if self.args.save_video:
self.env.video.record(self.env)
self.env.video.save('/home/vedant/Curiosity/Curiosity/DPI/log/video/learned_model.mp4')
obs = next_obs
obs = self.env.reset()
episodic_rewards.append(rewards)
print("Episodic rewards: ", episodic_rewards)
print("Average episodic reward: ", np.mean(episodic_rewards))
def init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def grouped_arrays(self,array):
indices = [0] + self.terms.tolist()
def subarrays():
for start, end in zip(indices[:-1], indices[1:]):
yield array[start:end]
try:
subarrays = np.stack(list(subarrays()), axis=0)
except ValueError:
subarrays = np.asarray(list(subarrays()))
return subarrays
def select_first_k(self, array, init_index, episode_number):
term_index = init_index + self.args.episode_length
array = array[episode_number]
array_list = []
for i in range(array.shape[0]):
array_list.append(array[i][init_index[i]:term_index[i]])
array = np.asarray(array_list)
if array.ndim == 5:
transposed_array = np.transpose(array, (1, 0, 2, 3, 4))
elif array.ndim == 4:
transposed_array = np.transpose(array, (1, 0, 2, 3))
elif array.ndim == 3:
transposed_array = np.transpose(array, (1, 0, 2))
elif array.ndim == 2:
transposed_array = np.transpose(array, (1, 0))
else:
transposed_array = np.expand_dims(array, axis=0)
#return torch.tensor(array).float()
return torch.tensor(transposed_array).float()
def _upper_bound_minimization(self, current_states, predicted_current_states):
current_negative_states = shuffle_along_axis(current_states.clone(), axis=0)
current_negative_states = shuffle_along_axis(current_negative_states, axis=1)
club_loss = self.club_sample(current_states, predicted_current_states, current_negative_states)
likelihood_loss = 0
return likelihood_loss, club_loss
def _encoder_loss(self, curr_states_dist, predicted_curr_states_dist):
# KL divergence loss
loss = torch.mean(torch.distributions.kl.kl_divergence(curr_states_dist,predicted_curr_states_dist))
return loss
def get_features(self, x, momentum=False):
if self.args.aug:
crop_transform = T.RandomCrop(size=80)
cropped_x = torch.stack([crop_transform(x[i]) for i in range(x.size(0))])
padding = (2, 2, 2, 2)
x = F.pad(cropped_x, padding)
with torch.no_grad():
if momentum:
x = self.obs_encoder_momentum(x)
else:
x = self.obs_encoder(x)
return x
def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value):
next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0)
targets = rewards + discounts * next_values * (1-td_lam)
rets =[]
last_rew = last_value
for t in range(rewards.shape[0]-1, -1, -1):
last_rew = targets[t] + discounts[t] * td_lam *(last_rew)
rets.append(last_rew)
returns = torch.flip(torch.stack(rets), [0])
return returns
def lambda_return(self,imged_reward, value_pred, bootstrap, discount=0.99, lambda_=0.95):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
next_values = torch.cat([value_pred[1:], bootstrap[None]], 0)
discount_tensor = discount * torch.ones_like(imged_reward) # pcont
inputs = imged_reward + discount_tensor * next_values * (1 - lambda_)
last = bootstrap
indices = reversed(range(len(inputs)))
outputs = []
for index in indices:
inp, disc = inputs[index], discount_tensor[index]
last = inp + disc * lambda_ * last
outputs.append(last)
outputs = list(reversed(outputs))
outputs = torch.stack(outputs, 0)
returns = outputs
return returns
def save_models(self, save_path):
torch.save(
{'rssm' : self.transition_model.state_dict(),
'actor': self.actor_model.state_dict(),
'reward_model': self.reward_model.state_dict(),
'obs_encoder': self.obs_encoder.state_dict(),
'obs_decoder': self.obs_decoder.state_dict(),
'actor_optimizer': self.actor_opt.state_dict(),
'value_optimizer': self.value_opt.state_dict(),
'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path)
def restore_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path)
self.transition_model.load_state_dict(checkpoint['rssm'])
self.actor_model.load_state_dict(checkpoint['actor'])
self.reward_model.load_state_dict(checkpoint['reward_model'])
self.obs_encoder.load_state_dict(checkpoint['obs_encoder'])
self.obs_decoder.load_state_dict(checkpoint['obs_decoder'])
self.world_model_opt.load_state_dict(checkpoint['world_model_optimizer'])
self.actor_opt.load_state_dict(checkpoint['actor_optimizer'])
self.value_opt.load_state_dict(checkpoint['value_optimizer'])
if __name__ == '__main__':
args = parse_args()
writer = SummaryWriter()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
step = 0
total_steps = 500000
dpi = DPI(args)
dpi.train(step,total_steps)
dpi.evaluate()