updates
This commit is contained in:
parent
cb1aff9441
commit
6e6de8de9b
15
conda_env.yml
Normal file
15
conda_env.yml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
name: pytorch_sac_ae
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.6
|
||||||
|
- pytorch
|
||||||
|
- torchvision
|
||||||
|
- cudatoolkit=9.2
|
||||||
|
- pip:
|
||||||
|
- colored
|
||||||
|
- absl-py
|
||||||
|
- git+git://github.com/deepmind/dm_control.git
|
||||||
|
- git+git://github.com/1nadequacy/dmc2gym.git
|
||||||
|
- tb-nightly
|
||||||
|
- imageio
|
76
sac_ae.py
76
sac_ae.py
@ -49,19 +49,17 @@ class Actor(nn.Module):
|
|||||||
"""MLP actor network."""
|
"""MLP actor network."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters,
|
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters
|
||||||
freeze_encoder, stochastic
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.encoder = make_encoder(
|
self.encoder = make_encoder(
|
||||||
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||||
num_filters, stochastic
|
num_filters
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log_std_min = log_std_min
|
self.log_std_min = log_std_min
|
||||||
self.log_std_max = log_std_max
|
self.log_std_max = log_std_max
|
||||||
self.freeze_encoder = freeze_encoder
|
|
||||||
|
|
||||||
self.trunk = nn.Sequential(
|
self.trunk = nn.Sequential(
|
||||||
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
|
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
|
||||||
@ -77,13 +75,10 @@ class Actor(nn.Module):
|
|||||||
):
|
):
|
||||||
obs = self.encoder(obs, detach=detach_encoder)
|
obs = self.encoder(obs, detach=detach_encoder)
|
||||||
|
|
||||||
if self.freeze_encoder:
|
|
||||||
obs = obs.detach()
|
|
||||||
|
|
||||||
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
|
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
|
||||||
|
|
||||||
# constrain log_std inside [log_std_min, log_std_max]
|
# constrain log_std inside [log_std_min, log_std_max]
|
||||||
log_std = F.tanh(log_std)
|
log_std = torch.tanh(log_std)
|
||||||
log_std = self.log_std_min + 0.5 * (
|
log_std = self.log_std_min + 0.5 * (
|
||||||
self.log_std_max - self.log_std_min
|
self.log_std_max - self.log_std_min
|
||||||
) * (log_std + 1)
|
) * (log_std + 1)
|
||||||
@ -138,42 +133,18 @@ class QFunction(nn.Module):
|
|||||||
return self.trunk(obs_action)
|
return self.trunk(obs_action)
|
||||||
|
|
||||||
|
|
||||||
class DynamicsModel(nn.Module):
|
|
||||||
def __init__(self, state_dim, action_dim, hidden_dim):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.trunk = nn.Sequential(
|
|
||||||
nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim, state_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, state, action):
|
|
||||||
assert state.size(0) == action.size(0)
|
|
||||||
|
|
||||||
state_action = torch.cat([state, action], dim=1)
|
|
||||||
return self.trunk(state_action)
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
class Critic(nn.Module):
|
||||||
"""Critic network, employes two q-functions."""
|
"""Critic network, employes two q-functions."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
encoder_feature_dim, num_layers, num_filters
|
||||||
use_dynamics, stochastic
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.freeze_encoder = freeze_encoder
|
|
||||||
|
|
||||||
self.encoder = make_encoder(
|
self.encoder = make_encoder(
|
||||||
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||||
num_filters, stochastic
|
num_filters
|
||||||
)
|
|
||||||
|
|
||||||
if use_dynamics:
|
|
||||||
self.forward_model = DynamicsModel(
|
|
||||||
self.encoder.feature_dim, action_shape[0], hidden_dim
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.Q1 = QFunction(
|
self.Q1 = QFunction(
|
||||||
@ -190,9 +161,6 @@ class Critic(nn.Module):
|
|||||||
# detach_encoder allows to stop gradient propogation to encoder
|
# detach_encoder allows to stop gradient propogation to encoder
|
||||||
obs = self.encoder(obs, detach=detach_encoder)
|
obs = self.encoder(obs, detach=detach_encoder)
|
||||||
|
|
||||||
if self.freeze_encoder:
|
|
||||||
obs = obs.detach()
|
|
||||||
|
|
||||||
q1 = self.Q1(obs, action)
|
q1 = self.Q1(obs, action)
|
||||||
q2 = self.Q2(obs, action)
|
q2 = self.Q2(obs, action)
|
||||||
|
|
||||||
@ -220,7 +188,6 @@ class SacAeAgent(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
obs_shape,
|
obs_shape,
|
||||||
state_shape,
|
|
||||||
action_shape,
|
action_shape,
|
||||||
device,
|
device,
|
||||||
hidden_dim=256,
|
hidden_dim=256,
|
||||||
@ -261,19 +228,17 @@ class SacAeAgent(object):
|
|||||||
self.actor = Actor(
|
self.actor = Actor(
|
||||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
||||||
num_layers, num_filters, freeze_encoder, stochastic
|
num_layers, num_filters
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
self.critic = Critic(
|
self.critic = Critic(
|
||||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
encoder_feature_dim, num_layers, num_filters
|
||||||
use_dynamics, stochastic
|
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
self.critic_target = Critic(
|
self.critic_target = Critic(
|
||||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
encoder_feature_dim, num_layers, num_filters
|
||||||
use_dynamics, stochastic
|
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||||
@ -289,9 +254,8 @@ class SacAeAgent(object):
|
|||||||
self.decoder = None
|
self.decoder = None
|
||||||
if decoder_type != 'identity':
|
if decoder_type != 'identity':
|
||||||
# create decoder
|
# create decoder
|
||||||
shape = obs_shape if decoder_type == 'pixel' else state_shape
|
|
||||||
self.decoder = make_decoder(
|
self.decoder = make_decoder(
|
||||||
decoder_type, shape, encoder_feature_dim, num_layers,
|
decoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||||
num_filters
|
num_filters
|
||||||
).to(device)
|
).to(device)
|
||||||
self.decoder.apply(weight_init)
|
self.decoder.apply(weight_init)
|
||||||
@ -365,15 +329,6 @@ class SacAeAgent(object):
|
|||||||
target_Q) + F.mse_loss(current_Q2, target_Q)
|
target_Q) + F.mse_loss(current_Q2, target_Q)
|
||||||
L.log('train_critic/loss', critic_loss, step)
|
L.log('train_critic/loss', critic_loss, step)
|
||||||
|
|
||||||
# update dynamics (optional)
|
|
||||||
if self.use_dynamics:
|
|
||||||
h_obs = self.critic.encoder.outputs['mu']
|
|
||||||
with torch.no_grad():
|
|
||||||
next_latent = self.critic.encoder(next_obs)
|
|
||||||
pred_next_latent = self.critic.forward_model(h_obs, action)
|
|
||||||
dynamics_loss = F.mse_loss(pred_next_latent, next_latent)
|
|
||||||
L.log('train_critic/dynamics_loss', dynamics_loss, step)
|
|
||||||
critic_loss += dynamics_loss
|
|
||||||
|
|
||||||
# Optimize the critic
|
# Optimize the critic
|
||||||
self.critic_optimizer.zero_grad()
|
self.critic_optimizer.zero_grad()
|
||||||
@ -424,16 +379,7 @@ class SacAeAgent(object):
|
|||||||
# see https://arxiv.org/pdf/1903.12436.pdf
|
# see https://arxiv.org/pdf/1903.12436.pdf
|
||||||
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
|
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
|
||||||
|
|
||||||
# add KL penalty for VAE
|
loss = rec_loss + self.decoder_latent_lambda * latent_loss
|
||||||
if self.decoder_kl_lambda > 0.0:
|
|
||||||
log_std = self.critic.encoder.outputs['log_std']
|
|
||||||
mu = self.critic.encoder.outputs['mu']
|
|
||||||
kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp())
|
|
||||||
kl_div = kl_div.sum(1).mean(0, True)
|
|
||||||
else:
|
|
||||||
kl_div = 0.0
|
|
||||||
loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div
|
|
||||||
|
|
||||||
self.encoder_optimizer.zero_grad()
|
self.encoder_optimizer.zero_grad()
|
||||||
self.decoder_optimizer.zero_grad()
|
self.decoder_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -445,7 +391,7 @@ class SacAeAgent(object):
|
|||||||
self.decoder.log(L, step, log_freq=LOG_FREQ)
|
self.decoder.log(L, step, log_freq=LOG_FREQ)
|
||||||
|
|
||||||
def update(self, replay_buffer, L, step):
|
def update(self, replay_buffer, L, step):
|
||||||
obs, action, reward, next_obs, not_done, state = replay_buffer.sample()
|
obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
||||||
|
|
||||||
L.log('train/batch_reward', reward.mean(), step)
|
L.log('train/batch_reward', reward.mean(), step)
|
||||||
|
|
||||||
|
3
train.py
3
train.py
@ -95,11 +95,10 @@ def evaluate(env, agent, video, num_episodes, L, step):
|
|||||||
L.dump(step)
|
L.dump(step)
|
||||||
|
|
||||||
|
|
||||||
def make_agent(obs_shape, state_shape, action_shape, args, device):
|
def make_agent(obs_shape, action_shape, args, device):
|
||||||
if args.agent == 'sac_ae':
|
if args.agent == 'sac_ae':
|
||||||
return SacAeAgent(
|
return SacAeAgent(
|
||||||
obs_shape=obs_shape,
|
obs_shape=obs_shape,
|
||||||
state_shape=state_shape,
|
|
||||||
action_shape=action_shape,
|
action_shape=action_shape,
|
||||||
device=device,
|
device=device,
|
||||||
hidden_dim=args.hidden_dim,
|
hidden_dim=args.hidden_dim,
|
||||||
|
Loading…
Reference in New Issue
Block a user