Adding model

This commit is contained in:
Vedant Dave 2023-04-20 14:55:54 +02:00
parent 3fa5e8e74a
commit e7f5533ee6

View File

@ -6,11 +6,12 @@ import wandb
import random import random
import argparse import argparse
import numpy as np import numpy as np
from collections import OrderedDict
import utils import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image
from replay_buffer import ReplayBuffer
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
@ -40,24 +41,25 @@ def parse_args():
parser.add_argument('--resource_files', type=str) parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int) # 10000 parser.add_argument('--total_frames', default=5000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true') parser.add_argument('--high_noise', action='store_true')
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000 parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=21, type=int) parser.add_argument('--episode_length', default=51, type=int)
# train # train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=100000, type=int) parser.add_argument('--num_train_steps', default=100000, type=int)
parser.add_argument('--batch_size', default=128, type=int) #512 parser.add_argument('--update_steps', default=1, type=int)
parser.add_argument('--state_size', default=30, type=int) parser.add_argument('--batch_size', default=64, type=int) #512
parser.add_argument('--hidden_size', default=256, type=int) parser.add_argument('--state_size', default=50, type=int)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--episode_collection', default=5, type=int) parser.add_argument('--episode_collection', default=5, type=int)
parser.add_argument('--episodes_buffer', default=20, type=int) parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer')
parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=10, type=str) parser.add_argument('--imagine_horizon', default=15, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm') parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
# eval # eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
@ -65,8 +67,6 @@ def parse_args():
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000 parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
# value # value
parser.add_argument('--value_lr', default=8e-6, type=float) parser.add_argument('--value_lr', default=8e-6, type=float)
parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int) parser.add_argument('--td_lambda', default=0.95, type=int)
# actor # actor
@ -78,13 +78,13 @@ def parse_args():
# world/encoder/decoder # world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--world_model_lr', default=1e-5, type=float) parser.add_argument('--world_model_lr', default=1e-5, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_tau', default=0.001 , type=float)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true') parser.add_argument('--aug', action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.95, type=float)
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int) parser.add_argument('--logging_freq', default=100, type=int)
@ -131,6 +131,7 @@ class DPI:
self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat) self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env) self.env = utils.NormalizeActions(self.env)
self.env = utils.TimeLimit(self.env, 1000 / args.action_repeat)
# create replay buffer # create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
@ -250,7 +251,7 @@ class DPI:
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None): def collect_random_sequences(self, episodes):
obs = self.env.reset() obs = self.env.reset()
done = False done = False
@ -259,239 +260,278 @@ class DPI:
self.global_episodes += 1 self.global_episodes += 1
epi_reward = 0 epi_reward = 0
while not done: while not done:
if random: action = self.env.action_space.sample()
action = self.env.action_space.sample() next_obs, rew, done, _ = self.env.step(action)
else: self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes)
with torch.no_grad(): obs = next_obs
obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0) epi_reward += rew
obs_processed = preprocess_obs(obs).to(device) obs = self.env.reset()
state = self.obs_encoder(obs_processed)["distribution"].sample() done=False
action = self.actor_model(state).cpu().numpy().squeeze() all_rews.append(epi_reward)
#action = self.env.action_space.sample() return all_rews
def collect_sequences(self, episodes, actor_model):
obs = self.env.reset()
done = False
all_rews = []
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
self.global_episodes += 1
epi_reward = 0
while not done:
with torch.no_grad():
obs = torch.tensor(obs.copy(), dtype=torch.float32).to(device).unsqueeze(0)
state = self.get_features(obs)["distribution"].rsample()
action = self.actor_model(state)
action = actor_model.add_exploration(action).cpu().numpy()[0]
print(action)
obs = obs.cpu().numpy()[0]
next_obs, rew, done, _ = self.env.step(action) next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes) self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes)
obs = next_obs obs = next_obs
epi_reward += rew epi_reward += rew
obs = self.env.reset() obs = self.env.reset()
done=False done=False
all_rews.append(epi_reward) all_rews.append(epi_reward)
return all_rews return all_rews
def train(self, step, total_steps): def train(self, step, total_steps):
counter = 0 episodic_rews = self.collect_random_sequences(self.args.episodes_buffer)
import matplotlib.pyplot as plt global_step = self.data_buffer.steps
fig, ax = plt.subplots()
while step < total_steps: # logger
logs = OrderedDict()
# collect experience
if step !=0:
encoder = self.obs_encoder
actor = self.actor_model
all_rews = self.collect_sequences(self.args.episode_collection, random=False, actor_model=actor, encoder_model=encoder)
else:
all_rews = self.collect_sequences(self.args.episodes_buffer, random=True)
# collect sequences while global_step < total_steps:
non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0] step += 1
current_obs = self.data_buffer.observations[non_zero_indices] for update_steps in range(self.args.update_steps):
next_obs = self.data_buffer.next_observations[non_zero_indices] model_loss, actor_loss, value_loss = self.update((step-1)*args.update_steps + update_steps)
actions_raw = self.data_buffer.actions[non_zero_indices] episodic_rews = self.collect_sequences(self.args.episode_collection, actor_model=self.actor_model, encoder_model=self.obs_encoder)
rewards = self.data_buffer.rewards[non_zero_indices]
self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=0)[0]
# Group by episodes
current_obs = self.grouped_arrays(current_obs)
next_obs = self.grouped_arrays(next_obs)
actions_raw = self.grouped_arrays(actions_raw)
rewards_ = self.grouped_arrays(rewards)
# Train encoder logs.update({
if step == 0: 'model_loss' : model_loss,
step += 1 'actor_loss': actor_loss,
'value_loss': value_loss,
update_steps = 1 if step > 1 else 1 'train_avg_reward':np.mean(episodic_rews),
#for _ in range(self.args.collection_interval // self.args.episode_length+1): 'train_max_reward': np.max(episodic_rews),
for _ in range(update_steps): 'train_min_reward': np.min(episodic_rews),
counter += 1 'train_std_reward':np.std(episodic_rews),
})
# Select random chunks of episodes print("########## Global Step: ", global_step, " ##########")
if current_obs.shape[0] < self.args.batch_size: for key, value in logs.items():
random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size) print(key, " : ", value)
else:
random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size)
if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size:
init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size)
else:
init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size))
random.shuffle(random_episode_number)
random.shuffle(init_index)
last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1]
current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:]
next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1]
actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device)
next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device)
rewards = self.select_first_k(rewards_, init_index, random_episode_number)[1:].to(device)
# Preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
# Initialize transition model states print(global_step)
self.transition_model.init_states(self.args.batch_size, device) # (N,128) if global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
self.history = self.transition_model.prev_history # (N,128)
past_world_model_loss = 0
past_action_loss = 0
past_value_loss = 0
for i in range(self.args.episode_length-1):
if i > 0:
# Encode observations and next_observations
self.last_states_dict = self.get_features(last_observations[i])
self.current_states_dict = self.get_features(current_observations[i])
self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6)
self.next_action = next_actions[i] # (N,6)
history = self.transition_model.prev_history
# Encode negative observations fro upper bound loss
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, current_observations.shape[0]-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.get_features(negative_current_observations)
# Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict["sample"].detach(),
self.current_states_dict["sample"].detach(),
self.negative_current_states_dict["sample"].detach(),
predicted_current_state_dict["sample"].detach(),
)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict, predicted_current_state_dict)
# decoder loss
horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
nxt_obs = next_observations[i:i+horizon].reshape(-1,9,84,84)
next_states_encodings = self.get_features(nxt_obs)["sample"].view(horizon,self.args.batch_size, -1)
obs_dist = self.obs_decoder(next_states_encodings)
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon]))
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"].detach()
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels)
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[i]))
# world model loss
world_model_loss = (10*encoder_loss + 10*ub_loss + 1e-1*lb_loss + reward_loss + 1e-3*decoder_loss + past_world_model_loss) * 1e-3
past_world_model_loss = world_model_loss.item()
# actor loss
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
action = self.actor_model(self.current_states_dict["sample"])
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"],
action, self.history,
imagine_horizon)
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rewards = self.reward_model(imagined_rollout["sample"]).mean
imag_values = self.target_value_model(imagined_rollout["sample"]).mean
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1],
imag_values[:-1],
discounts[:-1] ,
self.args.td_lambda,
imag_values[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.returns) + past_action_loss
past_action_loss = actor_loss.item()
# value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) + past_value_loss
past_value_loss = value_loss.item()
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
# counter for reward
#count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size))
count = (counter-1) * (self.args.batch_size)
if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, self.data_buffer.steps)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), self.data_buffer.steps)
step += 1
print(world_model_loss, actor_loss, value_loss)
# update actor model
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
# update world model
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# update value model
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update momentum encoder and projection head
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
rew_len = np.arange(count, count+self.args.episode_collection) if count != 0 else np.arange(0, self.args.batch_size)
for j in range(len(all_rews)):
writer.add_scalar('Rewards/Rewards', all_rews[j], rew_len[j])
print(step)
if step % 2850 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
print("Saving model") print("Saving model")
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path) self.save_models(path)
self.evaluate() self.evaluate()
global_step = self.data_buffer.steps
# collect experience
if step !=0:
encoder = self.obs_encoder
actor = self.actor_model
all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder)
def update(self, step):
last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch()
world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations,
current_observations,
next_observations,
actions,
next_actions,
rewards)
self.world_model_opt.zero_grad()
world_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
actor_loss = self.actor_model_losses()
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
value_loss = self.value_model_losses()
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update momentum encoder and projection head
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# update target value networks
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Encoder Loss', enc_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Decoder Loss', dec_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step)
return world_loss.item(), actor_loss.item(), value_loss.item()
def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards):
self.last_state_feat = self.get_features(last_obs)
self.curr_state_feat = self.get_features(curr_obs)
self.nxt_state_feat = self.get_features(nxt_obs)
# states
self.last_state_enc = self.last_state_feat["sample"]
self.curr_state_enc = self.curr_state_feat["sample"]
self.nxt_state_enc = self.nxt_state_feat["sample"]
# actions
actions = actions
nxt_actions = nxt_actions
# rewards
rewards = rewards
# predict next states
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history)
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
self.pred_curr_state_enc = self.pred_curr_state_dist.mean
#print(torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc))
#print(torch.distributions.kl_divergence(self.curr_state_feat["distribution"], self.pred_curr_state_dist).mean(),0)
# encoder loss
enc_loss = torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc)
#self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
# reward loss
rew_dist = self.reward_model(self.curr_state_enc)
rew_loss = -torch.mean(rew_dist.log_prob(rewards.unsqueeze(-1)))
# decoder loss
dec_dist = self.obs_decoder(self.nxt_state_enc)
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
# upper bound loss
likelihood_loss, ub_loss = self._upper_bound_minimization(self.curr_state_enc,
self.pred_curr_state_enc)
# lower bound loss
# contrastive projection
vec_anchor = self.pred_curr_state_enc
vec_positive = self.nxt_state_enc
z_anchor = self.prjoection_head(vec_anchor, nxt_actions)
z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions)
# contrastive loss
past_lb_loss = 0
for i in range(z_anchor.shape[0]):
logits = self.contrastive_head(z_anchor[i], z_positive[i])
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels) + past_lb_loss
past_lb_loss = lb_loss.detach().item()
lb_loss = lb_loss/(z_anchor.shape[0])
world_loss = enc_loss + rew_loss + dec_loss * 1e-4 + ub_loss * 10 + lb_loss
return world_loss, enc_loss , rew_loss, dec_loss * 1e-4, ub_loss * 10, lb_loss
def actor_model_losses(self):
with torch.no_grad():
curr_state_enc = self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"]
curr_state_hist = self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"]
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon
action = self.actor_model(curr_state_enc)
self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc,
action, curr_state_hist,
imagine_horizon)
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rewards = self.reward_model(self.imagined_rollout["sample"]).mean
imag_values = self.target_value_model(self.imagined_rollout["sample"]).mean
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1],
imag_values[:-1],
discounts[:-1] ,
self.args.td_lambda,
imag_values[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.returns)
return actor_loss
def value_model_losses(self):
# value loss
with torch.no_grad():
value_feat = self.imagined_rollout["sample"][:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
return value_loss
def select_one_batch(self):
# collect sequences
non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0]
current_obs = self.data_buffer.observations[non_zero_indices]
next_obs = self.data_buffer.next_observations[non_zero_indices]
actions_raw = self.data_buffer.actions[non_zero_indices]
rewards = self.data_buffer.rewards[non_zero_indices]
self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=False)[0]
# group by episodes
current_obs = self.grouped_arrays(current_obs)
next_obs = self.grouped_arrays(next_obs)
actions_raw = self.grouped_arrays(actions_raw)
rewards_ = self.grouped_arrays(rewards)
# select random chunks of episodes
if current_obs.shape[0] < self.args.batch_size:
random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size)
else:
random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size)
# select random starting points
if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size:
init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size)
else:
init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size))
# shuffle
random.shuffle(random_episode_number)
random.shuffle(init_index)
# select first k elements
last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1]
current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:]
next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1]
actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device)
next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device)
rewards = self.select_first_k(rewards_, init_index, random_episode_number)[:-1].to(device)
# preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
return last_observations, current_observations, next_observations, actions, next_actions, rewards
def evaluate(self, eval_episodes=10): def evaluate(self, eval_episodes=10):
path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
@ -511,7 +551,7 @@ class DPI:
with torch.no_grad(): with torch.no_grad():
obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0) obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0)
obs_processed = preprocess_obs(obs).to(device) obs_processed = preprocess_obs(obs).to(device)
state = self.obs_encoder(obs_processed)["distribution"].sample() state = self.get_features(obs_processed)["distribution"].rsample()
action = self.actor_model(state).cpu().detach().numpy().squeeze() action = self.actor_model(state).cpu().detach().numpy().squeeze()
next_obs, rew, done, _ = self.env.step(action) next_obs, rew, done, _ = self.env.step(action)
@ -534,11 +574,9 @@ class DPI:
def grouped_arrays(self,array): def grouped_arrays(self,array):
indices = [0] + self.terms.tolist() indices = [0] + self.terms.tolist()
def subarrays(): def subarrays():
for start, end in zip(indices[:-1], indices[1:]): for start, end in zip(indices[:-1], indices[1:]):
yield array[start:end] yield array[start:end]
try: try:
subarrays = np.stack(list(subarrays()), axis=0) subarrays = np.stack(list(subarrays()), axis=0)
except ValueError: except ValueError:
@ -548,13 +586,13 @@ class DPI:
def select_first_k(self, array, init_index, episode_number): def select_first_k(self, array, init_index, episode_number):
term_index = init_index + self.args.episode_length term_index = init_index + self.args.episode_length
array = array[episode_number] array = array[episode_number]
array_list = [] array_list = []
for i in range(array.shape[0]): for i in range(array.shape[0]):
array_list.append(array[i][init_index[i]:term_index[i]]) array_list.append(array[i][init_index[i]:term_index[i]])
array = np.asarray(array_list) array = np.asarray(array_list)
if array.ndim == 5: if array.ndim == 5:
transposed_array = np.transpose(array, (1, 0, 2, 3, 4)) transposed_array = np.transpose(array, (1, 0, 2, 3, 4))
elif array.ndim == 4: elif array.ndim == 4:
@ -565,20 +603,16 @@ class DPI:
transposed_array = np.transpose(array, (1, 0)) transposed_array = np.transpose(array, (1, 0))
else: else:
transposed_array = np.expand_dims(array, axis=0) transposed_array = np.expand_dims(array, axis=0)
#return torch.tensor(array).float()
return torch.tensor(transposed_array).float() return torch.tensor(transposed_array).float()
def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states): def _upper_bound_minimization(self, current_states, predicted_current_states):
club_loss = self.club_sample(current_states, predicted_current_states, negative_current_states) club_loss = self.club_sample(current_states, predicted_current_states, current_states)
likelihood_loss = 0 likelihood_loss = 0
return likelihood_loss, club_loss return likelihood_loss, club_loss
def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): def _encoder_loss(self, curr_states_dist, predicted_curr_states_dist):
# current state distribution
curr_states_dist = curr_states_dict["distribution"]
# predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss # KL divergence loss
loss = torch.mean(torch.distributions.kl.kl_divergence(curr_states_dist,predicted_curr_states_dist)) loss = torch.mean(torch.distributions.kl.kl_divergence(curr_states_dist,predicted_curr_states_dist))