Adding after some changes

This commit is contained in:
Vedant Dave 2023-04-22 13:07:22 +02:00
parent e7f5533ee6
commit 02a66cfb33
2 changed files with 333 additions and 128 deletions

View File

@ -9,7 +9,7 @@ import numpy as np
from collections import OrderedDict from collections import OrderedDict
import utils import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image, shuffle_along_axis, Logger
from replay_buffer import ReplayBuffer from replay_buffer import ReplayBuffer
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from video import VideoRecorder from video import VideoRecorder
@ -50,11 +50,11 @@ def parse_args():
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=100000, type=int) parser.add_argument('--num_train_steps', default=100000, type=int)
parser.add_argument('--update_steps', default=1, type=int) parser.add_argument('--update_steps', default=100, type=int)
parser.add_argument('--batch_size', default=64, type=int) #512 parser.add_argument('--batch_size', default=64, type=int) #512
parser.add_argument('--state_size', default=50, type=int) parser.add_argument('--state_size', default=50, type=int)
parser.add_argument('--hidden_size', default=512, type=int) parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--history_size', default=256, type=int)
parser.add_argument('--episode_collection', default=5, type=int) parser.add_argument('--episode_collection', default=5, type=int)
parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer') parser.add_argument('--episodes_buffer', default=5, type=int, help='Initial number of episodes to store in the buffer')
parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
@ -66,11 +66,11 @@ def parse_args():
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000 parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
# value # value
parser.add_argument('--value_lr', default=8e-6, type=float) parser.add_argument('--value_lr', default=1e-6, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int) parser.add_argument('--td_lambda', default=0.95, type=int)
# actor # actor
parser.add_argument('--actor_lr', default=8e-6, type=float) parser.add_argument('--actor_lr', default=1e-6, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
@ -78,13 +78,15 @@ def parse_args():
# world/encoder/decoder # world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--world_model_lr', default=1e-5, type=float) parser.add_argument('--world_model_lr', default=1e-5, type=float)
parser.add_argument('--encoder_tau', default=0.001 , type=float) parser.add_argument('--decoder_lr', default=1e-5, type=float)
parser.add_argument('--reward_lr', default=1e-5, type=float)
parser.add_argument('--encoder_tau', default=0.001, type=float)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true') parser.add_argument('--aug', action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.95, type=float) parser.add_argument('--discount', default=0.99, type=float)
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int) parser.add_argument('--logging_freq', default=100, type=int)
@ -131,15 +133,14 @@ class DPI:
self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat) self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env) self.env = utils.NormalizeActions(self.env)
self.env = utils.TimeLimit(self.env, 1000 / args.action_repeat) self.env = utils.TimeLimit(self.env, 1000 // args.action_repeat)
# create replay buffer # create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, self.data_buffer = ReplayBuffer(self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), self.env.observation_space.shape,
action_size=self.env.action_space.shape[0], self.env.action_space.shape[0],
seq_len=self.args.episode_length, self.args.episode_length,
batch_size=args.batch_size, self.args.batch_size)
args=self.args)
# create work directory # create work directory
utils.make_dir(self.args.work_dir) utils.make_dir(self.args.work_dir)
@ -180,7 +181,7 @@ class DPI:
hidden_size=self.args.hidden_size, # 256, hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
).to(device) ).to(device)
self.actor_model.apply(self.init_weights) #self.actor_model.apply(self.init_weights)
# Value Models # Value Models
@ -225,22 +226,23 @@ class DPI:
# model parameters # model parameters
self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \ self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \
list(self.transition_model.parameters()) + list(self.obs_decoder.parameters()) + \ list(self.transition_model.parameters()) + list(self.club_sample.parameters()) + \
list(self.reward_model.parameters()) + list(self.club_sample.parameters()) list(self.contrastive_head.parameters())
self.past_transition_parameters = self.transition_model.parameters() self.past_transition_parameters = self.transition_model.parameters()
# optimizers # optimizers
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr) self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr,eps=1e-6)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr) self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr,eps=1e-6)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr) self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr,eps=1e-6)
#self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), 1e-5) self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), self.args.decoder_lr,eps=1e-6)
#self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), 1e-4) self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), self.args.reward_lr,eps=1e-6)
# Create Modules # Create Modules
self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.obs_decoder, self.reward_model, self.club_sample] self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.club_sample, self.contrastive_head]
self.value_modules = [self.value_model] self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model] self.actor_modules = [self.actor_model]
#self.reward_modules = [self.reward_model] self.decoder_modules = [self.obs_decoder]
self.reward_modules = [self.reward_model]
#self.decoder_modules = [self.obs_decoder] #self.decoder_modules = [self.obs_decoder]
if use_saved: if use_saved:
@ -251,63 +253,80 @@ class DPI:
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_random_sequences(self, episodes): def collect_random_sequences(self, seed_steps):
obs = self.env.reset() obs = self.env.reset()
done = False done = False
all_rews = [] all_rews = []
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
self.global_episodes += 1 self.global_episodes += 1
epi_reward = 0 epi_reward = 0
while not done: for _ in tqdm.tqdm(range(seed_steps), desc='Collecting episodes'):
action = self.env.action_space.sample() action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action) next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes) self.data_buffer.add(obs, action, next_obs, rew, done)
obs = next_obs obs = next_obs
epi_reward += rew epi_reward += rew
if done:
obs = self.env.reset() obs = self.env.reset()
done=False done=False
all_rews.append(epi_reward) all_rews.append(epi_reward)
epi_reward = 0
return all_rews return all_rews
def collect_sequences(self, episodes, actor_model): def collect_sequences(self, collect_steps, actor_model):
obs = self.env.reset() obs = self.env.reset()
done = False done = False
all_rews = [] all_rews = []
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
self.global_episodes += 1 self.global_episodes += 1
epi_reward = 0 epi_reward = 0
while not done: for episode_count in tqdm.tqdm(range(collect_steps), desc='Collecting episodes'):
with torch.no_grad(): with torch.no_grad():
obs = torch.tensor(obs.copy(), dtype=torch.float32).to(device).unsqueeze(0) obs_ = torch.tensor(obs.copy(), dtype=torch.float32)
state = self.get_features(obs)["distribution"].rsample() obs_ = preprocess_obs(obs_).to(device)
action = self.actor_model(state) state = self.get_features(obs_)["distribution"].rsample().unsqueeze(0)
action = actor_model.add_exploration(action).cpu().numpy()[0] action = actor_model(state)
print(action) action = actor_model.add_exploration(action)
obs = obs.cpu().numpy()[0] action = action.cpu().numpy()[0]
next_obs, rew, done, _ = self.env.step(action) next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes) self.data_buffer.add(obs, action, next_obs, rew, done)
if done:
obs = self.env.reset()
done = False
all_rews.append(epi_reward)
epi_reward = 0
else:
obs = next_obs obs = next_obs
epi_reward += rew epi_reward += rew
obs = self.env.reset()
done=False
all_rews.append(epi_reward)
return all_rews return all_rews
def train(self, step, total_steps): def train(self, step, total_steps):
episodic_rews = self.collect_random_sequences(self.args.episodes_buffer)
global_step = self.data_buffer.steps
# logger # logger
logs = OrderedDict() logdir = os.path.dirname(os.path.realpath(__file__)) + "/log/logs/"
if not(os.path.exists(logdir)):
os.makedirs(logdir)
initial_logs = OrderedDict()
logger = Logger(logdir)
while global_step < total_steps: episodic_rews = self.collect_random_sequences(5000//args.action_repeat)
self.global_step = self.data_buffer.steps
initial_logs.update({
'train_avg_reward':np.mean(episodic_rews),
'train_max_reward': np.max(episodic_rews),
'train_min_reward': np.min(episodic_rews),
'train_std_reward':np.std(episodic_rews),
})
logger.log_scalars(initial_logs, step=0)
logger.flush()
while self.global_step < total_steps:
logs = OrderedDict()
step += 1 step += 1
for update_steps in range(self.args.update_steps): for update_steps in range(self.args.update_steps):
model_loss, actor_loss, value_loss = self.update((step-1)*args.update_steps + update_steps) model_loss, actor_loss, value_loss, actor_model = self.update((step-1)*args.update_steps + update_steps)
episodic_rews = self.collect_sequences(self.args.episode_collection, actor_model=self.actor_model, encoder_model=self.obs_encoder)
logs.update({ initial_logs.update({
'model_loss' : model_loss, 'model_loss' : model_loss,
'actor_loss': actor_loss, 'actor_loss': actor_loss,
'value_loss': value_loss, 'value_loss': value_loss,
@ -316,28 +335,51 @@ class DPI:
'train_min_reward': np.min(episodic_rews), 'train_min_reward': np.min(episodic_rews),
'train_std_reward':np.std(episodic_rews), 'train_std_reward':np.std(episodic_rews),
}) })
logger.log_scalars(logs, self.global_step)
print("########## Global Step: ", global_step, " ##########")
for key, value in logs.items(): print("########## Global Step:", self.global_step, " ##########")
for key, value in initial_logs.items():
print(key, " : ", value) print(key, " : ", value)
print(global_step) episodic_rews = self.collect_sequences(1000//self.args.action_repeat, actor_model)
if global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
if self.global_step % 3150 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
print("Saving model") print("Saving model")
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path) self.save_models(path)
self.evaluate() self.evaluate()
global_step = self.data_buffer.steps self.global_step = self.data_buffer.steps * self.args.action_repeat
"""
# collect experience # collect experience
if step !=0: if step !=0:
encoder = self.obs_encoder encoder = self.obs_encoder
actor = self.actor_model actor = self.actor_model
all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder) all_rews = self.collect_sequences(self.args.episode_collection, actor_model=actor, encoder_model=encoder)
"""
def collect_batch(self):
obs_, acs_, nxt_obs_, rews_, terms_ = self.data_buffer.sample()
obs = torch.tensor(obs_, dtype=torch.float32)[1:]
last_obs = torch.tensor(obs_, dtype=torch.float32)[:-1]
nxt_obs = torch.tensor(nxt_obs_, dtype=torch.float32)[1:]
acs = torch.tensor(acs_, dtype=torch.float32)[:-1].to(device)
nxt_acs = torch.tensor(acs_, dtype=torch.float32)[1:].to(device)
rews = torch.tensor(rews_, dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
nonterms = torch.tensor((1.0-terms_), dtype=torch.float32)[:-1].to(device).unsqueeze(-1)
last_obs = preprocess_obs(last_obs).to(device)
obs = preprocess_obs(obs).to(device)
nxt_obs = preprocess_obs(nxt_obs).to(device)
return last_obs, obs, nxt_obs, acs, rews, nxt_acs, nonterms
def update(self, step): def update(self, step):
last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch() last_observations, current_observations, next_observations, actions, rewards, next_actions, nonterms = self.collect_batch()
#last_observations, current_observations, next_observations, actions, next_actions, rewards = self.select_one_batch()
world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations, world_loss, enc_loss, rew_loss, dec_loss, ub_loss, lb_loss = self.world_model_losses(last_observations,
@ -345,12 +387,23 @@ class DPI:
next_observations, next_observations,
actions, actions,
next_actions, next_actions,
rewards) rewards,
nonterms)
self.world_model_opt.zero_grad() self.world_model_opt.zero_grad()
world_loss.backward() world_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step() self.world_model_opt.step()
self.decoder_opt.zero_grad()
dec_loss.backward()
nn.utils.clip_grad_norm_(self.obs_decoder.parameters(), self.args.grad_clip_norm)
self.decoder_opt.step()
self.reward_opt.zero_grad()
rew_loss.backward()
nn.utils.clip_grad_norm_(self.reward_model.parameters(), self.args.grad_clip_norm)
self.reward_opt.step()
actor_loss = self.actor_model_losses() actor_loss = self.actor_model_losses()
self.actor_opt.zero_grad() self.actor_opt.zero_grad()
actor_loss.backward() actor_loss.backward()
@ -368,8 +421,8 @@ class DPI:
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# update target value networks # update target value networks
if step % self.args.value_target_update_freq == 0: #if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model) # self.target_value_model = copy.deepcopy(self.value_model)
if step % self.args.logging_freq: if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step) writer.add_scalar('World Loss/World Loss', world_loss.detach().item(), step)
@ -379,14 +432,15 @@ class DPI:
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step) writer.add_scalar('Actor Critic Loss/Reward Loss', rew_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step) writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step) writer.add_scalar('Bound Loss/Lower Bound Loss', -lb_loss.detach().item(), step)
return world_loss.item(), actor_loss.item(), value_loss.item() return world_loss.item(), actor_loss.item(), value_loss.item(), self.actor_model
def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards): def world_model_losses(self, last_obs, curr_obs, nxt_obs, actions, nxt_actions, rewards, nonterms):
# get features
self.last_state_feat = self.get_features(last_obs) self.last_state_feat = self.get_features(last_obs)
self.curr_state_feat = self.get_features(curr_obs) self.curr_state_feat = self.get_features(curr_obs)
self.nxt_state_feat = self.get_features(nxt_obs) self.nxt_state_feat = self.get_features(nxt_obs, momentum=True)
# states # states
self.last_state_enc = self.last_state_feat["sample"] self.last_state_enc = self.last_state_feat["sample"]
@ -394,42 +448,37 @@ class DPI:
self.nxt_state_enc = self.nxt_state_feat["sample"] self.nxt_state_enc = self.nxt_state_feat["sample"]
# actions # actions
actions = actions actions = actions.clone()
nxt_actions = nxt_actions nxt_actions = nxt_actions.clone()
# rewards # rewards
rewards = rewards rewards = rewards.clone()
# predict next states # predict next states
self.transition_model.init_states(self.args.batch_size, device) # (N,128) self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history) self.observed_rollout = self.transition_model.observe_rollout(self.last_state_enc, actions, self.transition_model.prev_history, nonterms)
self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"]) self.pred_curr_state_dist = self.transition_model.get_dist(self.observed_rollout["mean"], self.observed_rollout["std"])
self.pred_curr_state_enc = self.pred_curr_state_dist.mean self.pred_curr_state_enc = self.pred_curr_state_dist.mean
#print(torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc))
#print(torch.distributions.kl_divergence(self.curr_state_feat["distribution"], self.pred_curr_state_dist).mean(),0)
# encoder loss # encoder loss
enc_loss = torch.nn.MSELoss()(self.curr_state_enc, self.pred_curr_state_enc) enc_loss = self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
#self._encoder_loss(self.curr_state_feat["distribution"], self.pred_curr_state_dist)
# reward loss # reward loss
rew_dist = self.reward_model(self.curr_state_enc) rew_dist = self.reward_model(self.curr_state_enc)
rew_loss = -torch.mean(rew_dist.log_prob(rewards.unsqueeze(-1))) rew_loss = -torch.mean(rew_dist.log_prob(rewards))
# decoder loss # decoder loss
dec_dist = self.obs_decoder(self.nxt_state_enc) dec_dist = self.obs_decoder(self.nxt_state_enc)
dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs)) dec_loss = -torch.mean(dec_dist.log_prob(nxt_obs))
# upper bound loss # upper bound loss
likelihood_loss, ub_loss = self._upper_bound_minimization(self.curr_state_enc, _, ub_loss = self._upper_bound_minimization(self.curr_state_enc,
self.pred_curr_state_enc) self.pred_curr_state_enc)
# lower bound loss # lower bound loss
# contrastive projection # contrastive projection
vec_anchor = self.pred_curr_state_enc vec_anchor = self.pred_curr_state_enc.detach()
vec_positive = self.nxt_state_enc vec_positive = self.nxt_state_enc.detach()
z_anchor = self.prjoection_head(vec_anchor, nxt_actions) z_anchor = self.prjoection_head(vec_anchor, nxt_actions)
z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions) z_positive = self.prjoection_head_momentum(vec_positive, nxt_actions)
@ -440,40 +489,39 @@ class DPI:
labels = torch.arange(logits.shape[0]).long().to(device) labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels) + past_lb_loss lb_loss = F.cross_entropy(logits, labels) + past_lb_loss
past_lb_loss = lb_loss.detach().item() past_lb_loss = lb_loss.detach().item()
lb_loss = lb_loss/(z_anchor.shape[0]) lb_loss = -0.1 * lb_loss/(z_anchor.shape[0])
world_loss = enc_loss + rew_loss + dec_loss * 1e-4 + ub_loss * 10 + lb_loss world_loss = enc_loss + ub_loss + lb_loss
return world_loss, enc_loss , rew_loss, dec_loss * 1e-4, ub_loss * 10, lb_loss return world_loss, enc_loss , rew_loss, dec_loss, ub_loss, lb_loss
def actor_model_losses(self): def actor_model_losses(self):
with torch.no_grad(): with torch.no_grad():
curr_state_enc = self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"] curr_state_enc = self.transition_model.seq_to_batch(self.curr_state_feat, "sample")["sample"]
curr_state_hist = self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"] curr_state_hist = self.transition_model.seq_to_batch(self.observed_rollout, "history")["sample"]
with FreezeParameters(self.world_model_modules): with FreezeParameters(self.world_model_modules + self.decoder_modules + self.reward_modules):
imagine_horizon = self.args.imagine_horizon imagine_horizon = self.args.imagine_horizon
action = self.actor_model(curr_state_enc) action = self.actor_model(curr_state_enc)
self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc, self.imagined_rollout = self.transition_model.imagine_rollout(curr_state_enc,
action, curr_state_hist, action, curr_state_hist,
imagine_horizon) imagine_horizon)
with FreezeParameters(self.world_model_modules + self.value_modules): with FreezeParameters(self.world_model_modules + self.value_modules + self.decoder_modules + self.reward_modules):
imag_rewards = self.reward_model(self.imagined_rollout["sample"]).mean imag_rewards = self.reward_model(self.imagined_rollout["sample"]).mean
imag_values = self.target_value_model(self.imagined_rollout["sample"]).mean imag_values = self.value_model(self.imagined_rollout["sample"]).mean
discounts = self.args.discount * torch.ones_like(imag_rewards).detach() discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1], self.returns = self._compute_lambda_return(imag_rewards[:-1],
imag_values[:-1], imag_values[:-1],
discounts[:-1] , discounts[:-1] ,
self.args.td_lambda, self.args.td_lambda,
imag_values[-1]) imag_values[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach() self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.returns) actor_loss = -torch.mean(self.discounts * self.returns)
return actor_loss return actor_loss
def value_model_losses(self): def value_model_losses(self):
# value loss # value loss
with torch.no_grad(): with torch.no_grad():
@ -483,7 +531,6 @@ class DPI:
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
return value_loss return value_loss
def select_one_batch(self): def select_one_batch(self):
# collect sequences # collect sequences
non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0] non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0]
@ -530,9 +577,6 @@ class DPI:
return last_observations, current_observations, next_observations, actions, next_actions, rewards return last_observations, current_observations, next_observations, actions, next_actions, rewards
def evaluate(self, eval_episodes=10): def evaluate(self, eval_episodes=10):
path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.restore_checkpoint(path) self.restore_checkpoint(path)
@ -608,7 +652,9 @@ class DPI:
return torch.tensor(transposed_array).float() return torch.tensor(transposed_array).float()
def _upper_bound_minimization(self, current_states, predicted_current_states): def _upper_bound_minimization(self, current_states, predicted_current_states):
club_loss = self.club_sample(current_states, predicted_current_states, current_states) current_negative_states = shuffle_along_axis(current_states.clone(), axis=0)
current_negative_states = shuffle_along_axis(current_negative_states, axis=1)
club_loss = self.club_sample(current_states, predicted_current_states, current_negative_states)
likelihood_loss = 0 likelihood_loss = 0
return likelihood_loss, club_loss return likelihood_loss, club_loss
@ -645,6 +691,24 @@ class DPI:
returns = torch.flip(torch.stack(rets), [0]) returns = torch.flip(torch.stack(rets), [0])
return returns return returns
def lambda_return(self,imged_reward, value_pred, bootstrap, discount=0.99, lambda_=0.95):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
next_values = torch.cat([value_pred[1:], bootstrap[None]], 0)
discount_tensor = discount * torch.ones_like(imged_reward) # pcont
inputs = imged_reward + discount_tensor * next_values * (1 - lambda_)
last = bootstrap
indices = reversed(range(len(inputs)))
outputs = []
for index in indices:
inp, disc = inputs[index], discount_tensor[index]
last = inp + disc * lambda_ * last
outputs.append(last)
outputs = list(reversed(outputs))
outputs = torch.stack(outputs, 0)
returns = outputs
return returns
def save_models(self, save_path): def save_models(self, save_path):
torch.save( torch.save(
{'rssm' : self.transition_model.state_dict(), {'rssm' : self.transition_model.state_dict(),

View File

@ -1,10 +1,13 @@
import os import os
import random import random
import pickle
import numpy as np import numpy as np
from collections import deque from collections import deque
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import gym import gym
import dmc2gym import dmc2gym
@ -144,8 +147,90 @@ class NormalizeActions:
original = np.where(self._mask, original, action) original = np.where(self._mask, original, action)
return self._env.step(original) return self._env.step(original)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, 'Must reset environment.'
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if 'discount' not in info:
info['discount'] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
class ReplayBuffer: class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,), dtype=np.float32)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
self.episode_count = np.zeros((size,), dtype=np.int32)
def add(self, obs, ac, next_obs, rew, done, episode_count):
self.observations[self.idx] = obs
self.next_observations[self.idx] = next_obs
self.actions[self.idx] = ac
self.rewards[self.idx] = rew
self.terminals[self.idx] = done
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
self.episode_count[self.idx] = episode_count
self.idx = (self.idx + 1) % self.size
def _sample_idx(self, L):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.size if self.full else self.idx - L)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = not self.idx in idxs[1:]
return idxs
def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs]
next_obs = self.next_observations[vec_idxs]
obs = observations.reshape(L, n, *observations.shape[1:])
next_obs = next_obs.reshape(L, n, *next_obs.shape[1:])
acs = self.actions[vec_idxs].reshape(L, n, -1)
rew = self.rewards[vec_idxs].reshape(L, n)
term = self.terminals[vec_idxs].reshape(L, n)
return obs, acs, next_obs, rew, term
def sample(self):
n = self.batch_size
l = self.seq_len
obs,acs,next_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,next_obs,rews,terms
class ReplayBuffer1:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size self.size = size
self.obs_shape = obs_shape self.obs_shape = obs_shape
@ -199,7 +284,10 @@ class ReplayBuffer:
def group_steps(self, buffer, variable, obs=True): def group_steps(self, buffer, variable, obs=True):
variable = getattr(buffer, variable) variable = getattr(buffer, variable)
non_zero_indices = np.nonzero(buffer.episode_count)[0] non_zero_indices = np.nonzero(buffer.episode_count)[0]
print(buffer.episode_count)
variable = variable[non_zero_indices] variable = variable[non_zero_indices]
print(variable.shape)
exit()
if obs: if obs:
variable = variable.reshape(-1, self.args.episode_length, variable = variable.reshape(-1, self.args.episode_length,
@ -215,8 +303,9 @@ class ReplayBuffer:
self.args.image_size,self.args.image_size) self.args.image_size,self.args.image_size)
return variable return variable
def sample_random_idx(self, buffer_length): def sample_random_idx(self, buffer_length, last=False):
random_indices = random.sample(range(0, buffer_length), self.args.batch_size) init = 0 if last else buffer_length - self.args.batch_size
random_indices = random.sample(range(init, buffer_length), self.args.batch_size)
return random_indices return random_indices
def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0): def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0):
@ -248,19 +337,23 @@ def make_env(args):
) )
return env return env
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a,idx,axis=axis)
def preprocess_obs(obs): def preprocess_obs(obs):
obs = obs/255.0 - 0.5 obs = (obs/255.0) - 0.5
return obs return obs
def soft_update_params(net, target_net, tau): def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()): for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_( target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data tau * param.detach().data + (1 - tau) * target_param.data
) )
def save_image(array, filename): def save_image(array, filename):
array = array.transpose(1, 2, 0) array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8) array = ((array+0.5) * 255).astype(np.uint8)
image = Image.fromarray(array) image = Image.fromarray(array)
image.save(filename) image.save(filename)
@ -354,3 +447,51 @@ class FreezeParameters:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)): for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i] param.requires_grad = self.param_states[i]
class Logger:
def __init__(self, log_dir, n_logged_samples=10, summary_writer=None):
self._log_dir = log_dir
print('########################')
print('logging outputs to ', log_dir)
print('########################')
self._n_logged_samples = n_logged_samples
self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1)
def log_scalar(self, scalar, name, step_):
self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
def log_scalars(self, scalar_dict, step):
for key, value in scalar_dict.items():
print('{} : {}'.format(key, value))
self.log_scalar(value, key, step)
self.dump_scalars_to_pickle(scalar_dict, step)
def log_videos(self, videos, step, max_videos_to_save=1, fps=20, video_title='video'):
# max rollout length
max_videos_to_save = np.min([max_videos_to_save, videos.shape[0]])
max_length = videos[0].shape[0]
for i in range(max_videos_to_save):
if videos[i].shape[0]>max_length:
max_length = videos[i].shape[0]
# pad rollouts to all be same length
for i in range(max_videos_to_save):
if videos[i].shape[0]<max_length:
padding = np.tile([videos[i][-1]], (max_length-videos[i].shape[0],1,1,1))
videos[i] = np.concatenate([videos[i], padding], 0)
clip = mpy.ImageSequenceClip(list(videos[i]), fps=fps)
new_video_title = video_title+'{}_{}'.format(step, i) + '.gif'
filename = os.path.join(self._log_dir, new_video_title)
video.write_gif(filename, fps =fps)
def dump_scalars_to_pickle(self, metrics, step, log_title=None):
log_path = os.path.join(self._log_dir, "scalar_data.pkl" if log_title is None else log_title)
with open(log_path, 'ab') as f:
pickle.dump({'step': step, **dict(metrics)}, f)
def flush(self):
self._summ_writer.flush()