import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import copy import math import utils from encoder import make_encoder, club_loss, TransitionModel from decoder import make_decoder LOG_FREQ = 10000 def gaussian_logprob(noise, log_std): """Compute Gaussian log probability.""" residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) def squash(mu, pi, log_pi): """Apply squashing function. See appendix C from https://arxiv.org/pdf/1812.05905.pdf. """ mu = torch.tanh(mu) if pi is not None: pi = torch.tanh(pi) if log_pi is not None: log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) return mu, pi, log_pi def weight_init(m): """Custom weight init for Conv2D and Linear layers.""" if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data) m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf assert m.weight.size(2) == m.weight.size(3) m.weight.data.fill_(0.0) m.bias.data.fill_(0.0) mid = m.weight.size(2) // 2 gain = nn.init.calculate_gain('relu') nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) class Actor(nn.Module): """MLP actor network.""" def __init__( self, obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters ): super().__init__() self.encoder = make_encoder( encoder_type, obs_shape, encoder_feature_dim, num_layers, num_filters ) self.log_std_min = log_std_min self.log_std_max = log_std_max self.trunk = nn.Sequential( nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2 * action_shape[0]) ) self.outputs = dict() self.apply(weight_init) def forward(self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False): obs, _, _ = self.encoder(obs, detach=detach_encoder) mu, log_std = self.trunk(obs).chunk(2, dim=-1) # constrain log_std inside [log_std_min, log_std_max] log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * ( self.log_std_max - self.log_std_min ) * (log_std + 1) self.outputs['mu'] = mu self.outputs['std'] = log_std.exp() if compute_pi: std = log_std.exp() noise = torch.randn_like(mu) pi = mu + noise * std else: pi = None entropy = None if compute_log_pi: log_pi = gaussian_logprob(noise, log_std) else: log_pi = None mu, pi, log_pi = squash(mu, pi, log_pi) return mu, pi, log_pi, log_std def log(self, L, step, log_freq=LOG_FREQ): if step % log_freq != 0: return for k, v in self.outputs.items(): L.log_histogram('train_actor/%s_hist' % k, v, step) L.log_param('train_actor/fc1', self.trunk[0], step) L.log_param('train_actor/fc2', self.trunk[2], step) L.log_param('train_actor/fc3', self.trunk[4], step) class QFunction(nn.Module): """MLP for q-function.""" def __init__(self, obs_dim, action_dim, hidden_dim): super().__init__() self.trunk = nn.Sequential( nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) def forward(self, obs, action): assert obs.size(0) == action.size(0) obs_action = torch.cat([obs, action], dim=1) return self.trunk(obs_action) class Critic(nn.Module): """Critic network, employes two q-functions.""" def __init__( self, obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, num_layers, num_filters ): super().__init__() self.encoder = make_encoder( encoder_type, obs_shape, encoder_feature_dim, num_layers, num_filters ) self.Q1 = QFunction( self.encoder.feature_dim, action_shape[0], hidden_dim ) self.Q2 = QFunction( self.encoder.feature_dim, action_shape[0], hidden_dim ) self.outputs = dict() self.apply(weight_init) def forward(self, obs, action, detach_encoder=False): # detach_encoder allows to stop gradient propogation to encoder obs, _ , _ = self.encoder(obs, detach=detach_encoder) q1 = self.Q1(obs, action) q2 = self.Q2(obs, action) self.outputs['q1'] = q1 self.outputs['q2'] = q2 return q1, q2 def log(self, L, step, log_freq=LOG_FREQ): if step % log_freq != 0: return self.encoder.log(L, step, log_freq) for k, v in self.outputs.items(): L.log_histogram('train_critic/%s_hist' % k, v, step) for i in range(3): L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) class CURL(nn.Module): """ CURL """ def __init__(self, obs_shape, z_dim, a_dim, batch_size, critic, critic_target, output_type="continuous"): super(CURL, self).__init__() self.batch_size = batch_size self.encoder = critic.encoder self.encoder_target = critic_target.encoder self.W = nn.Parameter(torch.rand(z_dim, z_dim)) self.combine = nn.Linear(z_dim + a_dim, z_dim) self.output_type = output_type def encode(self, x, a=None, detach=False, ema=False): """ Encoder: z_t = e(x_t) :param x: x_t, x y coordinates :return: z_t, value in r2 """ if ema: with torch.no_grad(): z_out = self.encoder_target(x)[0] z_out = self.combine(torch.concat((z_out,a), dim=-1)) else: z_out = self.encoder(x)[0] if detach: z_out = z_out.detach() return z_out def compute_logits(self, z_a, z_pos): """ Uses logits trick for CURL: - compute (B,B) matrix z_a (W z_pos.T) - positives are all diagonal elements - negatives are all other elements - to compute loss use multiclass cross entropy with identity matrix for labels """ Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) logits = torch.matmul(z_a, Wz) # (B,B) logits = logits - torch.max(logits, 1)[0][:, None] return logits class SacAeAgent(object): """SAC+AE algorithm.""" def __init__( self, obs_shape, action_shape, device, hidden_dim=256, discount=0.99, init_temperature=0.01, alpha_lr=1e-3, alpha_beta=0.9, actor_lr=1e-3, actor_beta=0.9, actor_log_std_min=-10, actor_log_std_max=2, actor_update_freq=2, critic_lr=1e-3, critic_beta=0.9, critic_tau=0.005, critic_target_update_freq=2, encoder_type='pixel', encoder_feature_dim=50, encoder_lr=1e-3, encoder_tau=0.005, decoder_type='pixel', decoder_lr=1e-3, decoder_update_freq=1, decoder_latent_lambda=0.0, decoder_weight_lambda=0.0, num_layers=4, num_filters=32 ): self.device = device self.discount = discount self.critic_tau = critic_tau self.encoder_tau = encoder_tau self.actor_update_freq = actor_update_freq self.critic_target_update_freq = critic_target_update_freq self.decoder_update_freq = decoder_update_freq self.decoder_latent_lambda = decoder_latent_lambda self.transition_model = TransitionModel( encoder_feature_dim, hidden_dim, action_shape[0], encoder_feature_dim).to(device) self.actor = Actor( obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, actor_log_std_min, actor_log_std_max, num_layers, num_filters ).to(device) self.critic = Critic( obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, num_layers, num_filters ).to(device) self.critic_target = Critic( obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, num_layers, num_filters ).to(device) self.critic_target.load_state_dict(self.critic.state_dict()) # tie encoders between actor and critic self.actor.encoder.copy_conv_weights_from(self.critic.encoder) self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) self.log_alpha.requires_grad = True # set target entropy to -|A| self.target_entropy = -np.prod(action_shape) self.CURL = CURL(obs_shape, encoder_feature_dim, action_shape[0], obs_shape[0], self.critic,self.critic_target, output_type='continuous').to(self.device) self.cross_entropy_loss = nn.CrossEntropyLoss() self.decoder = None if decoder_type != 'identity': # create decoder self.decoder = make_decoder( decoder_type, obs_shape, encoder_feature_dim, num_layers, num_filters ).to(device) self.decoder.apply(weight_init) # optimizer for critic encoder for reconstruction loss self.encoder_optimizer = torch.optim.Adam( self.critic.encoder.parameters(), lr=encoder_lr ) # optimizer for decoder self.decoder_optimizer = torch.optim.Adam( self.decoder.parameters(), lr=decoder_lr, weight_decay=decoder_weight_lambda ) # optimizers self.actor_optimizer = torch.optim.Adam( self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999) ) self.critic_optimizer = torch.optim.Adam( self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999) ) self.cpc_optimizer = torch.optim.Adam( self.CURL.parameters(), lr=encoder_lr ) self.log_alpha_optimizer = torch.optim.Adam( [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999) ) self.train() self.critic_target.train() def train(self, training=True): self.training = training self.actor.train(training) self.critic.train(training) if self.decoder is not None: self.decoder.train(training) @property def alpha(self): return self.log_alpha.exp() def select_action(self, obs): with torch.no_grad(): obs = torch.FloatTensor(obs).to(self.device) obs = obs.unsqueeze(0) mu, _, _, _ = self.actor( obs, compute_pi=False, compute_log_pi=False ) return mu.cpu().data.numpy().flatten() def sample_action(self, obs): with torch.no_grad(): obs = torch.FloatTensor(obs).to(self.device) obs = obs.unsqueeze(0) mu, pi, _, _ = self.actor(obs, compute_log_pi=False) return pi.cpu().data.numpy().flatten() def update_critic(self, obs, action, reward, next_obs, not_done, L, step): with torch.no_grad(): _, policy_action, log_pi, _ = self.actor(next_obs) target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi target_Q = reward + (not_done * self.discount * target_V) # get current Q estimates current_Q1, current_Q2 = self.critic(obs, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) L.log('train_critic/loss', critic_loss, step) # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() self.critic.log(L, step) def update_actor_and_alpha(self, obs, L, step): # detach encoder, so we don't update it with the actor loss _, pi, log_pi, log_std = self.actor(obs, detach_encoder=True) actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True) actor_Q = torch.min(actor_Q1, actor_Q2) actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() L.log('train_actor/loss', actor_loss, step) L.log('train_actor/target_entropy', self.target_entropy, step) entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi) ) + log_std.sum(dim=-1) L.log('train_actor/entropy', entropy.mean(), step) # optimize the actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.actor.log(L, step) self.log_alpha_optimizer.zero_grad() alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean() L.log('train_alpha/loss', alpha_loss, step) L.log('train_alpha/value', self.alpha, step) alpha_loss.backward() self.log_alpha_optimizer.step() def update_decoder(self, last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, target_obs, L, step): h_curr, mu_h_curr, std_h_curr = self.critic.encoder(curr_obs) with torch.no_grad(): h_last, _, _ = self.critic.encoder(last_obs) self.transition_model.init_states(last_obs.shape[0], self.device) curr_state = self.transition_model.transition_step(h_last, last_action, self.transition_model.prev_history, last_not_done) hist = curr_state["history"] next_state = self.transition_model.transition_step(h_curr, action, hist, not_done) next_state_mu = next_state["mean"] next_state_sigma = next_state["std"] next_state_sample = next_state["sample"] pred_dist = torch.distributions.Normal(next_state_mu, next_state_sigma) h, mu_h_next, logstd_h_next = self.critic.encoder(next_obs) std_h_next = torch.exp(logstd_h_next) enc_dist = torch.distributions.Normal(mu_h_next, std_h_next) enc_loss = torch.mean(torch.distributions.kl.kl_divergence(enc_dist,pred_dist)) * 0.1 z_pos = self.CURL.encode(next_obs, action.detach(), ema=True) logits = self.CURL.compute_logits(h_curr, z_pos) labels = torch.arange(logits.shape[0]).long().to(self.device) lb_loss = self.cross_entropy_loss(logits, labels) * 0.1 ub_loss = club_loss(h, mu_h_next, logstd_h_next, next_state_sample) * 0.1 if target_obs.dim() == 4: # preprocess images to be in [-0.5, 0.5] range target_obs = utils.preprocess_obs(target_obs) rec_obs = self.decoder(h) rec_loss = F.mse_loss(target_obs, rec_obs) # add L2 penalty on latent representation # see https://arxiv.org/pdf/1903.12436.pdf latent_loss = (0.5 * h.pow(2).sum(1)).mean() loss = rec_loss + enc_loss + lb_loss + ub_loss #self.decoder_latent_lambda * latent_loss self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() self.cpc_optimizer.zero_grad() loss.backward() self.encoder_optimizer.step() self.decoder_optimizer.step() self.cpc_optimizer.step() L.log('train_ae/ae_loss', loss, step) L.log('train_ae/lb_loss', lb_loss, step) L.log('train_ae/ub_loss', ub_loss, step) L.log('train_ae/enc_loss', enc_loss, step) L.log('train_ae/dec_loss', rec_loss, step) self.decoder.log(L, step, log_freq=LOG_FREQ) def update(self, replay_buffer, L, step): last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done = replay_buffer.sample() #obs, action, reward, next_obs, not_done = replay_buffer.sample() L.log('train/batch_reward', last_reward.mean(), step) #self.update_critic(last_obs, last_action, last_reward, curr_obs, last_not_done, L, step) self.update_critic(curr_obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: #self.update_actor_and_alpha(last_obs, L, step) self.update_actor_and_alpha(curr_obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params( self.critic.Q1, self.critic_target.Q1, self.critic_tau ) utils.soft_update_params( self.critic.Q2, self.critic_target.Q2, self.critic_tau ) utils.soft_update_params( self.critic.encoder, self.critic_target.encoder, self.encoder_tau ) if self.decoder is not None and step % self.decoder_update_freq == 0: self.update_decoder(last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, next_obs, L, step) def save(self, model_dir, step): torch.save( self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) ) torch.save( self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) ) if self.decoder is not None: torch.save( self.decoder.state_dict(), '%s/decoder_%s.pt' % (model_dir, step) ) def load(self, model_dir, step): self.actor.load_state_dict( torch.load('%s/actor_%s.pt' % (model_dir, step)) ) self.critic.load_state_dict( torch.load('%s/critic_%s.pt' % (model_dir, step)) ) if self.decoder is not None: self.decoder.load_state_dict( torch.load('%s/decoder_%s.pt' % (model_dir, step)) )