diff --git a/conda_env.yml b/conda_env.yml new file mode 100644 index 0000000..2356ac2 --- /dev/null +++ b/conda_env.yml @@ -0,0 +1,15 @@ +name: pytorch_sac_ae +channels: + - defaults +dependencies: + - python=3.6 + - pytorch + - torchvision + - cudatoolkit=9.2 + - pip: + - colored + - absl-py + - git+git://github.com/deepmind/dm_control.git + - git+git://github.com/1nadequacy/dmc2gym.git + - tb-nightly + - imageio \ No newline at end of file diff --git a/sac_ae.py b/sac_ae.py index 2dcfdb8..38b5f15 100644 --- a/sac_ae.py +++ b/sac_ae.py @@ -49,19 +49,17 @@ class Actor(nn.Module): """MLP actor network.""" def __init__( self, obs_shape, action_shape, hidden_dim, encoder_type, - encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters, - freeze_encoder, stochastic + encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters ): super().__init__() self.encoder = make_encoder( encoder_type, obs_shape, encoder_feature_dim, num_layers, - num_filters, stochastic + num_filters ) self.log_std_min = log_std_min self.log_std_max = log_std_max - self.freeze_encoder = freeze_encoder self.trunk = nn.Sequential( nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), @@ -77,13 +75,10 @@ class Actor(nn.Module): ): obs = self.encoder(obs, detach=detach_encoder) - if self.freeze_encoder: - obs = obs.detach() - mu, log_std = self.trunk(obs).chunk(2, dim=-1) # constrain log_std inside [log_std_min, log_std_max] - log_std = F.tanh(log_std) + log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * ( self.log_std_max - self.log_std_min ) * (log_std + 1) @@ -138,44 +133,20 @@ class QFunction(nn.Module): return self.trunk(obs_action) -class DynamicsModel(nn.Module): - def __init__(self, state_dim, action_dim, hidden_dim): - super().__init__() - - self.trunk = nn.Sequential( - nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), - nn.Linear(hidden_dim, state_dim) - ) - - def forward(self, state, action): - assert state.size(0) == action.size(0) - - state_action = torch.cat([state, action], dim=1) - return self.trunk(state_action) - - class Critic(nn.Module): """Critic network, employes two q-functions.""" def __init__( self, obs_shape, action_shape, hidden_dim, encoder_type, - encoder_feature_dim, num_layers, num_filters, freeze_encoder, - use_dynamics, stochastic + encoder_feature_dim, num_layers, num_filters ): super().__init__() - self.freeze_encoder = freeze_encoder self.encoder = make_encoder( encoder_type, obs_shape, encoder_feature_dim, num_layers, - num_filters, stochastic + num_filters ) - if use_dynamics: - self.forward_model = DynamicsModel( - self.encoder.feature_dim, action_shape[0], hidden_dim - ) - self.Q1 = QFunction( self.encoder.feature_dim, action_shape[0], hidden_dim ) @@ -190,9 +161,6 @@ class Critic(nn.Module): # detach_encoder allows to stop gradient propogation to encoder obs = self.encoder(obs, detach=detach_encoder) - if self.freeze_encoder: - obs = obs.detach() - q1 = self.Q1(obs, action) q2 = self.Q2(obs, action) @@ -220,7 +188,6 @@ class SacAeAgent(object): def __init__( self, obs_shape, - state_shape, action_shape, device, hidden_dim=256, @@ -261,19 +228,17 @@ class SacAeAgent(object): self.actor = Actor( obs_shape, action_shape, hidden_dim, encoder_type, encoder_feature_dim, actor_log_std_min, actor_log_std_max, - num_layers, num_filters, freeze_encoder, stochastic + num_layers, num_filters ).to(device) self.critic = Critic( obs_shape, action_shape, hidden_dim, encoder_type, - encoder_feature_dim, num_layers, num_filters, freeze_encoder, - use_dynamics, stochastic + encoder_feature_dim, num_layers, num_filters ).to(device) self.critic_target = Critic( obs_shape, action_shape, hidden_dim, encoder_type, - encoder_feature_dim, num_layers, num_filters, freeze_encoder, - use_dynamics, stochastic + encoder_feature_dim, num_layers, num_filters ).to(device) self.critic_target.load_state_dict(self.critic.state_dict()) @@ -289,9 +254,8 @@ class SacAeAgent(object): self.decoder = None if decoder_type != 'identity': # create decoder - shape = obs_shape if decoder_type == 'pixel' else state_shape self.decoder = make_decoder( - decoder_type, shape, encoder_feature_dim, num_layers, + decoder_type, obs_shape, encoder_feature_dim, num_layers, num_filters ).to(device) self.decoder.apply(weight_init) @@ -365,15 +329,6 @@ class SacAeAgent(object): target_Q) + F.mse_loss(current_Q2, target_Q) L.log('train_critic/loss', critic_loss, step) - # update dynamics (optional) - if self.use_dynamics: - h_obs = self.critic.encoder.outputs['mu'] - with torch.no_grad(): - next_latent = self.critic.encoder(next_obs) - pred_next_latent = self.critic.forward_model(h_obs, action) - dynamics_loss = F.mse_loss(pred_next_latent, next_latent) - L.log('train_critic/dynamics_loss', dynamics_loss, step) - critic_loss += dynamics_loss # Optimize the critic self.critic_optimizer.zero_grad() @@ -424,16 +379,7 @@ class SacAeAgent(object): # see https://arxiv.org/pdf/1903.12436.pdf latent_loss = (0.5 * h.pow(2).sum(1)).mean() - # add KL penalty for VAE - if self.decoder_kl_lambda > 0.0: - log_std = self.critic.encoder.outputs['log_std'] - mu = self.critic.encoder.outputs['mu'] - kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp()) - kl_div = kl_div.sum(1).mean(0, True) - else: - kl_div = 0.0 - loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div - + loss = rec_loss + self.decoder_latent_lambda * latent_loss self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() loss.backward() @@ -445,7 +391,7 @@ class SacAeAgent(object): self.decoder.log(L, step, log_freq=LOG_FREQ) def update(self, replay_buffer, L, step): - obs, action, reward, next_obs, not_done, state = replay_buffer.sample() + obs, action, reward, next_obs, not_done = replay_buffer.sample() L.log('train/batch_reward', reward.mean(), step) diff --git a/train.py b/train.py index bd03f7f..35b740e 100644 --- a/train.py +++ b/train.py @@ -95,11 +95,10 @@ def evaluate(env, agent, video, num_episodes, L, step): L.dump(step) -def make_agent(obs_shape, state_shape, action_shape, args, device): +def make_agent(obs_shape, action_shape, args, device): if args.agent == 'sac_ae': return SacAeAgent( obs_shape=obs_shape, - state_shape=state_shape, action_shape=action_shape, device=device, hidden_dim=args.hidden_dim,