This commit is contained in:
Denis Yarats 2019-09-23 12:24:30 -07:00
parent cb1aff9441
commit 6e6de8de9b
3 changed files with 27 additions and 67 deletions

15
conda_env.yml Normal file
View File

@ -0,0 +1,15 @@
name: pytorch_sac_ae
channels:
- defaults
dependencies:
- python=3.6
- pytorch
- torchvision
- cudatoolkit=9.2
- pip:
- colored
- absl-py
- git+git://github.com/deepmind/dm_control.git
- git+git://github.com/1nadequacy/dmc2gym.git
- tb-nightly
- imageio

View File

@ -49,19 +49,17 @@ class Actor(nn.Module):
"""MLP actor network."""
def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters,
freeze_encoder, stochastic
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim, num_layers,
num_filters, stochastic
num_filters
)
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.freeze_encoder = freeze_encoder
self.trunk = nn.Sequential(
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
@ -77,13 +75,10 @@ class Actor(nn.Module):
):
obs = self.encoder(obs, detach=detach_encoder)
if self.freeze_encoder:
obs = obs.detach()
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
# constrain log_std inside [log_std_min, log_std_max]
log_std = F.tanh(log_std)
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1)
@ -138,44 +133,20 @@ class QFunction(nn.Module):
return self.trunk(obs_action)
class DynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super().__init__()
self.trunk = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, state_dim)
)
def forward(self, state, action):
assert state.size(0) == action.size(0)
state_action = torch.cat([state, action], dim=1)
return self.trunk(state_action)
class Critic(nn.Module):
"""Critic network, employes two q-functions."""
def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
use_dynamics, stochastic
encoder_feature_dim, num_layers, num_filters
):
super().__init__()
self.freeze_encoder = freeze_encoder
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim, num_layers,
num_filters, stochastic
num_filters
)
if use_dynamics:
self.forward_model = DynamicsModel(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.Q1 = QFunction(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
@ -190,9 +161,6 @@ class Critic(nn.Module):
# detach_encoder allows to stop gradient propogation to encoder
obs = self.encoder(obs, detach=detach_encoder)
if self.freeze_encoder:
obs = obs.detach()
q1 = self.Q1(obs, action)
q2 = self.Q2(obs, action)
@ -220,7 +188,6 @@ class SacAeAgent(object):
def __init__(
self,
obs_shape,
state_shape,
action_shape,
device,
hidden_dim=256,
@ -261,19 +228,17 @@ class SacAeAgent(object):
self.actor = Actor(
obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
num_layers, num_filters, freeze_encoder, stochastic
num_layers, num_filters
).to(device)
self.critic = Critic(
obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
use_dynamics, stochastic
encoder_feature_dim, num_layers, num_filters
).to(device)
self.critic_target = Critic(
obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
use_dynamics, stochastic
encoder_feature_dim, num_layers, num_filters
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
@ -289,9 +254,8 @@ class SacAeAgent(object):
self.decoder = None
if decoder_type != 'identity':
# create decoder
shape = obs_shape if decoder_type == 'pixel' else state_shape
self.decoder = make_decoder(
decoder_type, shape, encoder_feature_dim, num_layers,
decoder_type, obs_shape, encoder_feature_dim, num_layers,
num_filters
).to(device)
self.decoder.apply(weight_init)
@ -365,15 +329,6 @@ class SacAeAgent(object):
target_Q) + F.mse_loss(current_Q2, target_Q)
L.log('train_critic/loss', critic_loss, step)
# update dynamics (optional)
if self.use_dynamics:
h_obs = self.critic.encoder.outputs['mu']
with torch.no_grad():
next_latent = self.critic.encoder(next_obs)
pred_next_latent = self.critic.forward_model(h_obs, action)
dynamics_loss = F.mse_loss(pred_next_latent, next_latent)
L.log('train_critic/dynamics_loss', dynamics_loss, step)
critic_loss += dynamics_loss
# Optimize the critic
self.critic_optimizer.zero_grad()
@ -424,16 +379,7 @@ class SacAeAgent(object):
# see https://arxiv.org/pdf/1903.12436.pdf
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
# add KL penalty for VAE
if self.decoder_kl_lambda > 0.0:
log_std = self.critic.encoder.outputs['log_std']
mu = self.critic.encoder.outputs['mu']
kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp())
kl_div = kl_div.sum(1).mean(0, True)
else:
kl_div = 0.0
loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div
loss = rec_loss + self.decoder_latent_lambda * latent_loss
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
loss.backward()
@ -445,7 +391,7 @@ class SacAeAgent(object):
self.decoder.log(L, step, log_freq=LOG_FREQ)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done, state = replay_buffer.sample()
obs, action, reward, next_obs, not_done = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)

View File

@ -95,11 +95,10 @@ def evaluate(env, agent, video, num_episodes, L, step):
L.dump(step)
def make_agent(obs_shape, state_shape, action_shape, args, device):
def make_agent(obs_shape, action_shape, args, device):
if args.agent == 'sac_ae':
return SacAeAgent(
obs_shape=obs_shape,
state_shape=state_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,