From f2e39f7a683a4b1491120ab733db0e5abf97d3f2 Mon Sep 17 00:00:00 2001 From: Denis Yarats Date: Mon, 23 Sep 2019 11:38:55 -0700 Subject: [PATCH] changes --- README.md | 18 ++- ddpg.py | 209 ----------------------------------- decoder.py | 40 +------ encoder.py | 80 ++------------ logger.py | 24 ++-- run.sh | 21 ---- sac.py => sac_ae.py | 26 ++--- td3.py | 259 -------------------------------------------- train.py | 87 ++------------- utils.py | 20 +--- 10 files changed, 55 insertions(+), 729 deletions(-) delete mode 100644 ddpg.py delete mode 100755 run.sh rename sac.py => sac_ae.py (96%) delete mode 100644 td3.py diff --git a/README.md b/README.md index 5573476..b859b41 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,20 @@ -# Soft Actor-Critic implementaiton in PyTorch +# SAC+AE implementaiton in PyTorch +## Requirements -## Running locally -To train SAC locally one can use provided `run_local.sh` script (change it to modify particular arguments): +## Instructions +To train an SAC+AE agent on the `cheetah run` task from image-based observations run: ``` -./run_local.sh +python train.py \ + --domain_name cheetah \ + --task_name run \ + --encoder_type pixel \ + --decoder_type pixel \ + --action_repeat 4 \ + --save_video \ + --save_tb \ + --work_dir ./runs/cheetah_run/sac_ae \ + --seed 1 ``` This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command: ``` diff --git a/ddpg.py b/ddpg.py deleted file mode 100644 index 2b03972..0000000 --- a/ddpg.py +++ /dev/null @@ -1,209 +0,0 @@ -# Code is taken from https://github.com/sfujim/TD3 with slight modifications - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -import utils -from encoder import make_encoder - -LOG_FREQ = 10000 - - -class Actor(nn.Module): - def __init__( - self, obs_shape, action_shape, encoder_type, encoder_feature_dim - ): - super().__init__() - - self.encoder = make_encoder( - encoder_type, obs_shape, encoder_feature_dim - ) - - self.l1 = nn.Linear(self.encoder.feature_dim, 400) - self.l2 = nn.Linear(400, 300) - self.l3 = nn.Linear(300, action_shape[0]) - - self.outputs = dict() - - def forward(self, obs, detach_encoder=False): - obs = self.encoder(obs, detach=detach_encoder) - h = F.relu(self.l1(obs)) - h = F.relu(self.l2(h)) - action = torch.tanh(self.l3(h)) - self.outputs['mu'] = action - return action - - def log(self, L, step, log_freq=LOG_FREQ): - if step % log_freq != 0: - return - - for k, v in self.outputs.items(): - L.log_histogram('train_actor/%s_hist' % k, v, step) - - L.log_param('train_actor/fc1', self.l1, step) - L.log_param('train_actor/fc2', self.l2, step) - L.log_param('train_actor/fc3', self.l3, step) - - -class Critic(nn.Module): - def __init__( - self, obs_shape, action_shape, encoder_type, encoder_feature_dim - ): - super().__init__() - - self.encoder = make_encoder( - encoder_type, obs_shape, encoder_feature_dim - ) - - self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) - self.l2 = nn.Linear(400, 300) - self.l3 = nn.Linear(300, 1) - - self.outputs = dict() - - def forward(self, obs, action, detach_encoder=False): - obs = self.encoder(obs, detach=detach_encoder) - obs_action = torch.cat([obs, action], dim=1) - h = F.relu(self.l1(obs_action)) - h = F.relu(self.l2(h)) - q = self.l3(h) - self.outputs['q'] = q - return q - - def log(self, L, step, log_freq=LOG_FREQ): - if step % log_freq != 0: - return - - self.encoder.log(L, step, log_freq) - - for k, v in self.outputs.items(): - L.log_histogram('train_critic/%s_hist' % k, v, step) - - L.log_param('train_critic/fc1', self.l1, step) - L.log_param('train_critic/fc2', self.l2, step) - L.log_param('train_critic/fc3', self.l3, step) - - -class DDPGAgent(object): - def __init__( - self, - obs_shape, - action_shape, - device, - discount=0.99, - tau=0.005, - actor_lr=1e-3, - critic_lr=1e-3, - encoder_type='identity', - encoder_feature_dim=50 - ): - self.device = device - self.discount = discount - self.tau = tau - - # models - self.actor = Actor( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - - self.critic = Critic( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - - self.actor.encoder.copy_conv_weights_from(self.critic.encoder) - - self.actor_target = Actor( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - self.actor_target.load_state_dict(self.actor.state_dict()) - - self.critic_target = Critic( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - self.critic_target.load_state_dict(self.critic.state_dict()) - - # optimizers - self.actor_optimizer = torch.optim.Adam( - self.actor.parameters(), lr=actor_lr - ) - - self.critic_optimizer = torch.optim.Adam( - self.critic.parameters(), lr=critic_lr - ) - - self.train() - self.critic_target.train() - self.actor_target.train() - - def train(self, training=True): - self.training = training - self.actor.train(training) - self.critic.train(training) - - def select_action(self, obs): - with torch.no_grad(): - obs = torch.FloatTensor(obs).to(self.device) - obs = obs.unsqueeze(0) - action = self.actor(obs) - return action.cpu().data.numpy().flatten() - - def sample_action(self, obs): - return self.select_action(obs) - - def update_critic(self, obs, action, reward, next_obs, not_done, L, step): - with torch.no_grad(): - target_Q = self.critic_target( - next_obs, self.actor_target(next_obs) - ) - target_Q = reward + (not_done * self.discount * target_Q) - - current_Q = self.critic(obs, action) - - critic_loss = F.mse_loss(current_Q, target_Q) - L.log('train_critic/loss', critic_loss, step) - - self.critic_optimizer.zero_grad() - critic_loss.backward() - self.critic_optimizer.step() - - self.critic.log(L, step) - - def update_actor(self, obs, L, step): - action = self.actor(obs, detach_encoder=True) - actor_Q = self.critic(obs, action, detach_encoder=True) - actor_loss = -actor_Q.mean() - - self.actor_optimizer.zero_grad() - actor_loss.backward() - self.actor_optimizer.step() - - self.actor.log(L, step) - - def update(self, replay_buffer, L, step): - obs, action, reward, next_obs, not_done = replay_buffer.sample() - - L.log('train/batch_reward', reward.mean(), step) - - self.update_critic(obs, action, reward, next_obs, not_done, L, step) - self.update_actor(obs, L, step) - - utils.soft_update_params(self.critic, self.critic_target, self.tau) - utils.soft_update_params(self.actor, self.actor_target, self.tau) - - def save(self, model_dir, step): - torch.save( - self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) - ) - torch.save( - self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) - ) - - def load(self, model_dir, step): - self.actor.load_state_dict( - torch.load('%s/actor_%s.pt' % (model_dir, step)) - ) - self.critic.load_state_dict( - torch.load('%s/critic_%s.pt' % (model_dir, step)) - ) diff --git a/decoder.py b/decoder.py index 3fbb4a7..7e501c8 100644 --- a/decoder.py +++ b/decoder.py @@ -62,45 +62,13 @@ class PixelDecoder(nn.Module): L.log_param('train_decoder/fc', self.fc, step) -class StateDecoder(nn.Module): - def __init__(self, obs_shape, feature_dim): - super().__init__() - - assert len(obs_shape) == 1 - - self.trunk = nn.Sequential( - nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024), - nn.ReLU(), nn.Linear(1024, obs_shape[0]), nn.ReLU() - ) - - self.outputs = dict() - - def forward(self, obs, detach=False): - h = self.trunk(obs) - if detach: - h = h.detach() - self.outputs['h'] = h - return h - - def log(self, L, step, log_freq): - if step % log_freq != 0: - return - - L.log_param('train_encoder/fc1', self.trunk[0], step) - L.log_param('train_encoder/fc2', self.trunk[2], step) - for k, v in self.outputs.items(): - L.log_histogram('train_encoder/%s_hist' % k, v, step) - - -_AVAILABLE_DECODERS = {'pixel': PixelDecoder, 'state': StateDecoder} +_AVAILABLE_DECODERS = {'pixel': PixelDecoder} def make_decoder( decoder_type, obs_shape, feature_dim, num_layers, num_filters ): assert decoder_type in _AVAILABLE_DECODERS - if decoder_type == 'pixel': - return _AVAILABLE_DECODERS[decoder_type]( - obs_shape, feature_dim, num_layers, num_filters - ) - return _AVAILABLE_DECODERS[decoder_type](obs_shape, feature_dim) + return _AVAILABLE_DECODERS[decoder_type]( + obs_shape, feature_dim, num_layers, num_filters + ) diff --git a/encoder.py b/encoder.py index 0f2e581..b137a62 100644 --- a/encoder.py +++ b/encoder.py @@ -13,21 +13,13 @@ OUT_DIM = {2: 39, 4: 35, 6: 31} class PixelEncoder(nn.Module): """Convolutional encoder of pixels observations.""" - def __init__( - self, - obs_shape, - feature_dim, - num_layers=2, - num_filters=32, - stochastic=False - ): + def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32): super().__init__() assert len(obs_shape) == 3 self.feature_dim = feature_dim self.num_layers = num_layers - self.stochastic = stochastic self.convs = nn.ModuleList( [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] @@ -39,13 +31,6 @@ class PixelEncoder(nn.Module): self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) self.ln = nn.LayerNorm(self.feature_dim) - if self.stochastic: - self.log_std_min = -10 - self.log_std_max = 2 - self.fc_log_std = nn.Linear( - num_filters * out_dim * out_dim, self.feature_dim - ) - self.outputs = dict() def reparameterize(self, mu, logstd): @@ -80,17 +65,6 @@ class PixelEncoder(nn.Module): self.outputs['ln'] = h_norm out = torch.tanh(h_norm) - - if self.stochastic: - self.outputs['mu'] = out - log_std = torch.tanh(self.fc_log_std(h)) - # normalize - log_std = self.log_std_min + 0.5 * ( - self.log_std_max - self.log_std_min - ) * (log_std + 1) - out = self.reparameterize(out, log_std) - self.outputs['log_std'] = log_std - self.outputs['tanh'] = out return out @@ -116,42 +90,8 @@ class PixelEncoder(nn.Module): L.log_param('train_encoder/ln', self.ln, step) -class StateEncoder(nn.Module): - def __init__(self, obs_shape, feature_dim): - super().__init__() - - assert len(obs_shape) == 1 - self.feature_dim = feature_dim - - self.trunk = nn.Sequential( - nn.Linear(obs_shape[0], 256), nn.ReLU(), - nn.Linear(256, feature_dim), nn.ReLU() - ) - - self.outputs = dict() - - def forward(self, obs, detach=False): - h = self.trunk(obs) - if detach: - h = h.detach() - self.outputs['h'] = h - return h - - def copy_conv_weights_from(self, source): - pass - - def log(self, L, step, log_freq): - if step % log_freq != 0: - return - - L.log_param('train_encoder/fc1', self.trunk[0], step) - L.log_param('train_encoder/fc2', self.trunk[2], step) - for k, v in self.outputs.items(): - L.log_histogram('train_encoder/%s_hist' % k, v, step) - - class IdentityEncoder(nn.Module): - def __init__(self, obs_shape, feature_dim): + def __init__(self, obs_shape, feature_dim, num_layers, num_filters): super().__init__() assert len(obs_shape) == 1 @@ -167,19 +107,13 @@ class IdentityEncoder(nn.Module): pass -_AVAILABLE_ENCODERS = { - 'pixel': PixelEncoder, - 'state': StateEncoder, - 'identity': IdentityEncoder -} +_AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder} def make_encoder( - encoder_type, obs_shape, feature_dim, num_layers, num_filters, stochastic + encoder_type, obs_shape, feature_dim, num_layers, num_filters ): assert encoder_type in _AVAILABLE_ENCODERS - if encoder_type == 'pixel': - return _AVAILABLE_ENCODERS[encoder_type]( - obs_shape, feature_dim, num_layers, num_filters, stochastic - ) - return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim) + return _AVAILABLE_ENCODERS[encoder_type]( + obs_shape, feature_dim, num_layers, num_filters + ) diff --git a/logger.py b/logger.py index 93ff12e..8e31fd4 100644 --- a/logger.py +++ b/logger.py @@ -8,19 +8,15 @@ import torchvision import numpy as np from termcolor import colored - FORMAT_CONFIG = { 'rl': { - 'train': [('episode', 'E', 'int'), - ('step', 'S', 'int'), - ('duration', 'D', 'time'), - ('episode_reward', 'R', 'float'), - ('batch_reward', 'BR', 'float'), - ('actor_loss', 'ALOSS', 'float'), - ('critic_loss', 'CLOSS', 'float'), - ('ae_loss', 'RLOSS', 'float')], - 'eval': [('step', 'S', 'int'), - ('episode_reward', 'ER', 'float')] + 'train': [ + ('episode', 'E', 'int'), ('step', 'S', 'int'), + ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), + ('batch_reward', 'BR', 'float'), ('actor_loss', 'ALOSS', 'float'), + ('critic_loss', 'CLOSS', 'float'), ('ae_loss', 'RLOSS', 'float') + ], + 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')] } } @@ -106,10 +102,12 @@ class Logger(object): self._sw = None self._train_mg = MetersGroup( os.path.join(log_dir, 'train.log'), - formating=FORMAT_CONFIG[config]['train']) + formating=FORMAT_CONFIG[config]['train'] + ) self._eval_mg = MetersGroup( os.path.join(log_dir, 'eval.log'), - formating=FORMAT_CONFIG[config]['eval']) + formating=FORMAT_CONFIG[config]['eval'] + ) def _try_sw_log(self, key, value, step): if self._sw is not None: diff --git a/run.sh b/run.sh deleted file mode 100755 index 78d2e41..0000000 --- a/run.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -DOMAIN=cheetah -TASK=run -ACTION_REPEAT=4 -ENCODER_TYPE=pixel -ENCODER_TYPE=pixel - - -WORK_DIR=./runs - -python train.py \ - --domain_name ${DOMAIN} \ - --task_name ${TASK} \ - --encoder_type ${ENCODER_TYPE} \ - --decoder_type ${DECODER_TYPE} \ - --action_repeat ${ACTION_REPEAT} \ - --save_video \ - --save_tb \ - --work_dir ${WORK_DIR}/${DOMAIN}_{TASK}/_ae_encoder_${ENCODER_TYPE}_decoder_{ENCODER_TYPE} \ - --seed 1 diff --git a/sac.py b/sac_ae.py similarity index 96% rename from sac.py rename to sac_ae.py index d1a18f6..2dcfdb8 100644 --- a/sac.py +++ b/sac_ae.py @@ -215,8 +215,8 @@ class Critic(nn.Module): L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) -class SACAgent(object): - """Soft Actor-Critic algorithm.""" +class SacAeAgent(object): + """SAC+AE algorithm.""" def __init__( self, obs_shape, @@ -237,20 +237,17 @@ class SACAgent(object): critic_beta=0.9, critic_tau=0.005, critic_target_update_freq=2, - encoder_type='identity', + encoder_type='pixel', encoder_feature_dim=50, encoder_lr=1e-3, encoder_tau=0.005, - decoder_type='identity', + decoder_type='pixel', decoder_lr=1e-3, decoder_update_freq=1, decoder_latent_lambda=0.0, decoder_weight_lambda=0.0, - decoder_kl_lambda=0.0, num_layers=4, - num_filters=32, - freeze_encoder=False, - use_dynamics=False + num_filters=32 ): self.device = device self.discount = discount @@ -260,11 +257,6 @@ class SACAgent(object): self.critic_target_update_freq = critic_target_update_freq self.decoder_update_freq = decoder_update_freq self.decoder_latent_lambda = decoder_latent_lambda - self.decoder_kl_lambda = decoder_kl_lambda - self.decoder_type = decoder_type - self.use_dynamics = use_dynamics - - stochastic = decoder_kl_lambda > 0.0 self.actor = Actor( obs_shape, action_shape, hidden_dim, encoder_type, @@ -420,9 +412,6 @@ class SACAgent(object): self.log_alpha_optimizer.step() def update_decoder(self, obs, target_obs, L, step): - if self.decoder is None: - return - h = self.critic.encoder(obs) if target_obs.dim() == 4: @@ -477,9 +466,8 @@ class SACAgent(object): self.encoder_tau ) - if step % self.decoder_update_freq == 0: - target = obs if self.decoder_type == 'pixel' else state - self.update_decoder(obs, target, L, step) + if self.decoder is None and step % self.decoder_update_freq == 0: + self.update_decoder(obs, obs, L, step) def save(self, model_dir, step): torch.save( diff --git a/td3.py b/td3.py deleted file mode 100644 index db6aaf5..0000000 --- a/td3.py +++ /dev/null @@ -1,259 +0,0 @@ -# Code is taken from https://github.com/sfujim/TD3 with slight modifications - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -import utils -from encoder import make_encoder - -LOG_FREQ = 10000 - - -class Actor(nn.Module): - def __init__( - self, obs_shape, action_shape, encoder_type, encoder_feature_dim - ): - super().__init__() - - self.encoder = make_encoder( - encoder_type, obs_shape, encoder_feature_dim - ) - - self.l1 = nn.Linear(self.encoder.feature_dim, 400) - self.l2 = nn.Linear(400, 300) - self.l3 = nn.Linear(300, action_shape[0]) - - self.outputs = dict() - - def forward(self, obs, detach_encoder=False): - obs = self.encoder(obs, detach=detach_encoder) - h = F.relu(self.l1(obs)) - h = F.relu(self.l2(h)) - action = torch.tanh(self.l3(h)) - self.outputs['mu'] = action - return action - - def log(self, L, step, log_freq=LOG_FREQ): - if step % log_freq != 0: - return - - for k, v in self.outputs.items(): - L.log_histogram('train_actor/%s_hist' % k, v, step) - - L.log_param('train_actor/fc1', self.l1, step) - L.log_param('train_actor/fc2', self.l2, step) - L.log_param('train_actor/fc3', self.l3, step) - - -class Critic(nn.Module): - def __init__( - self, obs_shape, action_shape, encoder_type, encoder_feature_dim - ): - super().__init__() - - self.encoder = make_encoder( - encoder_type, obs_shape, encoder_feature_dim - ) - - # Q1 architecture - self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) - self.l2 = nn.Linear(400, 300) - self.l3 = nn.Linear(300, 1) - - # Q2 architecture - self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) - self.l5 = nn.Linear(400, 300) - self.l6 = nn.Linear(300, 1) - - self.outputs = dict() - - def forward(self, obs, action, detach_encoder=False): - obs = self.encoder(obs, detach=detach_encoder) - - obs_action = torch.cat([obs, action], 1) - - h1 = F.relu(self.l1(obs_action)) - h1 = F.relu(self.l2(h1)) - q1 = self.l3(h1) - - h2 = F.relu(self.l4(obs_action)) - h2 = F.relu(self.l5(h2)) - q2 = self.l6(h2) - - self.outputs['q1'] = q1 - self.outputs['q2'] = q2 - - return q1, q2 - - def Q1(self, obs, action, detach_encoder=False): - obs = self.encoder(obs, detach=detach_encoder) - - obs_action = torch.cat([obs, action], 1) - - h1 = F.relu(self.l1(obs_action)) - h1 = F.relu(self.l2(h1)) - q1 = self.l3(h1) - return q1 - - def log(self, L, step, log_freq=LOG_FREQ): - if step % log_freq != 0: return - - self.encoder.log(L, step, log_freq) - - for k, v in self.outputs.items(): - L.log_histogram('train_critic/%s_hist' % k, v, step) - - L.log_param('train_critic/q1_fc1', self.l1, step) - L.log_param('train_critic/q1_fc2', self.l2, step) - L.log_param('train_critic/q1_fc3', self.l3, step) - L.log_param('train_critic/q1_fc4', self.l4, step) - L.log_param('train_critic/q1_fc5', self.l5, step) - L.log_param('train_critic/q1_fc6', self.l6, step) - - -class TD3Agent(object): - def __init__( - self, - obs_shape, - action_shape, - device, - discount=0.99, - tau=0.005, - policy_noise=0.2, - noise_clip=0.5, - expl_noise=0.1, - actor_lr=1e-3, - critic_lr=1e-3, - encoder_type='identity', - encoder_feature_dim=50, - actor_update_freq=2, - target_update_freq=2, - ): - self.device = device - self.discount = discount - self.tau = tau - self.policy_noise = policy_noise - self.noise_clip = noise_clip - self.expl_noise = expl_noise - self.actor_update_freq = actor_update_freq - self.target_update_freq = target_update_freq - - # models - self.actor = Actor( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - - self.critic = Critic( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - - self.actor.encoder.copy_conv_weights_from(self.critic.encoder) - - self.actor_target = Actor( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - self.actor_target.load_state_dict(self.actor.state_dict()) - - self.critic_target = Critic( - obs_shape, action_shape, encoder_type, encoder_feature_dim - ).to(device) - self.critic_target.load_state_dict(self.critic.state_dict()) - - # optimizers - self.actor_optimizer = torch.optim.Adam( - self.actor.parameters(), lr=actor_lr - ) - - self.critic_optimizer = torch.optim.Adam( - self.critic.parameters(), lr=critic_lr - ) - - self.train() - self.critic_target.train() - self.actor_target.train() - - def train(self, training=True): - self.training = training - self.actor.train(training) - self.critic.train(training) - - def select_action(self, obs): - with torch.no_grad(): - obs = torch.FloatTensor(obs).to(self.device) - obs = obs.unsqueeze(0) - action = self.actor(obs) - return action.cpu().data.numpy().flatten() - - def sample_action(self, obs): - with torch.no_grad(): - obs = torch.FloatTensor(obs).to(self.device) - obs = obs.unsqueeze(0) - action = self.actor(obs) - noise = torch.randn_like(action) * self.expl_noise - action = (action + noise).clamp(-1.0, 1.0) - return action.cpu().data.numpy().flatten() - - def update_critic(self, obs, action, reward, next_obs, not_done, L, step): - with torch.no_grad(): - noise = torch.randn_like(action).to(self.device) * self.policy_noise - noise = noise.clamp(-self.noise_clip, self.noise_clip) - next_action = self.actor_target(next_obs) + noise - next_action = next_action.clamp(-1.0, 1.0) - target_Q1, target_Q2 = self.critic_target(next_obs, next_action) - target_Q = torch.min(target_Q1, target_Q2) - target_Q = reward + (not_done * self.discount * target_Q) - - current_Q1, current_Q2 = self.critic(obs, action) - - critic_loss = F.mse_loss(current_Q1, - target_Q) + F.mse_loss(current_Q2, target_Q) - L.log('train_critic/loss', critic_loss, step) - - self.critic_optimizer.zero_grad() - critic_loss.backward() - self.critic_optimizer.step() - - self.critic.log(L, step) - - def update_actor(self, obs, L, step): - action = self.actor(obs, detach_encoder=True) - actor_Q = self.critic.Q1(obs, action, detach_encoder=True) - actor_loss = -actor_Q.mean() - - self.actor_optimizer.zero_grad() - actor_loss.backward() - self.actor_optimizer.step() - - self.actor.log(L, step) - - def update(self, replay_buffer, L, step): - obs, action, reward, next_obs, not_done = replay_buffer.sample() - - L.log('train/batch_reward', reward.mean(), step) - - self.update_critic(obs, action, reward, next_obs, not_done, L, step) - - if step % self.actor_update_freq == 0: - self.update_actor(obs, L, step) - - if step % self.target_update_freq == 0: - utils.soft_update_params(self.critic, self.critic_target, self.tau) - utils.soft_update_params(self.actor, self.actor_target, self.tau) - - def save(self, model_dir, step): - torch.save( - self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) - ) - torch.save( - self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) - ) - - def load(self, model_dir, step): - self.actor.load_state_dict( - torch.load('%s/actor_%s.pt' % (model_dir, step)) - ) - self.critic.load_state_dict( - torch.load('%s/critic_%s.pt' % (model_dir, step)) - ) \ No newline at end of file diff --git a/train.py b/train.py index 2af856d..bd03f7f 100644 --- a/train.py +++ b/train.py @@ -15,9 +15,7 @@ import utils from logger import Logger from video import VideoRecorder -from sac import SACAgent -from td3 import TD3Agent -from ddpg import DDPGAgent +from sac_ae import SacAeAgent def parse_args(): @@ -31,7 +29,7 @@ def parse_args(): # replay buffer parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) # train - parser.add_argument('--agent', default='sac', type=str) + parser.add_argument('--agent', default='sac_ae', type=str) parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int) parser.add_argument('--batch_size', default=512, type=int) @@ -51,30 +49,22 @@ def parse_args(): parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_update_freq', default=2, type=int) # encoder/decoder - parser.add_argument('--encoder_type', default='identity', type=str) + parser.add_argument('--encoder_type', default='pixel', type=str) parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float) - parser.add_argument('--decoder_type', default='identity', type=str) + parser.add_argument('--decoder_type', default='pixel', type=str) parser.add_argument('--decoder_lr', default=1e-3, type=float) parser.add_argument('--decoder_update_freq', default=1, type=int) parser.add_argument('--decoder_latent_lambda', default=0.0, type=float) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) - parser.add_argument('--decoder_kl_lambda', default=0.0, type=float) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) - parser.add_argument('--freeze_encoder', default=False, action='store_true') - parser.add_argument('--use_dynamics', default=False, action='store_true') # sac parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--alpha_lr', default=1e-3, type=float) parser.add_argument('--alpha_beta', default=0.9, type=float) - # td3 - parser.add_argument('--policy_noise', default=0.2, type=float) - parser.add_argument('--expl_noise', default=0.1, type=float) - parser.add_argument('--noise_clip', default=0.5, type=float) - parser.add_argument('--tau', default=0.005, type=float) # misc parser.add_argument('--seed', default=1, type=int) parser.add_argument('--work_dir', default='.', type=str) @@ -82,8 +72,6 @@ def parse_args(): parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_buffer', default=False, action='store_true') parser.add_argument('--save_video', default=False, action='store_true') - parser.add_argument('--pretrained_info', default=None, type=str) - parser.add_argument('--pretrained_decoder', default=False, action='store_true') args = parser.parse_args() return args @@ -108,8 +96,8 @@ def evaluate(env, agent, video, num_episodes, L, step): def make_agent(obs_shape, state_shape, action_shape, args, device): - if args.agent == 'sac': - return SACAgent( + if args.agent == 'sac_ae': + return SacAeAgent( obs_shape=obs_shape, state_shape=state_shape, action_shape=action_shape, @@ -137,63 +125,13 @@ def make_agent(obs_shape, state_shape, action_shape, args, device): decoder_update_freq=args.decoder_update_freq, decoder_latent_lambda=args.decoder_latent_lambda, decoder_weight_lambda=args.decoder_weight_lambda, - decoder_kl_lambda=args.decoder_kl_lambda, num_layers=args.num_layers, - num_filters=args.num_filters, - freeze_encoder=args.freeze_encoder, - use_dynamics=args.use_dynamics - ) - elif args.agent == 'td3': - return TD3Agent( - obs_shape=obs_shape, - action_shape=action_shape, - device=device, - discount=args.discount, - tau=args.tau, - policy_noise=args.policy_noise, - noise_clip=args.noise_clip, - expl_noise=args.expl_noise, - actor_lr=args.actor_lr, - critic_lr=args.critic_lr, - encoder_type=args.encoder_type, - encoder_feature_dim=args.encoder_feature_dim, - actor_update_freq=args.actor_update_freq, - target_update_freq=args.critic_target_update_freq - ) - elif args.agent == 'ddpg': - return DDPGAgent( - obs_shape=obs_shape, - action_shape=action_shape, - device=device, - discount=args.discount, - tau=args.tau, - actor_lr=args.actor_lr, - critic_lr=args.critic_lr, - encoder_type=args.encoder_type, - encoder_feature_dim=args.encoder_feature_dim + num_filters=args.num_filters ) else: assert 'agent is not supported: %s' % args.agent -def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder): - path, version = pretrained_info.split(':') - - pretrained_agent = copy.deepcopy(agent) - pretrained_agent.load(path, int(version)) - agent.critic.encoder.load_state_dict( - pretrained_agent.critic.encoder.state_dict() - ) - agent.actor.encoder.load_state_dict( - pretrained_agent.actor.encoder.state_dict() - ) - - if pretrained_decoder: - agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict()) - - return agent - - def main(): args = parse_args() utils.set_seed_everywhere(args.seed) @@ -232,7 +170,6 @@ def main(): replay_buffer = utils.ReplayBuffer( obs_shape=env.observation_space.shape, - state_shape=env.state_space.shape, action_shape=env.action_space.shape, capacity=args.replay_buffer_capacity, batch_size=args.batch_size, @@ -241,17 +178,11 @@ def main(): agent = make_agent( obs_shape=env.observation_space.shape, - state_shape=env.state_space.shape, action_shape=env.action_space.shape, args=args, device=device ) - if args.pretrained_info is not None: - agent = load_pretrained_encoder( - agent, args.pretrained_info, args.pretrained_decoder - ) - L = Logger(args.work_dir, use_tb=args.save_tb) episode, episode_reward, done = 0, 0, True @@ -295,9 +226,7 @@ def main(): for _ in range(num_updates): agent.update(replay_buffer, L, step) - state = env.env.env._current_state next_obs, reward, done, _ = env.step(action) - next_state = env.env.env._current_state.shape # allow infinit bootstrap done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( @@ -305,7 +234,7 @@ def main(): ) episode_reward += reward - replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state) + replay_buffer.add(obs, action, reward, next_obs, done_bool) obs = next_obs episode_step += 1 diff --git a/utils.py b/utils.py index 4f0c6c1..067715c 100644 --- a/utils.py +++ b/utils.py @@ -67,10 +67,7 @@ def preprocess_obs(obs, bits=5): class ReplayBuffer(object): """Buffer to store environment transitions.""" - def __init__( - self, obs_shape, state_shape, action_shape, capacity, batch_size, - device - ): + def __init__(self, obs_shape, action_shape, capacity, batch_size, device): self.capacity = capacity self.batch_size = batch_size self.device = device @@ -83,21 +80,17 @@ class ReplayBuffer(object): self.actions = np.empty((capacity, *action_shape), dtype=np.float32) self.rewards = np.empty((capacity, 1), dtype=np.float32) self.not_dones = np.empty((capacity, 1), dtype=np.float32) - self.states = np.empty((capacity, *state_shape), dtype=np.float32) - self.next_states = np.empty((capacity, *state_shape), dtype=np.float32) self.idx = 0 self.last_save = 0 self.full = False - def add(self, obs, action, reward, next_obs, done, state, next_state): + def add(self, obs, action, reward, next_obs, done): np.copyto(self.obses[self.idx], obs) np.copyto(self.actions[self.idx], action) np.copyto(self.rewards[self.idx], reward) np.copyto(self.next_obses[self.idx], next_obs) np.copyto(self.not_dones[self.idx], not done) - np.copyto(self.states[self.idx], state) - np.copyto(self.next_states[self.idx], next_state) self.idx = (self.idx + 1) % self.capacity self.full = self.full or self.idx == 0 @@ -114,9 +107,8 @@ class ReplayBuffer(object): self.next_obses[idxs], device=self.device ).float() not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) - states = torch.as_tensor(self.states[idxs], device=self.device) - return obses, actions, rewards, next_obses, not_dones, states + return obses, actions, rewards, next_obses, not_dones def save(self, save_dir): if self.idx == self.last_save: @@ -127,9 +119,7 @@ class ReplayBuffer(object): self.next_obses[self.last_save:self.idx], self.actions[self.last_save:self.idx], self.rewards[self.last_save:self.idx], - self.not_dones[self.last_save:self.idx], - self.states[self.last_save:self.idx], - self.next_states[self.last_save:self.idx] + self.not_dones[self.last_save:self.idx] ] self.last_save = self.idx torch.save(payload, path) @@ -147,8 +137,6 @@ class ReplayBuffer(object): self.actions[start:end] = payload[2] self.rewards[start:end] = payload[3] self.not_dones[start:end] = payload[4] - self.states[start:end] = payload[5] - self.next_states[start:end] = payload[6] self.idx = end