This commit is contained in:
Denis Yarats 2019-09-23 11:38:55 -07:00
parent 681e13b12a
commit f2e39f7a68
10 changed files with 55 additions and 729 deletions

View File

@ -1,10 +1,20 @@
# Soft Actor-Critic implementaiton in PyTorch # SAC+AE implementaiton in PyTorch
## Requirements
## Running locally ## Instructions
To train SAC locally one can use provided `run_local.sh` script (change it to modify particular arguments): To train an SAC+AE agent on the `cheetah run` task from image-based observations run:
``` ```
./run_local.sh python train.py \
--domain_name cheetah \
--task_name run \
--encoder_type pixel \
--decoder_type pixel \
--action_repeat 4 \
--save_video \
--save_tb \
--work_dir ./runs/cheetah_run/sac_ae \
--seed 1
``` ```
This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command: This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command:
``` ```

209
ddpg.py
View File

@ -1,209 +0,0 @@
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from encoder import make_encoder
LOG_FREQ = 10000
class Actor(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_shape[0])
self.outputs = dict()
def forward(self, obs, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
h = F.relu(self.l1(obs))
h = F.relu(self.l2(h))
action = torch.tanh(self.l3(h))
self.outputs['mu'] = action
return action
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_actor/%s_hist' % k, v, step)
L.log_param('train_actor/fc1', self.l1, step)
L.log_param('train_actor/fc2', self.l2, step)
L.log_param('train_actor/fc3', self.l3, step)
class Critic(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], dim=1)
h = F.relu(self.l1(obs_action))
h = F.relu(self.l2(h))
q = self.l3(h)
self.outputs['q'] = q
return q
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
self.encoder.log(L, step, log_freq)
for k, v in self.outputs.items():
L.log_histogram('train_critic/%s_hist' % k, v, step)
L.log_param('train_critic/fc1', self.l1, step)
L.log_param('train_critic/fc2', self.l2, step)
L.log_param('train_critic/fc3', self.l3, step)
class DDPGAgent(object):
def __init__(
self,
obs_shape,
action_shape,
device,
discount=0.99,
tau=0.005,
actor_lr=1e-3,
critic_lr=1e-3,
encoder_type='identity',
encoder_feature_dim=50
):
self.device = device
self.discount = discount
self.tau = tau
# models
self.actor = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.actor_target = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr
)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=critic_lr
)
self.train()
self.critic_target.train()
self.actor_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
def select_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
return action.cpu().data.numpy().flatten()
def sample_action(self, obs):
return self.select_action(obs)
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
target_Q = self.critic_target(
next_obs, self.actor_target(next_obs)
)
target_Q = reward + (not_done * self.discount * target_Q)
current_Q = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q, target_Q)
L.log('train_critic/loss', critic_loss, step)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(L, step)
def update_actor(self, obs, L, step):
action = self.actor(obs, detach_encoder=True)
actor_Q = self.critic(obs, action, detach_encoder=True)
actor_loss = -actor_Q.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(L, step)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
self.update_actor(obs, L, step)
utils.soft_update_params(self.critic, self.critic_target, self.tau)
utils.soft_update_params(self.actor, self.actor_target, self.tau)
def save(self, model_dir, step):
torch.save(
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
)
torch.save(
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
)
def load(self, model_dir, step):
self.actor.load_state_dict(
torch.load('%s/actor_%s.pt' % (model_dir, step))
)
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)

View File

@ -62,45 +62,13 @@ class PixelDecoder(nn.Module):
L.log_param('train_decoder/fc', self.fc, step) L.log_param('train_decoder/fc', self.fc, step)
class StateDecoder(nn.Module): _AVAILABLE_DECODERS = {'pixel': PixelDecoder}
def __init__(self, obs_shape, feature_dim):
super().__init__()
assert len(obs_shape) == 1
self.trunk = nn.Sequential(
nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024),
nn.ReLU(), nn.Linear(1024, obs_shape[0]), nn.ReLU()
)
self.outputs = dict()
def forward(self, obs, detach=False):
h = self.trunk(obs)
if detach:
h = h.detach()
self.outputs['h'] = h
return h
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
L.log_param('train_encoder/fc1', self.trunk[0], step)
L.log_param('train_encoder/fc2', self.trunk[2], step)
for k, v in self.outputs.items():
L.log_histogram('train_encoder/%s_hist' % k, v, step)
_AVAILABLE_DECODERS = {'pixel': PixelDecoder, 'state': StateDecoder}
def make_decoder( def make_decoder(
decoder_type, obs_shape, feature_dim, num_layers, num_filters decoder_type, obs_shape, feature_dim, num_layers, num_filters
): ):
assert decoder_type in _AVAILABLE_DECODERS assert decoder_type in _AVAILABLE_DECODERS
if decoder_type == 'pixel':
return _AVAILABLE_DECODERS[decoder_type]( return _AVAILABLE_DECODERS[decoder_type](
obs_shape, feature_dim, num_layers, num_filters obs_shape, feature_dim, num_layers, num_filters
) )
return _AVAILABLE_DECODERS[decoder_type](obs_shape, feature_dim)

View File

@ -13,21 +13,13 @@ OUT_DIM = {2: 39, 4: 35, 6: 31}
class PixelEncoder(nn.Module): class PixelEncoder(nn.Module):
"""Convolutional encoder of pixels observations.""" """Convolutional encoder of pixels observations."""
def __init__( def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32):
self,
obs_shape,
feature_dim,
num_layers=2,
num_filters=32,
stochastic=False
):
super().__init__() super().__init__()
assert len(obs_shape) == 3 assert len(obs_shape) == 3
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.num_layers = num_layers self.num_layers = num_layers
self.stochastic = stochastic
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
@ -39,13 +31,6 @@ class PixelEncoder(nn.Module):
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
self.ln = nn.LayerNorm(self.feature_dim) self.ln = nn.LayerNorm(self.feature_dim)
if self.stochastic:
self.log_std_min = -10
self.log_std_max = 2
self.fc_log_std = nn.Linear(
num_filters * out_dim * out_dim, self.feature_dim
)
self.outputs = dict() self.outputs = dict()
def reparameterize(self, mu, logstd): def reparameterize(self, mu, logstd):
@ -80,17 +65,6 @@ class PixelEncoder(nn.Module):
self.outputs['ln'] = h_norm self.outputs['ln'] = h_norm
out = torch.tanh(h_norm) out = torch.tanh(h_norm)
if self.stochastic:
self.outputs['mu'] = out
log_std = torch.tanh(self.fc_log_std(h))
# normalize
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1)
out = self.reparameterize(out, log_std)
self.outputs['log_std'] = log_std
self.outputs['tanh'] = out self.outputs['tanh'] = out
return out return out
@ -116,42 +90,8 @@ class PixelEncoder(nn.Module):
L.log_param('train_encoder/ln', self.ln, step) L.log_param('train_encoder/ln', self.ln, step)
class StateEncoder(nn.Module):
def __init__(self, obs_shape, feature_dim):
super().__init__()
assert len(obs_shape) == 1
self.feature_dim = feature_dim
self.trunk = nn.Sequential(
nn.Linear(obs_shape[0], 256), nn.ReLU(),
nn.Linear(256, feature_dim), nn.ReLU()
)
self.outputs = dict()
def forward(self, obs, detach=False):
h = self.trunk(obs)
if detach:
h = h.detach()
self.outputs['h'] = h
return h
def copy_conv_weights_from(self, source):
pass
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
L.log_param('train_encoder/fc1', self.trunk[0], step)
L.log_param('train_encoder/fc2', self.trunk[2], step)
for k, v in self.outputs.items():
L.log_histogram('train_encoder/%s_hist' % k, v, step)
class IdentityEncoder(nn.Module): class IdentityEncoder(nn.Module):
def __init__(self, obs_shape, feature_dim): def __init__(self, obs_shape, feature_dim, num_layers, num_filters):
super().__init__() super().__init__()
assert len(obs_shape) == 1 assert len(obs_shape) == 1
@ -167,19 +107,13 @@ class IdentityEncoder(nn.Module):
pass pass
_AVAILABLE_ENCODERS = { _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}
'pixel': PixelEncoder,
'state': StateEncoder,
'identity': IdentityEncoder
}
def make_encoder( def make_encoder(
encoder_type, obs_shape, feature_dim, num_layers, num_filters, stochastic encoder_type, obs_shape, feature_dim, num_layers, num_filters
): ):
assert encoder_type in _AVAILABLE_ENCODERS assert encoder_type in _AVAILABLE_ENCODERS
if encoder_type == 'pixel':
return _AVAILABLE_ENCODERS[encoder_type]( return _AVAILABLE_ENCODERS[encoder_type](
obs_shape, feature_dim, num_layers, num_filters, stochastic obs_shape, feature_dim, num_layers, num_filters
) )
return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim)

View File

@ -8,19 +8,15 @@ import torchvision
import numpy as np import numpy as np
from termcolor import colored from termcolor import colored
FORMAT_CONFIG = { FORMAT_CONFIG = {
'rl': { 'rl': {
'train': [('episode', 'E', 'int'), 'train': [
('step', 'S', 'int'), ('episode', 'E', 'int'), ('step', 'S', 'int'),
('duration', 'D', 'time'), ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'),
('episode_reward', 'R', 'float'), ('batch_reward', 'BR', 'float'), ('actor_loss', 'ALOSS', 'float'),
('batch_reward', 'BR', 'float'), ('critic_loss', 'CLOSS', 'float'), ('ae_loss', 'RLOSS', 'float')
('actor_loss', 'ALOSS', 'float'), ],
('critic_loss', 'CLOSS', 'float'), 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')]
('ae_loss', 'RLOSS', 'float')],
'eval': [('step', 'S', 'int'),
('episode_reward', 'ER', 'float')]
} }
} }
@ -106,10 +102,12 @@ class Logger(object):
self._sw = None self._sw = None
self._train_mg = MetersGroup( self._train_mg = MetersGroup(
os.path.join(log_dir, 'train.log'), os.path.join(log_dir, 'train.log'),
formating=FORMAT_CONFIG[config]['train']) formating=FORMAT_CONFIG[config]['train']
)
self._eval_mg = MetersGroup( self._eval_mg = MetersGroup(
os.path.join(log_dir, 'eval.log'), os.path.join(log_dir, 'eval.log'),
formating=FORMAT_CONFIG[config]['eval']) formating=FORMAT_CONFIG[config]['eval']
)
def _try_sw_log(self, key, value, step): def _try_sw_log(self, key, value, step):
if self._sw is not None: if self._sw is not None:

21
run.sh
View File

@ -1,21 +0,0 @@
#!/bin/bash
DOMAIN=cheetah
TASK=run
ACTION_REPEAT=4
ENCODER_TYPE=pixel
ENCODER_TYPE=pixel
WORK_DIR=./runs
python train.py \
--domain_name ${DOMAIN} \
--task_name ${TASK} \
--encoder_type ${ENCODER_TYPE} \
--decoder_type ${DECODER_TYPE} \
--action_repeat ${ACTION_REPEAT} \
--save_video \
--save_tb \
--work_dir ${WORK_DIR}/${DOMAIN}_{TASK}/_ae_encoder_${ENCODER_TYPE}_decoder_{ENCODER_TYPE} \
--seed 1

View File

@ -215,8 +215,8 @@ class Critic(nn.Module):
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
class SACAgent(object): class SacAeAgent(object):
"""Soft Actor-Critic algorithm.""" """SAC+AE algorithm."""
def __init__( def __init__(
self, self,
obs_shape, obs_shape,
@ -237,20 +237,17 @@ class SACAgent(object):
critic_beta=0.9, critic_beta=0.9,
critic_tau=0.005, critic_tau=0.005,
critic_target_update_freq=2, critic_target_update_freq=2,
encoder_type='identity', encoder_type='pixel',
encoder_feature_dim=50, encoder_feature_dim=50,
encoder_lr=1e-3, encoder_lr=1e-3,
encoder_tau=0.005, encoder_tau=0.005,
decoder_type='identity', decoder_type='pixel',
decoder_lr=1e-3, decoder_lr=1e-3,
decoder_update_freq=1, decoder_update_freq=1,
decoder_latent_lambda=0.0, decoder_latent_lambda=0.0,
decoder_weight_lambda=0.0, decoder_weight_lambda=0.0,
decoder_kl_lambda=0.0,
num_layers=4, num_layers=4,
num_filters=32, num_filters=32
freeze_encoder=False,
use_dynamics=False
): ):
self.device = device self.device = device
self.discount = discount self.discount = discount
@ -260,11 +257,6 @@ class SACAgent(object):
self.critic_target_update_freq = critic_target_update_freq self.critic_target_update_freq = critic_target_update_freq
self.decoder_update_freq = decoder_update_freq self.decoder_update_freq = decoder_update_freq
self.decoder_latent_lambda = decoder_latent_lambda self.decoder_latent_lambda = decoder_latent_lambda
self.decoder_kl_lambda = decoder_kl_lambda
self.decoder_type = decoder_type
self.use_dynamics = use_dynamics
stochastic = decoder_kl_lambda > 0.0
self.actor = Actor( self.actor = Actor(
obs_shape, action_shape, hidden_dim, encoder_type, obs_shape, action_shape, hidden_dim, encoder_type,
@ -420,9 +412,6 @@ class SACAgent(object):
self.log_alpha_optimizer.step() self.log_alpha_optimizer.step()
def update_decoder(self, obs, target_obs, L, step): def update_decoder(self, obs, target_obs, L, step):
if self.decoder is None:
return
h = self.critic.encoder(obs) h = self.critic.encoder(obs)
if target_obs.dim() == 4: if target_obs.dim() == 4:
@ -477,9 +466,8 @@ class SACAgent(object):
self.encoder_tau self.encoder_tau
) )
if step % self.decoder_update_freq == 0: if self.decoder is None and step % self.decoder_update_freq == 0:
target = obs if self.decoder_type == 'pixel' else state self.update_decoder(obs, obs, L, step)
self.update_decoder(obs, target, L, step)
def save(self, model_dir, step): def save(self, model_dir, step):
torch.save( torch.save(

259
td3.py
View File

@ -1,259 +0,0 @@
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from encoder import make_encoder
LOG_FREQ = 10000
class Actor(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_shape[0])
self.outputs = dict()
def forward(self, obs, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
h = F.relu(self.l1(obs))
h = F.relu(self.l2(h))
action = torch.tanh(self.l3(h))
self.outputs['mu'] = action
return action
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_actor/%s_hist' % k, v, step)
L.log_param('train_actor/fc1', self.l1, step)
L.log_param('train_actor/fc2', self.l2, step)
L.log_param('train_actor/fc3', self.l3, step)
class Critic(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
# Q1 architecture
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)
# Q2 architecture
self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l5 = nn.Linear(400, 300)
self.l6 = nn.Linear(300, 1)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], 1)
h1 = F.relu(self.l1(obs_action))
h1 = F.relu(self.l2(h1))
q1 = self.l3(h1)
h2 = F.relu(self.l4(obs_action))
h2 = F.relu(self.l5(h2))
q2 = self.l6(h2)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
def Q1(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], 1)
h1 = F.relu(self.l1(obs_action))
h1 = F.relu(self.l2(h1))
q1 = self.l3(h1)
return q1
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0: return
self.encoder.log(L, step, log_freq)
for k, v in self.outputs.items():
L.log_histogram('train_critic/%s_hist' % k, v, step)
L.log_param('train_critic/q1_fc1', self.l1, step)
L.log_param('train_critic/q1_fc2', self.l2, step)
L.log_param('train_critic/q1_fc3', self.l3, step)
L.log_param('train_critic/q1_fc4', self.l4, step)
L.log_param('train_critic/q1_fc5', self.l5, step)
L.log_param('train_critic/q1_fc6', self.l6, step)
class TD3Agent(object):
def __init__(
self,
obs_shape,
action_shape,
device,
discount=0.99,
tau=0.005,
policy_noise=0.2,
noise_clip=0.5,
expl_noise=0.1,
actor_lr=1e-3,
critic_lr=1e-3,
encoder_type='identity',
encoder_feature_dim=50,
actor_update_freq=2,
target_update_freq=2,
):
self.device = device
self.discount = discount
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.expl_noise = expl_noise
self.actor_update_freq = actor_update_freq
self.target_update_freq = target_update_freq
# models
self.actor = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.actor_target = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr
)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=critic_lr
)
self.train()
self.critic_target.train()
self.actor_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
def select_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
return action.cpu().data.numpy().flatten()
def sample_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
noise = torch.randn_like(action) * self.expl_noise
action = (action + noise).clamp(-1.0, 1.0)
return action.cpu().data.numpy().flatten()
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
noise = torch.randn_like(action).to(self.device) * self.policy_noise
noise = noise.clamp(-self.noise_clip, self.noise_clip)
next_action = self.actor_target(next_obs) + noise
next_action = next_action.clamp(-1.0, 1.0)
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + (not_done * self.discount * target_Q)
current_Q1, current_Q2 = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q1,
target_Q) + F.mse_loss(current_Q2, target_Q)
L.log('train_critic/loss', critic_loss, step)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(L, step)
def update_actor(self, obs, L, step):
action = self.actor(obs, detach_encoder=True)
actor_Q = self.critic.Q1(obs, action, detach_encoder=True)
actor_loss = -actor_Q.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(L, step)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
if step % self.actor_update_freq == 0:
self.update_actor(obs, L, step)
if step % self.target_update_freq == 0:
utils.soft_update_params(self.critic, self.critic_target, self.tau)
utils.soft_update_params(self.actor, self.actor_target, self.tau)
def save(self, model_dir, step):
torch.save(
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
)
torch.save(
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
)
def load(self, model_dir, step):
self.actor.load_state_dict(
torch.load('%s/actor_%s.pt' % (model_dir, step))
)
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)

View File

@ -15,9 +15,7 @@ import utils
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from sac import SACAgent from sac_ae import SacAeAgent
from td3 import TD3Agent
from ddpg import DDPGAgent
def parse_args(): def parse_args():
@ -31,7 +29,7 @@ def parse_args():
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
# train # train
parser.add_argument('--agent', default='sac', type=str) parser.add_argument('--agent', default='sac_ae', type=str)
parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int)
parser.add_argument('--batch_size', default=512, type=int) parser.add_argument('--batch_size', default=512, type=int)
@ -51,30 +49,22 @@ def parse_args():
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int) parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder # encoder/decoder
parser.add_argument('--encoder_type', default='identity', type=str) parser.add_argument('--encoder_type', default='pixel', type=str)
parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--decoder_type', default='identity', type=str) parser.add_argument('--decoder_type', default='pixel', type=str)
parser.add_argument('--decoder_lr', default=1e-3, type=float) parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int) parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float) parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--decoder_kl_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--freeze_encoder', default=False, action='store_true')
parser.add_argument('--use_dynamics', default=False, action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float) parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float) parser.add_argument('--alpha_beta', default=0.9, type=float)
# td3
parser.add_argument('--policy_noise', default=0.2, type=float)
parser.add_argument('--expl_noise', default=0.1, type=float)
parser.add_argument('--noise_clip', default=0.5, type=float)
parser.add_argument('--tau', default=0.005, type=float)
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--work_dir', default='.', type=str)
@ -82,8 +72,6 @@ def parse_args():
parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true') parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true') parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--pretrained_info', default=None, type=str)
parser.add_argument('--pretrained_decoder', default=False, action='store_true')
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -108,8 +96,8 @@ def evaluate(env, agent, video, num_episodes, L, step):
def make_agent(obs_shape, state_shape, action_shape, args, device): def make_agent(obs_shape, state_shape, action_shape, args, device):
if args.agent == 'sac': if args.agent == 'sac_ae':
return SACAgent( return SacAeAgent(
obs_shape=obs_shape, obs_shape=obs_shape,
state_shape=state_shape, state_shape=state_shape,
action_shape=action_shape, action_shape=action_shape,
@ -137,63 +125,13 @@ def make_agent(obs_shape, state_shape, action_shape, args, device):
decoder_update_freq=args.decoder_update_freq, decoder_update_freq=args.decoder_update_freq,
decoder_latent_lambda=args.decoder_latent_lambda, decoder_latent_lambda=args.decoder_latent_lambda,
decoder_weight_lambda=args.decoder_weight_lambda, decoder_weight_lambda=args.decoder_weight_lambda,
decoder_kl_lambda=args.decoder_kl_lambda,
num_layers=args.num_layers, num_layers=args.num_layers,
num_filters=args.num_filters, num_filters=args.num_filters
freeze_encoder=args.freeze_encoder,
use_dynamics=args.use_dynamics
)
elif args.agent == 'td3':
return TD3Agent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
discount=args.discount,
tau=args.tau,
policy_noise=args.policy_noise,
noise_clip=args.noise_clip,
expl_noise=args.expl_noise,
actor_lr=args.actor_lr,
critic_lr=args.critic_lr,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
actor_update_freq=args.actor_update_freq,
target_update_freq=args.critic_target_update_freq
)
elif args.agent == 'ddpg':
return DDPGAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
discount=args.discount,
tau=args.tau,
actor_lr=args.actor_lr,
critic_lr=args.critic_lr,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim
) )
else: else:
assert 'agent is not supported: %s' % args.agent assert 'agent is not supported: %s' % args.agent
def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder):
path, version = pretrained_info.split(':')
pretrained_agent = copy.deepcopy(agent)
pretrained_agent.load(path, int(version))
agent.critic.encoder.load_state_dict(
pretrained_agent.critic.encoder.state_dict()
)
agent.actor.encoder.load_state_dict(
pretrained_agent.actor.encoder.state_dict()
)
if pretrained_decoder:
agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict())
return agent
def main(): def main():
args = parse_args() args = parse_args()
utils.set_seed_everywhere(args.seed) utils.set_seed_everywhere(args.seed)
@ -232,7 +170,6 @@ def main():
replay_buffer = utils.ReplayBuffer( replay_buffer = utils.ReplayBuffer(
obs_shape=env.observation_space.shape, obs_shape=env.observation_space.shape,
state_shape=env.state_space.shape,
action_shape=env.action_space.shape, action_shape=env.action_space.shape,
capacity=args.replay_buffer_capacity, capacity=args.replay_buffer_capacity,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -241,17 +178,11 @@ def main():
agent = make_agent( agent = make_agent(
obs_shape=env.observation_space.shape, obs_shape=env.observation_space.shape,
state_shape=env.state_space.shape,
action_shape=env.action_space.shape, action_shape=env.action_space.shape,
args=args, args=args,
device=device device=device
) )
if args.pretrained_info is not None:
agent = load_pretrained_encoder(
agent, args.pretrained_info, args.pretrained_decoder
)
L = Logger(args.work_dir, use_tb=args.save_tb) L = Logger(args.work_dir, use_tb=args.save_tb)
episode, episode_reward, done = 0, 0, True episode, episode_reward, done = 0, 0, True
@ -295,9 +226,7 @@ def main():
for _ in range(num_updates): for _ in range(num_updates):
agent.update(replay_buffer, L, step) agent.update(replay_buffer, L, step)
state = env.env.env._current_state
next_obs, reward, done, _ = env.step(action) next_obs, reward, done, _ = env.step(action)
next_state = env.env.env._current_state.shape
# allow infinit bootstrap # allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
@ -305,7 +234,7 @@ def main():
) )
episode_reward += reward episode_reward += reward
replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state) replay_buffer.add(obs, action, reward, next_obs, done_bool)
obs = next_obs obs = next_obs
episode_step += 1 episode_step += 1

View File

@ -67,10 +67,7 @@ def preprocess_obs(obs, bits=5):
class ReplayBuffer(object): class ReplayBuffer(object):
"""Buffer to store environment transitions.""" """Buffer to store environment transitions."""
def __init__( def __init__(self, obs_shape, action_shape, capacity, batch_size, device):
self, obs_shape, state_shape, action_shape, capacity, batch_size,
device
):
self.capacity = capacity self.capacity = capacity
self.batch_size = batch_size self.batch_size = batch_size
self.device = device self.device = device
@ -83,21 +80,17 @@ class ReplayBuffer(object):
self.actions = np.empty((capacity, *action_shape), dtype=np.float32) self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
self.rewards = np.empty((capacity, 1), dtype=np.float32) self.rewards = np.empty((capacity, 1), dtype=np.float32)
self.not_dones = np.empty((capacity, 1), dtype=np.float32) self.not_dones = np.empty((capacity, 1), dtype=np.float32)
self.states = np.empty((capacity, *state_shape), dtype=np.float32)
self.next_states = np.empty((capacity, *state_shape), dtype=np.float32)
self.idx = 0 self.idx = 0
self.last_save = 0 self.last_save = 0
self.full = False self.full = False
def add(self, obs, action, reward, next_obs, done, state, next_state): def add(self, obs, action, reward, next_obs, done):
np.copyto(self.obses[self.idx], obs) np.copyto(self.obses[self.idx], obs)
np.copyto(self.actions[self.idx], action) np.copyto(self.actions[self.idx], action)
np.copyto(self.rewards[self.idx], reward) np.copyto(self.rewards[self.idx], reward)
np.copyto(self.next_obses[self.idx], next_obs) np.copyto(self.next_obses[self.idx], next_obs)
np.copyto(self.not_dones[self.idx], not done) np.copyto(self.not_dones[self.idx], not done)
np.copyto(self.states[self.idx], state)
np.copyto(self.next_states[self.idx], next_state)
self.idx = (self.idx + 1) % self.capacity self.idx = (self.idx + 1) % self.capacity
self.full = self.full or self.idx == 0 self.full = self.full or self.idx == 0
@ -114,9 +107,8 @@ class ReplayBuffer(object):
self.next_obses[idxs], device=self.device self.next_obses[idxs], device=self.device
).float() ).float()
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
states = torch.as_tensor(self.states[idxs], device=self.device)
return obses, actions, rewards, next_obses, not_dones, states return obses, actions, rewards, next_obses, not_dones
def save(self, save_dir): def save(self, save_dir):
if self.idx == self.last_save: if self.idx == self.last_save:
@ -127,9 +119,7 @@ class ReplayBuffer(object):
self.next_obses[self.last_save:self.idx], self.next_obses[self.last_save:self.idx],
self.actions[self.last_save:self.idx], self.actions[self.last_save:self.idx],
self.rewards[self.last_save:self.idx], self.rewards[self.last_save:self.idx],
self.not_dones[self.last_save:self.idx], self.not_dones[self.last_save:self.idx]
self.states[self.last_save:self.idx],
self.next_states[self.last_save:self.idx]
] ]
self.last_save = self.idx self.last_save = self.idx
torch.save(payload, path) torch.save(payload, path)
@ -147,8 +137,6 @@ class ReplayBuffer(object):
self.actions[start:end] = payload[2] self.actions[start:end] = payload[2]
self.rewards[start:end] = payload[3] self.rewards[start:end] = payload[3]
self.not_dones[start:end] = payload[4] self.not_dones[start:end] = payload[4]
self.states[start:end] = payload[5]
self.next_states[start:end] = payload[6]
self.idx = end self.idx = end