updates
This commit is contained in:
parent
cb1aff9441
commit
6e6de8de9b
15
conda_env.yml
Normal file
15
conda_env.yml
Normal file
@ -0,0 +1,15 @@
|
||||
name: pytorch_sac_ae
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.6
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit=9.2
|
||||
- pip:
|
||||
- colored
|
||||
- absl-py
|
||||
- git+git://github.com/deepmind/dm_control.git
|
||||
- git+git://github.com/1nadequacy/dmc2gym.git
|
||||
- tb-nightly
|
||||
- imageio
|
76
sac_ae.py
76
sac_ae.py
@ -49,19 +49,17 @@ class Actor(nn.Module):
|
||||
"""MLP actor network."""
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters,
|
||||
freeze_encoder, stochastic
|
||||
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||
num_filters, stochastic
|
||||
num_filters
|
||||
)
|
||||
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.freeze_encoder = freeze_encoder
|
||||
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
|
||||
@ -77,13 +75,10 @@ class Actor(nn.Module):
|
||||
):
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
|
||||
if self.freeze_encoder:
|
||||
obs = obs.detach()
|
||||
|
||||
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
|
||||
|
||||
# constrain log_std inside [log_std_min, log_std_max]
|
||||
log_std = F.tanh(log_std)
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1)
|
||||
@ -138,42 +133,18 @@ class QFunction(nn.Module):
|
||||
return self.trunk(obs_action)
|
||||
|
||||
|
||||
class DynamicsModel(nn.Module):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim):
|
||||
super().__init__()
|
||||
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, state_dim)
|
||||
)
|
||||
|
||||
def forward(self, state, action):
|
||||
assert state.size(0) == action.size(0)
|
||||
|
||||
state_action = torch.cat([state, action], dim=1)
|
||||
return self.trunk(state_action)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""Critic network, employes two q-functions."""
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
||||
use_dynamics, stochastic
|
||||
encoder_feature_dim, num_layers, num_filters
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.freeze_encoder = freeze_encoder
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||
num_filters, stochastic
|
||||
)
|
||||
|
||||
if use_dynamics:
|
||||
self.forward_model = DynamicsModel(
|
||||
self.encoder.feature_dim, action_shape[0], hidden_dim
|
||||
num_filters
|
||||
)
|
||||
|
||||
self.Q1 = QFunction(
|
||||
@ -190,9 +161,6 @@ class Critic(nn.Module):
|
||||
# detach_encoder allows to stop gradient propogation to encoder
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
|
||||
if self.freeze_encoder:
|
||||
obs = obs.detach()
|
||||
|
||||
q1 = self.Q1(obs, action)
|
||||
q2 = self.Q2(obs, action)
|
||||
|
||||
@ -220,7 +188,6 @@ class SacAeAgent(object):
|
||||
def __init__(
|
||||
self,
|
||||
obs_shape,
|
||||
state_shape,
|
||||
action_shape,
|
||||
device,
|
||||
hidden_dim=256,
|
||||
@ -261,19 +228,17 @@ class SacAeAgent(object):
|
||||
self.actor = Actor(
|
||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
||||
num_layers, num_filters, freeze_encoder, stochastic
|
||||
num_layers, num_filters
|
||||
).to(device)
|
||||
|
||||
self.critic = Critic(
|
||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
||||
use_dynamics, stochastic
|
||||
encoder_feature_dim, num_layers, num_filters
|
||||
).to(device)
|
||||
|
||||
self.critic_target = Critic(
|
||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
||||
use_dynamics, stochastic
|
||||
encoder_feature_dim, num_layers, num_filters
|
||||
).to(device)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
@ -289,9 +254,8 @@ class SacAeAgent(object):
|
||||
self.decoder = None
|
||||
if decoder_type != 'identity':
|
||||
# create decoder
|
||||
shape = obs_shape if decoder_type == 'pixel' else state_shape
|
||||
self.decoder = make_decoder(
|
||||
decoder_type, shape, encoder_feature_dim, num_layers,
|
||||
decoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||
num_filters
|
||||
).to(device)
|
||||
self.decoder.apply(weight_init)
|
||||
@ -365,15 +329,6 @@ class SacAeAgent(object):
|
||||
target_Q) + F.mse_loss(current_Q2, target_Q)
|
||||
L.log('train_critic/loss', critic_loss, step)
|
||||
|
||||
# update dynamics (optional)
|
||||
if self.use_dynamics:
|
||||
h_obs = self.critic.encoder.outputs['mu']
|
||||
with torch.no_grad():
|
||||
next_latent = self.critic.encoder(next_obs)
|
||||
pred_next_latent = self.critic.forward_model(h_obs, action)
|
||||
dynamics_loss = F.mse_loss(pred_next_latent, next_latent)
|
||||
L.log('train_critic/dynamics_loss', dynamics_loss, step)
|
||||
critic_loss += dynamics_loss
|
||||
|
||||
# Optimize the critic
|
||||
self.critic_optimizer.zero_grad()
|
||||
@ -424,16 +379,7 @@ class SacAeAgent(object):
|
||||
# see https://arxiv.org/pdf/1903.12436.pdf
|
||||
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
|
||||
|
||||
# add KL penalty for VAE
|
||||
if self.decoder_kl_lambda > 0.0:
|
||||
log_std = self.critic.encoder.outputs['log_std']
|
||||
mu = self.critic.encoder.outputs['mu']
|
||||
kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp())
|
||||
kl_div = kl_div.sum(1).mean(0, True)
|
||||
else:
|
||||
kl_div = 0.0
|
||||
loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div
|
||||
|
||||
loss = rec_loss + self.decoder_latent_lambda * latent_loss
|
||||
self.encoder_optimizer.zero_grad()
|
||||
self.decoder_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
@ -445,7 +391,7 @@ class SacAeAgent(object):
|
||||
self.decoder.log(L, step, log_freq=LOG_FREQ)
|
||||
|
||||
def update(self, replay_buffer, L, step):
|
||||
obs, action, reward, next_obs, not_done, state = replay_buffer.sample()
|
||||
obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
||||
|
||||
L.log('train/batch_reward', reward.mean(), step)
|
||||
|
||||
|
3
train.py
3
train.py
@ -95,11 +95,10 @@ def evaluate(env, agent, video, num_episodes, L, step):
|
||||
L.dump(step)
|
||||
|
||||
|
||||
def make_agent(obs_shape, state_shape, action_shape, args, device):
|
||||
def make_agent(obs_shape, action_shape, args, device):
|
||||
if args.agent == 'sac_ae':
|
||||
return SacAeAgent(
|
||||
obs_shape=obs_shape,
|
||||
state_shape=state_shape,
|
||||
action_shape=action_shape,
|
||||
device=device,
|
||||
hidden_dim=args.hidden_dim,
|
||||
|
Loading…
Reference in New Issue
Block a user