2019-09-23 18:20:48 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import copy
|
|
|
|
import math
|
|
|
|
|
|
|
|
import utils
|
2023-05-16 10:40:47 +00:00
|
|
|
from encoder import make_encoder, club_loss, TransitionModel
|
2019-09-23 18:20:48 +00:00
|
|
|
from decoder import make_decoder
|
|
|
|
|
|
|
|
LOG_FREQ = 10000
|
|
|
|
|
|
|
|
|
|
|
|
def gaussian_logprob(noise, log_std):
|
|
|
|
"""Compute Gaussian log probability."""
|
|
|
|
residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
|
|
|
|
return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
|
|
|
|
|
|
|
|
|
|
|
|
def squash(mu, pi, log_pi):
|
|
|
|
"""Apply squashing function.
|
|
|
|
See appendix C from https://arxiv.org/pdf/1812.05905.pdf.
|
|
|
|
"""
|
|
|
|
mu = torch.tanh(mu)
|
|
|
|
if pi is not None:
|
|
|
|
pi = torch.tanh(pi)
|
|
|
|
if log_pi is not None:
|
|
|
|
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
|
|
|
return mu, pi, log_pi
|
|
|
|
|
|
|
|
|
|
|
|
def weight_init(m):
|
|
|
|
"""Custom weight init for Conv2D and Linear layers."""
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
nn.init.orthogonal_(m.weight.data)
|
|
|
|
m.bias.data.fill_(0.0)
|
|
|
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
|
|
|
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
|
|
|
|
assert m.weight.size(2) == m.weight.size(3)
|
|
|
|
m.weight.data.fill_(0.0)
|
|
|
|
m.bias.data.fill_(0.0)
|
|
|
|
mid = m.weight.size(2) // 2
|
|
|
|
gain = nn.init.calculate_gain('relu')
|
|
|
|
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
|
|
|
|
|
|
|
|
|
|
|
class Actor(nn.Module):
|
|
|
|
"""MLP actor network."""
|
|
|
|
def __init__(
|
|
|
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
2019-09-23 19:24:30 +00:00
|
|
|
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.encoder = make_encoder(
|
|
|
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
2019-09-23 19:24:30 +00:00
|
|
|
num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
self.log_std_min = log_std_min
|
|
|
|
self.log_std_max = log_std_max
|
|
|
|
|
|
|
|
self.trunk = nn.Sequential(
|
|
|
|
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
|
|
|
|
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
|
|
|
nn.Linear(hidden_dim, 2 * action_shape[0])
|
|
|
|
)
|
|
|
|
|
|
|
|
self.outputs = dict()
|
|
|
|
self.apply(weight_init)
|
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
def forward(self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False):
|
|
|
|
obs, _, _ = self.encoder(obs, detach=detach_encoder)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
|
|
|
|
|
|
|
|
# constrain log_std inside [log_std_min, log_std_max]
|
2019-09-23 19:24:30 +00:00
|
|
|
log_std = torch.tanh(log_std)
|
2019-09-23 18:20:48 +00:00
|
|
|
log_std = self.log_std_min + 0.5 * (
|
|
|
|
self.log_std_max - self.log_std_min
|
|
|
|
) * (log_std + 1)
|
|
|
|
|
|
|
|
self.outputs['mu'] = mu
|
|
|
|
self.outputs['std'] = log_std.exp()
|
|
|
|
|
|
|
|
if compute_pi:
|
|
|
|
std = log_std.exp()
|
|
|
|
noise = torch.randn_like(mu)
|
|
|
|
pi = mu + noise * std
|
|
|
|
else:
|
|
|
|
pi = None
|
|
|
|
entropy = None
|
|
|
|
|
|
|
|
if compute_log_pi:
|
|
|
|
log_pi = gaussian_logprob(noise, log_std)
|
|
|
|
else:
|
|
|
|
log_pi = None
|
|
|
|
|
|
|
|
mu, pi, log_pi = squash(mu, pi, log_pi)
|
|
|
|
return mu, pi, log_pi, log_std
|
|
|
|
|
|
|
|
def log(self, L, step, log_freq=LOG_FREQ):
|
|
|
|
if step % log_freq != 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
for k, v in self.outputs.items():
|
|
|
|
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
|
|
|
|
|
|
|
L.log_param('train_actor/fc1', self.trunk[0], step)
|
|
|
|
L.log_param('train_actor/fc2', self.trunk[2], step)
|
|
|
|
L.log_param('train_actor/fc3', self.trunk[4], step)
|
|
|
|
|
|
|
|
|
|
|
|
class QFunction(nn.Module):
|
|
|
|
"""MLP for q-function."""
|
|
|
|
def __init__(self, obs_dim, action_dim, hidden_dim):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.trunk = nn.Sequential(
|
|
|
|
nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
|
|
|
|
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
|
|
|
nn.Linear(hidden_dim, 1)
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, obs, action):
|
|
|
|
assert obs.size(0) == action.size(0)
|
|
|
|
|
|
|
|
obs_action = torch.cat([obs, action], dim=1)
|
|
|
|
return self.trunk(obs_action)
|
|
|
|
|
|
|
|
|
|
|
|
class Critic(nn.Module):
|
|
|
|
"""Critic network, employes two q-functions."""
|
|
|
|
def __init__(
|
|
|
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
2019-09-23 19:24:30 +00:00
|
|
|
encoder_feature_dim, num_layers, num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
self.encoder = make_encoder(
|
|
|
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
2019-09-23 19:24:30 +00:00
|
|
|
num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
self.Q1 = QFunction(
|
|
|
|
self.encoder.feature_dim, action_shape[0], hidden_dim
|
|
|
|
)
|
|
|
|
self.Q2 = QFunction(
|
|
|
|
self.encoder.feature_dim, action_shape[0], hidden_dim
|
|
|
|
)
|
|
|
|
|
|
|
|
self.outputs = dict()
|
|
|
|
self.apply(weight_init)
|
|
|
|
|
|
|
|
def forward(self, obs, action, detach_encoder=False):
|
|
|
|
# detach_encoder allows to stop gradient propogation to encoder
|
2023-05-16 10:40:47 +00:00
|
|
|
obs, _ , _ = self.encoder(obs, detach=detach_encoder)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
q1 = self.Q1(obs, action)
|
|
|
|
q2 = self.Q2(obs, action)
|
|
|
|
|
|
|
|
self.outputs['q1'] = q1
|
|
|
|
self.outputs['q2'] = q2
|
|
|
|
|
|
|
|
return q1, q2
|
|
|
|
|
|
|
|
def log(self, L, step, log_freq=LOG_FREQ):
|
|
|
|
if step % log_freq != 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
self.encoder.log(L, step, log_freq)
|
|
|
|
|
|
|
|
for k, v in self.outputs.items():
|
|
|
|
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
|
|
|
|
|
|
|
for i in range(3):
|
|
|
|
L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
|
|
|
|
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
|
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
class CURL(nn.Module):
|
|
|
|
"""
|
|
|
|
CURL
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, obs_shape, z_dim, a_dim, batch_size, critic, critic_target, output_type="continuous"):
|
|
|
|
super(CURL, self).__init__()
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
self.encoder = critic.encoder
|
|
|
|
|
|
|
|
self.encoder_target = critic_target.encoder
|
|
|
|
|
|
|
|
self.W = nn.Parameter(torch.rand(z_dim, z_dim))
|
|
|
|
self.combine = nn.Linear(z_dim + a_dim, z_dim)
|
|
|
|
self.output_type = output_type
|
2019-09-23 18:20:48 +00:00
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
def encode(self, x, a=None, detach=False, ema=False):
|
|
|
|
"""
|
|
|
|
Encoder: z_t = e(x_t)
|
|
|
|
:param x: x_t, x y coordinates
|
|
|
|
:return: z_t, value in r2
|
|
|
|
"""
|
|
|
|
if ema:
|
|
|
|
with torch.no_grad():
|
|
|
|
z_out = self.encoder_target(x)[0]
|
|
|
|
z_out = self.combine(torch.concat((z_out,a), dim=-1))
|
|
|
|
else:
|
|
|
|
z_out = self.encoder(x)[0]
|
|
|
|
|
|
|
|
if detach:
|
|
|
|
z_out = z_out.detach()
|
|
|
|
return z_out
|
|
|
|
|
|
|
|
def compute_logits(self, z_a, z_pos):
|
|
|
|
"""
|
|
|
|
Uses logits trick for CURL:
|
|
|
|
- compute (B,B) matrix z_a (W z_pos.T)
|
|
|
|
- positives are all diagonal elements
|
|
|
|
- negatives are all other elements
|
|
|
|
- to compute loss use multiclass cross entropy with identity matrix for labels
|
|
|
|
"""
|
|
|
|
Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B)
|
|
|
|
logits = torch.matmul(z_a, Wz) # (B,B)
|
|
|
|
logits = logits - torch.max(logits, 1)[0][:, None]
|
|
|
|
return logits
|
|
|
|
|
2019-09-23 18:38:55 +00:00
|
|
|
class SacAeAgent(object):
|
|
|
|
"""SAC+AE algorithm."""
|
2019-09-23 18:20:48 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
obs_shape,
|
|
|
|
action_shape,
|
|
|
|
device,
|
|
|
|
hidden_dim=256,
|
|
|
|
discount=0.99,
|
|
|
|
init_temperature=0.01,
|
|
|
|
alpha_lr=1e-3,
|
|
|
|
alpha_beta=0.9,
|
|
|
|
actor_lr=1e-3,
|
|
|
|
actor_beta=0.9,
|
|
|
|
actor_log_std_min=-10,
|
|
|
|
actor_log_std_max=2,
|
|
|
|
actor_update_freq=2,
|
|
|
|
critic_lr=1e-3,
|
|
|
|
critic_beta=0.9,
|
|
|
|
critic_tau=0.005,
|
|
|
|
critic_target_update_freq=2,
|
2019-09-23 18:38:55 +00:00
|
|
|
encoder_type='pixel',
|
2019-09-23 18:20:48 +00:00
|
|
|
encoder_feature_dim=50,
|
|
|
|
encoder_lr=1e-3,
|
|
|
|
encoder_tau=0.005,
|
2019-09-23 18:38:55 +00:00
|
|
|
decoder_type='pixel',
|
2019-09-23 18:20:48 +00:00
|
|
|
decoder_lr=1e-3,
|
|
|
|
decoder_update_freq=1,
|
|
|
|
decoder_latent_lambda=0.0,
|
|
|
|
decoder_weight_lambda=0.0,
|
|
|
|
num_layers=4,
|
2019-09-23 18:38:55 +00:00
|
|
|
num_filters=32
|
2019-09-23 18:20:48 +00:00
|
|
|
):
|
|
|
|
self.device = device
|
|
|
|
self.discount = discount
|
|
|
|
self.critic_tau = critic_tau
|
|
|
|
self.encoder_tau = encoder_tau
|
|
|
|
self.actor_update_freq = actor_update_freq
|
|
|
|
self.critic_target_update_freq = critic_target_update_freq
|
|
|
|
self.decoder_update_freq = decoder_update_freq
|
|
|
|
self.decoder_latent_lambda = decoder_latent_lambda
|
2023-05-16 10:40:47 +00:00
|
|
|
|
|
|
|
self.transition_model = TransitionModel(
|
|
|
|
encoder_feature_dim,
|
|
|
|
hidden_dim,
|
|
|
|
action_shape[0],
|
|
|
|
encoder_feature_dim).to(device)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
self.actor = Actor(
|
|
|
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
|
|
|
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
2019-09-23 19:24:30 +00:00
|
|
|
num_layers, num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
).to(device)
|
|
|
|
|
|
|
|
self.critic = Critic(
|
|
|
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
2019-09-23 19:24:30 +00:00
|
|
|
encoder_feature_dim, num_layers, num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
).to(device)
|
|
|
|
|
|
|
|
self.critic_target = Critic(
|
|
|
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
2019-09-23 19:24:30 +00:00
|
|
|
encoder_feature_dim, num_layers, num_filters
|
2019-09-23 18:20:48 +00:00
|
|
|
).to(device)
|
|
|
|
|
|
|
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
|
|
|
|
|
|
|
# tie encoders between actor and critic
|
|
|
|
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
|
|
|
|
|
|
|
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
|
|
|
|
self.log_alpha.requires_grad = True
|
|
|
|
# set target entropy to -|A|
|
|
|
|
self.target_entropy = -np.prod(action_shape)
|
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
self.CURL = CURL(obs_shape, encoder_feature_dim, action_shape[0],
|
|
|
|
obs_shape[0], self.critic,self.critic_target, output_type='continuous').to(self.device)
|
|
|
|
|
|
|
|
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
|
|
|
|
2019-09-23 18:20:48 +00:00
|
|
|
self.decoder = None
|
|
|
|
if decoder_type != 'identity':
|
|
|
|
# create decoder
|
|
|
|
self.decoder = make_decoder(
|
2019-09-23 19:24:30 +00:00
|
|
|
decoder_type, obs_shape, encoder_feature_dim, num_layers,
|
2019-09-23 18:20:48 +00:00
|
|
|
num_filters
|
|
|
|
).to(device)
|
|
|
|
self.decoder.apply(weight_init)
|
|
|
|
|
|
|
|
# optimizer for critic encoder for reconstruction loss
|
|
|
|
self.encoder_optimizer = torch.optim.Adam(
|
|
|
|
self.critic.encoder.parameters(), lr=encoder_lr
|
|
|
|
)
|
|
|
|
|
|
|
|
# optimizer for decoder
|
|
|
|
self.decoder_optimizer = torch.optim.Adam(
|
|
|
|
self.decoder.parameters(),
|
|
|
|
lr=decoder_lr,
|
|
|
|
weight_decay=decoder_weight_lambda
|
|
|
|
)
|
|
|
|
|
|
|
|
# optimizers
|
|
|
|
self.actor_optimizer = torch.optim.Adam(
|
|
|
|
self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.critic_optimizer = torch.optim.Adam(
|
|
|
|
self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
|
|
|
|
)
|
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
self.cpc_optimizer = torch.optim.Adam(
|
|
|
|
self.CURL.parameters(), lr=encoder_lr
|
|
|
|
)
|
|
|
|
|
2019-09-23 18:20:48 +00:00
|
|
|
self.log_alpha_optimizer = torch.optim.Adam(
|
|
|
|
[self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.train()
|
|
|
|
self.critic_target.train()
|
|
|
|
|
|
|
|
def train(self, training=True):
|
|
|
|
self.training = training
|
|
|
|
self.actor.train(training)
|
|
|
|
self.critic.train(training)
|
|
|
|
if self.decoder is not None:
|
|
|
|
self.decoder.train(training)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def alpha(self):
|
|
|
|
return self.log_alpha.exp()
|
|
|
|
|
|
|
|
def select_action(self, obs):
|
|
|
|
with torch.no_grad():
|
|
|
|
obs = torch.FloatTensor(obs).to(self.device)
|
|
|
|
obs = obs.unsqueeze(0)
|
|
|
|
mu, _, _, _ = self.actor(
|
|
|
|
obs, compute_pi=False, compute_log_pi=False
|
|
|
|
)
|
|
|
|
return mu.cpu().data.numpy().flatten()
|
|
|
|
|
|
|
|
def sample_action(self, obs):
|
|
|
|
with torch.no_grad():
|
|
|
|
obs = torch.FloatTensor(obs).to(self.device)
|
|
|
|
obs = obs.unsqueeze(0)
|
|
|
|
mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
|
|
|
|
return pi.cpu().data.numpy().flatten()
|
|
|
|
|
|
|
|
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
|
|
|
with torch.no_grad():
|
|
|
|
_, policy_action, log_pi, _ = self.actor(next_obs)
|
|
|
|
target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
|
|
|
|
target_V = torch.min(target_Q1,
|
|
|
|
target_Q2) - self.alpha.detach() * log_pi
|
|
|
|
target_Q = reward + (not_done * self.discount * target_V)
|
|
|
|
|
|
|
|
# get current Q estimates
|
|
|
|
current_Q1, current_Q2 = self.critic(obs, action)
|
|
|
|
critic_loss = F.mse_loss(current_Q1,
|
|
|
|
target_Q) + F.mse_loss(current_Q2, target_Q)
|
|
|
|
L.log('train_critic/loss', critic_loss, step)
|
|
|
|
|
|
|
|
# Optimize the critic
|
|
|
|
self.critic_optimizer.zero_grad()
|
|
|
|
critic_loss.backward()
|
|
|
|
self.critic_optimizer.step()
|
|
|
|
|
|
|
|
self.critic.log(L, step)
|
|
|
|
|
|
|
|
def update_actor_and_alpha(self, obs, L, step):
|
|
|
|
# detach encoder, so we don't update it with the actor loss
|
|
|
|
_, pi, log_pi, log_std = self.actor(obs, detach_encoder=True)
|
|
|
|
actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
|
|
|
|
|
|
|
|
actor_Q = torch.min(actor_Q1, actor_Q2)
|
|
|
|
actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
|
|
|
|
|
|
|
|
L.log('train_actor/loss', actor_loss, step)
|
|
|
|
L.log('train_actor/target_entropy', self.target_entropy, step)
|
|
|
|
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
|
|
|
|
) + log_std.sum(dim=-1)
|
|
|
|
L.log('train_actor/entropy', entropy.mean(), step)
|
|
|
|
|
|
|
|
# optimize the actor
|
|
|
|
self.actor_optimizer.zero_grad()
|
|
|
|
actor_loss.backward()
|
|
|
|
self.actor_optimizer.step()
|
|
|
|
|
|
|
|
self.actor.log(L, step)
|
|
|
|
|
|
|
|
self.log_alpha_optimizer.zero_grad()
|
|
|
|
alpha_loss = (self.alpha *
|
|
|
|
(-log_pi - self.target_entropy).detach()).mean()
|
|
|
|
L.log('train_alpha/loss', alpha_loss, step)
|
|
|
|
L.log('train_alpha/value', self.alpha, step)
|
|
|
|
alpha_loss.backward()
|
|
|
|
self.log_alpha_optimizer.step()
|
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
def update_decoder(self, last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, target_obs, L, step):
|
|
|
|
h_curr, mu_h_curr, std_h_curr = self.critic.encoder(curr_obs)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
h_last, _, _ = self.critic.encoder(last_obs)
|
|
|
|
self.transition_model.init_states(last_obs.shape[0], self.device)
|
|
|
|
curr_state = self.transition_model.transition_step(h_last, last_action, self.transition_model.prev_history, last_not_done)
|
|
|
|
|
|
|
|
hist = curr_state["history"]
|
|
|
|
next_state = self.transition_model.transition_step(h_curr, action, hist, not_done)
|
|
|
|
|
|
|
|
next_state_mu = next_state["mean"]
|
|
|
|
next_state_sigma = next_state["std"]
|
|
|
|
next_state_sample = next_state["sample"]
|
|
|
|
pred_dist = torch.distributions.Normal(next_state_mu, next_state_sigma)
|
|
|
|
|
|
|
|
h, mu_h_next, logstd_h_next = self.critic.encoder(next_obs)
|
|
|
|
std_h_next = torch.exp(logstd_h_next)
|
|
|
|
enc_dist = torch.distributions.Normal(mu_h_next, std_h_next)
|
|
|
|
enc_loss = torch.mean(torch.distributions.kl.kl_divergence(enc_dist,pred_dist)) * 0.1
|
|
|
|
|
|
|
|
z_pos = self.CURL.encode(next_obs, action.detach(), ema=True)
|
|
|
|
logits = self.CURL.compute_logits(h_curr, z_pos)
|
|
|
|
labels = torch.arange(logits.shape[0]).long().to(self.device)
|
|
|
|
lb_loss = self.cross_entropy_loss(logits, labels) * 0.1
|
|
|
|
|
|
|
|
ub_loss = club_loss(h, mu_h_next, logstd_h_next, next_state_sample) * 0.1
|
|
|
|
|
2019-09-23 18:20:48 +00:00
|
|
|
if target_obs.dim() == 4:
|
|
|
|
# preprocess images to be in [-0.5, 0.5] range
|
|
|
|
target_obs = utils.preprocess_obs(target_obs)
|
2023-05-16 10:40:47 +00:00
|
|
|
|
2019-09-23 18:20:48 +00:00
|
|
|
rec_obs = self.decoder(h)
|
|
|
|
rec_loss = F.mse_loss(target_obs, rec_obs)
|
|
|
|
|
|
|
|
# add L2 penalty on latent representation
|
|
|
|
# see https://arxiv.org/pdf/1903.12436.pdf
|
|
|
|
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
|
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
loss = rec_loss + enc_loss + lb_loss + ub_loss #self.decoder_latent_lambda * latent_loss
|
2019-09-23 18:20:48 +00:00
|
|
|
self.encoder_optimizer.zero_grad()
|
|
|
|
self.decoder_optimizer.zero_grad()
|
2023-05-16 10:40:47 +00:00
|
|
|
self.cpc_optimizer.zero_grad()
|
2019-09-23 18:20:48 +00:00
|
|
|
loss.backward()
|
2023-05-16 10:40:47 +00:00
|
|
|
|
|
|
|
self.encoder_optimizer.step()
|
2019-09-23 18:20:48 +00:00
|
|
|
self.decoder_optimizer.step()
|
2023-05-16 10:40:47 +00:00
|
|
|
self.cpc_optimizer.step()
|
2019-09-23 18:20:48 +00:00
|
|
|
L.log('train_ae/ae_loss', loss, step)
|
2023-05-16 10:40:47 +00:00
|
|
|
L.log('train_ae/lb_loss', lb_loss, step)
|
|
|
|
L.log('train_ae/ub_loss', ub_loss, step)
|
|
|
|
L.log('train_ae/enc_loss', enc_loss, step)
|
|
|
|
L.log('train_ae/dec_loss', rec_loss, step)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
self.decoder.log(L, step, log_freq=LOG_FREQ)
|
|
|
|
|
|
|
|
def update(self, replay_buffer, L, step):
|
2023-05-16 10:40:47 +00:00
|
|
|
last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done = replay_buffer.sample()
|
|
|
|
#obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
2019-09-23 18:20:48 +00:00
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
L.log('train/batch_reward', last_reward.mean(), step)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
2023-05-16 10:40:47 +00:00
|
|
|
#self.update_critic(last_obs, last_action, last_reward, curr_obs, last_not_done, L, step)
|
|
|
|
self.update_critic(curr_obs, action, reward, next_obs, not_done, L, step)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
if step % self.actor_update_freq == 0:
|
2023-05-16 10:40:47 +00:00
|
|
|
#self.update_actor_and_alpha(last_obs, L, step)
|
|
|
|
self.update_actor_and_alpha(curr_obs, L, step)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
if step % self.critic_target_update_freq == 0:
|
|
|
|
utils.soft_update_params(
|
|
|
|
self.critic.Q1, self.critic_target.Q1, self.critic_tau
|
|
|
|
)
|
|
|
|
utils.soft_update_params(
|
|
|
|
self.critic.Q2, self.critic_target.Q2, self.critic_tau
|
|
|
|
)
|
|
|
|
utils.soft_update_params(
|
|
|
|
self.critic.encoder, self.critic_target.encoder,
|
|
|
|
self.encoder_tau
|
|
|
|
)
|
|
|
|
|
2019-09-24 01:22:49 +00:00
|
|
|
if self.decoder is not None and step % self.decoder_update_freq == 0:
|
2023-05-16 10:40:47 +00:00
|
|
|
self.update_decoder(last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, next_obs, L, step)
|
2019-09-23 18:20:48 +00:00
|
|
|
|
|
|
|
def save(self, model_dir, step):
|
|
|
|
torch.save(
|
|
|
|
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
|
|
|
)
|
|
|
|
torch.save(
|
|
|
|
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
|
|
|
)
|
|
|
|
if self.decoder is not None:
|
|
|
|
torch.save(
|
|
|
|
self.decoder.state_dict(),
|
|
|
|
'%s/decoder_%s.pt' % (model_dir, step)
|
|
|
|
)
|
|
|
|
|
|
|
|
def load(self, model_dir, step):
|
|
|
|
self.actor.load_state_dict(
|
|
|
|
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
|
|
|
)
|
|
|
|
self.critic.load_state_dict(
|
|
|
|
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
|
|
|
)
|
|
|
|
if self.decoder is not None:
|
|
|
|
self.decoder.load_state_dict(
|
|
|
|
torch.load('%s/decoder_%s.pt' % (model_dir, step))
|
|
|
|
)
|