Add MOCO to introduce lower bound loss

This commit is contained in:
Vedant Dave 2023-04-10 20:18:17 +02:00
parent 05dd20cdfa
commit de17cab9f5

View File

@ -10,14 +10,17 @@ import dmc2gym
import tqdm import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, make_env, save_image from utils import ReplayBuffer, FreezeParameters, make_env, soft_update_params, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
#from agent.baseline_agent import BaselineAgent #from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent #from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent #from agent.deepmdp_agent import DeepMDPAgent
@ -53,23 +56,27 @@ def parse_args():
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--imagine_horizon', default=15, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
# eval # eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
# critic # value
parser.add_argument('--critic_lr', default=1e-3, type=float) parser.add_argument('--value_lr', default=1e-4, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float) parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float) parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int) parser.add_argument('--value_target_update_freq', default=2, type=int)
# reward
parser.add_argument('--reward_lr', default=1e-4, type=float)
# actor # actor
parser.add_argument('--actor_lr', default=1e-3, type=float) parser.add_argument('--actor_lr', default=1e-4, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int) parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder # world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--world_model_lr', default=1e-3, type=float)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int) parser.add_argument('--encoder_stride', default=1, type=int)
@ -79,6 +86,7 @@ def parse_args():
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--init_temperature', default=0.01, type=float)
@ -154,6 +162,7 @@ class DPI:
self.build_models(use_saved=False, saved_model_dir=self.model_dir) self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None): def build_models(self, use_saved, saved_model_dir=None):
# World Models
self.obs_encoder = ObservationEncoder( self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
@ -176,12 +185,14 @@ class DPI:
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
) )
self.action_model = Actor( # Actor Model
self.actor_model = Actor(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256, hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
) )
# Value Models
self.value_model = ValueModel( self.value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
@ -197,12 +208,38 @@ class DPI:
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) )
# model parameters # Contrastive Models
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \ self.prjoection_head = ProjectionHead(
list(self.obs_decoder.parameters()) + list(self.transition_model.parameters()) state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
)
# optimizer self.prjoection_head_momentum = ProjectionHead(
self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr) state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
)
self.contrastive_head = ContrastiveHead(
hidden_size=self.args.hidden_size, # 256
)
# model parameters
self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + \
list(self.value_model.parameters()) + list(self.transition_model.parameters()) + \
list(self.prjoection_head.parameters())
# optimizers
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr)
# Create Modules
self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model]
if use_saved: if use_saved:
self._use_saved_models(saved_model_dir) self._use_saved_models(saved_model_dir)
@ -214,6 +251,8 @@ class DPI:
def collect_sequences(self, episodes): def collect_sequences(self, episodes):
obs = self.env.reset() obs = self.env.reset()
self.ob_mean = np.mean(obs, 0).astype(np.float32)
self.ob_std = np.std(obs, 0).mean().astype(np.float32)
#obs_clean = self.env_clean.reset() #obs_clean = self.env_clean.reset()
done = False done = False
@ -265,14 +304,15 @@ class DPI:
self.history = self.transition_model.prev_history # (N,128) self.history = self.transition_model.prev_history # (N,128)
# Train encoder # Train encoder
total_ub_loss = 0 step = 0
total_encoder_loss = 0 total_steps = 10000
while step < total_steps:
for i in range(self.args.episode_length-1): for i in range(self.args.episode_length-1):
if i > 0: if i > 0:
# Encode observations and next_observations # Encode observations and next_observations
self.last_states_dict = self.obs_encoder(last_observations[i]) self.last_states_dict = self.get_features(last_observations[i])
self.current_states_dict = self.obs_encoder(current_observations[i]) self.current_states_dict = self.get_features(current_observations[i])
self.next_states_dict = self.obs_encoder_momentum(next_observations[i]) self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6) self.action = actions[i] # (N,6)
history = self.transition_model.prev_history history = self.transition_model.prev_history
@ -287,6 +327,8 @@ class DPI:
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"] self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss # Calculate upper bound loss
ub_loss = self._upper_bound_minimization(self.last_states_dict, ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict, self.current_states_dict,
@ -298,12 +340,45 @@ class DPI:
encoder_loss = self._past_encoder_loss(self.current_states_dict, encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict) predicted_current_state_dict)
total_ub_loss += ub_loss #total_ub_loss += ub_loss
total_encoder_loss += encoder_loss #total_encoder_loss += encoder_loss
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = labels = torch.arange(logits.shape[0]).long()
lb_loss = F.cross_entropy(logits, labels)
# update models
world_model_loss = encoder_loss + 1e-1 * ub_loss + lb_loss #1e-1 * ub_loss + 1e-5 * encoder_loss + 1e-1 * lb_loss
print("ub_loss: {:.4f}, encoder_loss: {:.4f}, lb_loss: {:.4f}".format(ub_loss, encoder_loss, lb_loss))
print("world_model_loss: {:.4f}".format(world_model_loss))
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon) imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.action, self.history.detach(),
imagine_horizon)
print(imagined_rollout["sample"].shape, imagined_rollout["distribution"][0].sample().shape)
#exit()
step += 1
if step>total_steps:
print("Training finished")
break
#exit() #exit()
#print(total_ub_loss, total_encoder_loss) #print(total_ub_loss, total_encoder_loss)
@ -315,7 +390,7 @@ class DPI:
current_states, current_states,
negative_current_states, negative_current_states,
predicted_current_states) predicted_current_states)
club_loss = club_sample() club_loss = club_sample.loglikeli()
return club_loss return club_loss
def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict):
@ -325,42 +400,27 @@ class DPI:
# predicted current state distribution # predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"] predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss # KL divergence loss
loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean()
return loss return loss
"""
def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
# Imagine next state
if step == 0:
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
self.history = imagined_next_states["history"]
else:
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
# State Distribution
imagined_next_states_dist = imagined_next_states["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
return loss
"""
def get_features(self, x, momentum=False): def get_features(self, x, momentum=False):
if self.aug: import torchvision.transforms.functional as fn
x = x/255.0 - 0.5 # Preprocessing
if self.args.aug:
x = T.RandomCrop((80, 80))(x) # (None,80,80,4) x = T.RandomCrop((80, 80))(x) # (None,80,80,4)
x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4)
x = T.RandomCrop((84, 84))(x) # (None,84,84,4) x = T.RandomCrop((84, 84))(x) # (None,84,84,4)
with torch.no_grad(): with torch.no_grad():
x = (x.float() - self.ob_mean) / self.ob_std
if momentum: if momentum:
x = self.obs_encoder(x).detach()
else:
x = self.obs_encoder_momentum(x) x = self.obs_encoder_momentum(x)
else:
x = self.obs_encoder(x)
return x return x
if __name__ == '__main__': if __name__ == '__main__':