changes
This commit is contained in:
parent
681e13b12a
commit
f2e39f7a68
18
README.md
18
README.md
@ -1,10 +1,20 @@
|
|||||||
# Soft Actor-Critic implementaiton in PyTorch
|
# SAC+AE implementaiton in PyTorch
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
## Running locally
|
## Instructions
|
||||||
To train SAC locally one can use provided `run_local.sh` script (change it to modify particular arguments):
|
To train an SAC+AE agent on the `cheetah run` task from image-based observations run:
|
||||||
```
|
```
|
||||||
./run_local.sh
|
python train.py \
|
||||||
|
--domain_name cheetah \
|
||||||
|
--task_name run \
|
||||||
|
--encoder_type pixel \
|
||||||
|
--decoder_type pixel \
|
||||||
|
--action_repeat 4 \
|
||||||
|
--save_video \
|
||||||
|
--save_tb \
|
||||||
|
--work_dir ./runs/cheetah_run/sac_ae \
|
||||||
|
--seed 1
|
||||||
```
|
```
|
||||||
This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command:
|
This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command:
|
||||||
```
|
```
|
||||||
|
209
ddpg.py
209
ddpg.py
@ -1,209 +0,0 @@
|
|||||||
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import utils
|
|
||||||
from encoder import make_encoder
|
|
||||||
|
|
||||||
LOG_FREQ = 10000
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.encoder = make_encoder(
|
|
||||||
encoder_type, obs_shape, encoder_feature_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
|
|
||||||
self.l2 = nn.Linear(400, 300)
|
|
||||||
self.l3 = nn.Linear(300, action_shape[0])
|
|
||||||
|
|
||||||
self.outputs = dict()
|
|
||||||
|
|
||||||
def forward(self, obs, detach_encoder=False):
|
|
||||||
obs = self.encoder(obs, detach=detach_encoder)
|
|
||||||
h = F.relu(self.l1(obs))
|
|
||||||
h = F.relu(self.l2(h))
|
|
||||||
action = torch.tanh(self.l3(h))
|
|
||||||
self.outputs['mu'] = action
|
|
||||||
return action
|
|
||||||
|
|
||||||
def log(self, L, step, log_freq=LOG_FREQ):
|
|
||||||
if step % log_freq != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
for k, v in self.outputs.items():
|
|
||||||
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
|
||||||
|
|
||||||
L.log_param('train_actor/fc1', self.l1, step)
|
|
||||||
L.log_param('train_actor/fc2', self.l2, step)
|
|
||||||
L.log_param('train_actor/fc3', self.l3, step)
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.encoder = make_encoder(
|
|
||||||
encoder_type, obs_shape, encoder_feature_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
|
||||||
self.l2 = nn.Linear(400, 300)
|
|
||||||
self.l3 = nn.Linear(300, 1)
|
|
||||||
|
|
||||||
self.outputs = dict()
|
|
||||||
|
|
||||||
def forward(self, obs, action, detach_encoder=False):
|
|
||||||
obs = self.encoder(obs, detach=detach_encoder)
|
|
||||||
obs_action = torch.cat([obs, action], dim=1)
|
|
||||||
h = F.relu(self.l1(obs_action))
|
|
||||||
h = F.relu(self.l2(h))
|
|
||||||
q = self.l3(h)
|
|
||||||
self.outputs['q'] = q
|
|
||||||
return q
|
|
||||||
|
|
||||||
def log(self, L, step, log_freq=LOG_FREQ):
|
|
||||||
if step % log_freq != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.encoder.log(L, step, log_freq)
|
|
||||||
|
|
||||||
for k, v in self.outputs.items():
|
|
||||||
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
|
||||||
|
|
||||||
L.log_param('train_critic/fc1', self.l1, step)
|
|
||||||
L.log_param('train_critic/fc2', self.l2, step)
|
|
||||||
L.log_param('train_critic/fc3', self.l3, step)
|
|
||||||
|
|
||||||
|
|
||||||
class DDPGAgent(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
obs_shape,
|
|
||||||
action_shape,
|
|
||||||
device,
|
|
||||||
discount=0.99,
|
|
||||||
tau=0.005,
|
|
||||||
actor_lr=1e-3,
|
|
||||||
critic_lr=1e-3,
|
|
||||||
encoder_type='identity',
|
|
||||||
encoder_feature_dim=50
|
|
||||||
):
|
|
||||||
self.device = device
|
|
||||||
self.discount = discount
|
|
||||||
self.tau = tau
|
|
||||||
|
|
||||||
# models
|
|
||||||
self.actor = Actor(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
self.critic = Critic(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
|
||||||
|
|
||||||
self.actor_target = Actor(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
|
||||||
|
|
||||||
self.critic_target = Critic(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
|
||||||
|
|
||||||
# optimizers
|
|
||||||
self.actor_optimizer = torch.optim.Adam(
|
|
||||||
self.actor.parameters(), lr=actor_lr
|
|
||||||
)
|
|
||||||
|
|
||||||
self.critic_optimizer = torch.optim.Adam(
|
|
||||||
self.critic.parameters(), lr=critic_lr
|
|
||||||
)
|
|
||||||
|
|
||||||
self.train()
|
|
||||||
self.critic_target.train()
|
|
||||||
self.actor_target.train()
|
|
||||||
|
|
||||||
def train(self, training=True):
|
|
||||||
self.training = training
|
|
||||||
self.actor.train(training)
|
|
||||||
self.critic.train(training)
|
|
||||||
|
|
||||||
def select_action(self, obs):
|
|
||||||
with torch.no_grad():
|
|
||||||
obs = torch.FloatTensor(obs).to(self.device)
|
|
||||||
obs = obs.unsqueeze(0)
|
|
||||||
action = self.actor(obs)
|
|
||||||
return action.cpu().data.numpy().flatten()
|
|
||||||
|
|
||||||
def sample_action(self, obs):
|
|
||||||
return self.select_action(obs)
|
|
||||||
|
|
||||||
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
|
||||||
with torch.no_grad():
|
|
||||||
target_Q = self.critic_target(
|
|
||||||
next_obs, self.actor_target(next_obs)
|
|
||||||
)
|
|
||||||
target_Q = reward + (not_done * self.discount * target_Q)
|
|
||||||
|
|
||||||
current_Q = self.critic(obs, action)
|
|
||||||
|
|
||||||
critic_loss = F.mse_loss(current_Q, target_Q)
|
|
||||||
L.log('train_critic/loss', critic_loss, step)
|
|
||||||
|
|
||||||
self.critic_optimizer.zero_grad()
|
|
||||||
critic_loss.backward()
|
|
||||||
self.critic_optimizer.step()
|
|
||||||
|
|
||||||
self.critic.log(L, step)
|
|
||||||
|
|
||||||
def update_actor(self, obs, L, step):
|
|
||||||
action = self.actor(obs, detach_encoder=True)
|
|
||||||
actor_Q = self.critic(obs, action, detach_encoder=True)
|
|
||||||
actor_loss = -actor_Q.mean()
|
|
||||||
|
|
||||||
self.actor_optimizer.zero_grad()
|
|
||||||
actor_loss.backward()
|
|
||||||
self.actor_optimizer.step()
|
|
||||||
|
|
||||||
self.actor.log(L, step)
|
|
||||||
|
|
||||||
def update(self, replay_buffer, L, step):
|
|
||||||
obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
|
||||||
|
|
||||||
L.log('train/batch_reward', reward.mean(), step)
|
|
||||||
|
|
||||||
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
|
||||||
self.update_actor(obs, L, step)
|
|
||||||
|
|
||||||
utils.soft_update_params(self.critic, self.critic_target, self.tau)
|
|
||||||
utils.soft_update_params(self.actor, self.actor_target, self.tau)
|
|
||||||
|
|
||||||
def save(self, model_dir, step):
|
|
||||||
torch.save(
|
|
||||||
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
|
||||||
)
|
|
||||||
torch.save(
|
|
||||||
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, model_dir, step):
|
|
||||||
self.actor.load_state_dict(
|
|
||||||
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
|
||||||
)
|
|
||||||
self.critic.load_state_dict(
|
|
||||||
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
|
||||||
)
|
|
34
decoder.py
34
decoder.py
@ -62,45 +62,13 @@ class PixelDecoder(nn.Module):
|
|||||||
L.log_param('train_decoder/fc', self.fc, step)
|
L.log_param('train_decoder/fc', self.fc, step)
|
||||||
|
|
||||||
|
|
||||||
class StateDecoder(nn.Module):
|
_AVAILABLE_DECODERS = {'pixel': PixelDecoder}
|
||||||
def __init__(self, obs_shape, feature_dim):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
assert len(obs_shape) == 1
|
|
||||||
|
|
||||||
self.trunk = nn.Sequential(
|
|
||||||
nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024),
|
|
||||||
nn.ReLU(), nn.Linear(1024, obs_shape[0]), nn.ReLU()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.outputs = dict()
|
|
||||||
|
|
||||||
def forward(self, obs, detach=False):
|
|
||||||
h = self.trunk(obs)
|
|
||||||
if detach:
|
|
||||||
h = h.detach()
|
|
||||||
self.outputs['h'] = h
|
|
||||||
return h
|
|
||||||
|
|
||||||
def log(self, L, step, log_freq):
|
|
||||||
if step % log_freq != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
L.log_param('train_encoder/fc1', self.trunk[0], step)
|
|
||||||
L.log_param('train_encoder/fc2', self.trunk[2], step)
|
|
||||||
for k, v in self.outputs.items():
|
|
||||||
L.log_histogram('train_encoder/%s_hist' % k, v, step)
|
|
||||||
|
|
||||||
|
|
||||||
_AVAILABLE_DECODERS = {'pixel': PixelDecoder, 'state': StateDecoder}
|
|
||||||
|
|
||||||
|
|
||||||
def make_decoder(
|
def make_decoder(
|
||||||
decoder_type, obs_shape, feature_dim, num_layers, num_filters
|
decoder_type, obs_shape, feature_dim, num_layers, num_filters
|
||||||
):
|
):
|
||||||
assert decoder_type in _AVAILABLE_DECODERS
|
assert decoder_type in _AVAILABLE_DECODERS
|
||||||
if decoder_type == 'pixel':
|
|
||||||
return _AVAILABLE_DECODERS[decoder_type](
|
return _AVAILABLE_DECODERS[decoder_type](
|
||||||
obs_shape, feature_dim, num_layers, num_filters
|
obs_shape, feature_dim, num_layers, num_filters
|
||||||
)
|
)
|
||||||
return _AVAILABLE_DECODERS[decoder_type](obs_shape, feature_dim)
|
|
||||||
|
76
encoder.py
76
encoder.py
@ -13,21 +13,13 @@ OUT_DIM = {2: 39, 4: 35, 6: 31}
|
|||||||
|
|
||||||
class PixelEncoder(nn.Module):
|
class PixelEncoder(nn.Module):
|
||||||
"""Convolutional encoder of pixels observations."""
|
"""Convolutional encoder of pixels observations."""
|
||||||
def __init__(
|
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32):
|
||||||
self,
|
|
||||||
obs_shape,
|
|
||||||
feature_dim,
|
|
||||||
num_layers=2,
|
|
||||||
num_filters=32,
|
|
||||||
stochastic=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert len(obs_shape) == 3
|
assert len(obs_shape) == 3
|
||||||
|
|
||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.stochastic = stochastic
|
|
||||||
|
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
|
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
|
||||||
@ -39,13 +31,6 @@ class PixelEncoder(nn.Module):
|
|||||||
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
|
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
|
||||||
self.ln = nn.LayerNorm(self.feature_dim)
|
self.ln = nn.LayerNorm(self.feature_dim)
|
||||||
|
|
||||||
if self.stochastic:
|
|
||||||
self.log_std_min = -10
|
|
||||||
self.log_std_max = 2
|
|
||||||
self.fc_log_std = nn.Linear(
|
|
||||||
num_filters * out_dim * out_dim, self.feature_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
self.outputs = dict()
|
self.outputs = dict()
|
||||||
|
|
||||||
def reparameterize(self, mu, logstd):
|
def reparameterize(self, mu, logstd):
|
||||||
@ -80,17 +65,6 @@ class PixelEncoder(nn.Module):
|
|||||||
self.outputs['ln'] = h_norm
|
self.outputs['ln'] = h_norm
|
||||||
|
|
||||||
out = torch.tanh(h_norm)
|
out = torch.tanh(h_norm)
|
||||||
|
|
||||||
if self.stochastic:
|
|
||||||
self.outputs['mu'] = out
|
|
||||||
log_std = torch.tanh(self.fc_log_std(h))
|
|
||||||
# normalize
|
|
||||||
log_std = self.log_std_min + 0.5 * (
|
|
||||||
self.log_std_max - self.log_std_min
|
|
||||||
) * (log_std + 1)
|
|
||||||
out = self.reparameterize(out, log_std)
|
|
||||||
self.outputs['log_std'] = log_std
|
|
||||||
|
|
||||||
self.outputs['tanh'] = out
|
self.outputs['tanh'] = out
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@ -116,42 +90,8 @@ class PixelEncoder(nn.Module):
|
|||||||
L.log_param('train_encoder/ln', self.ln, step)
|
L.log_param('train_encoder/ln', self.ln, step)
|
||||||
|
|
||||||
|
|
||||||
class StateEncoder(nn.Module):
|
|
||||||
def __init__(self, obs_shape, feature_dim):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
assert len(obs_shape) == 1
|
|
||||||
self.feature_dim = feature_dim
|
|
||||||
|
|
||||||
self.trunk = nn.Sequential(
|
|
||||||
nn.Linear(obs_shape[0], 256), nn.ReLU(),
|
|
||||||
nn.Linear(256, feature_dim), nn.ReLU()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.outputs = dict()
|
|
||||||
|
|
||||||
def forward(self, obs, detach=False):
|
|
||||||
h = self.trunk(obs)
|
|
||||||
if detach:
|
|
||||||
h = h.detach()
|
|
||||||
self.outputs['h'] = h
|
|
||||||
return h
|
|
||||||
|
|
||||||
def copy_conv_weights_from(self, source):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log(self, L, step, log_freq):
|
|
||||||
if step % log_freq != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
L.log_param('train_encoder/fc1', self.trunk[0], step)
|
|
||||||
L.log_param('train_encoder/fc2', self.trunk[2], step)
|
|
||||||
for k, v in self.outputs.items():
|
|
||||||
L.log_histogram('train_encoder/%s_hist' % k, v, step)
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityEncoder(nn.Module):
|
class IdentityEncoder(nn.Module):
|
||||||
def __init__(self, obs_shape, feature_dim):
|
def __init__(self, obs_shape, feature_dim, num_layers, num_filters):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert len(obs_shape) == 1
|
assert len(obs_shape) == 1
|
||||||
@ -167,19 +107,13 @@ class IdentityEncoder(nn.Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
_AVAILABLE_ENCODERS = {
|
_AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}
|
||||||
'pixel': PixelEncoder,
|
|
||||||
'state': StateEncoder,
|
|
||||||
'identity': IdentityEncoder
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def make_encoder(
|
def make_encoder(
|
||||||
encoder_type, obs_shape, feature_dim, num_layers, num_filters, stochastic
|
encoder_type, obs_shape, feature_dim, num_layers, num_filters
|
||||||
):
|
):
|
||||||
assert encoder_type in _AVAILABLE_ENCODERS
|
assert encoder_type in _AVAILABLE_ENCODERS
|
||||||
if encoder_type == 'pixel':
|
|
||||||
return _AVAILABLE_ENCODERS[encoder_type](
|
return _AVAILABLE_ENCODERS[encoder_type](
|
||||||
obs_shape, feature_dim, num_layers, num_filters, stochastic
|
obs_shape, feature_dim, num_layers, num_filters
|
||||||
)
|
)
|
||||||
return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim)
|
|
||||||
|
24
logger.py
24
logger.py
@ -8,19 +8,15 @@ import torchvision
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
FORMAT_CONFIG = {
|
FORMAT_CONFIG = {
|
||||||
'rl': {
|
'rl': {
|
||||||
'train': [('episode', 'E', 'int'),
|
'train': [
|
||||||
('step', 'S', 'int'),
|
('episode', 'E', 'int'), ('step', 'S', 'int'),
|
||||||
('duration', 'D', 'time'),
|
('duration', 'D', 'time'), ('episode_reward', 'R', 'float'),
|
||||||
('episode_reward', 'R', 'float'),
|
('batch_reward', 'BR', 'float'), ('actor_loss', 'ALOSS', 'float'),
|
||||||
('batch_reward', 'BR', 'float'),
|
('critic_loss', 'CLOSS', 'float'), ('ae_loss', 'RLOSS', 'float')
|
||||||
('actor_loss', 'ALOSS', 'float'),
|
],
|
||||||
('critic_loss', 'CLOSS', 'float'),
|
'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')]
|
||||||
('ae_loss', 'RLOSS', 'float')],
|
|
||||||
'eval': [('step', 'S', 'int'),
|
|
||||||
('episode_reward', 'ER', 'float')]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,10 +102,12 @@ class Logger(object):
|
|||||||
self._sw = None
|
self._sw = None
|
||||||
self._train_mg = MetersGroup(
|
self._train_mg = MetersGroup(
|
||||||
os.path.join(log_dir, 'train.log'),
|
os.path.join(log_dir, 'train.log'),
|
||||||
formating=FORMAT_CONFIG[config]['train'])
|
formating=FORMAT_CONFIG[config]['train']
|
||||||
|
)
|
||||||
self._eval_mg = MetersGroup(
|
self._eval_mg = MetersGroup(
|
||||||
os.path.join(log_dir, 'eval.log'),
|
os.path.join(log_dir, 'eval.log'),
|
||||||
formating=FORMAT_CONFIG[config]['eval'])
|
formating=FORMAT_CONFIG[config]['eval']
|
||||||
|
)
|
||||||
|
|
||||||
def _try_sw_log(self, key, value, step):
|
def _try_sw_log(self, key, value, step):
|
||||||
if self._sw is not None:
|
if self._sw is not None:
|
||||||
|
21
run.sh
21
run.sh
@ -1,21 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
DOMAIN=cheetah
|
|
||||||
TASK=run
|
|
||||||
ACTION_REPEAT=4
|
|
||||||
ENCODER_TYPE=pixel
|
|
||||||
ENCODER_TYPE=pixel
|
|
||||||
|
|
||||||
|
|
||||||
WORK_DIR=./runs
|
|
||||||
|
|
||||||
python train.py \
|
|
||||||
--domain_name ${DOMAIN} \
|
|
||||||
--task_name ${TASK} \
|
|
||||||
--encoder_type ${ENCODER_TYPE} \
|
|
||||||
--decoder_type ${DECODER_TYPE} \
|
|
||||||
--action_repeat ${ACTION_REPEAT} \
|
|
||||||
--save_video \
|
|
||||||
--save_tb \
|
|
||||||
--work_dir ${WORK_DIR}/${DOMAIN}_{TASK}/_ae_encoder_${ENCODER_TYPE}_decoder_{ENCODER_TYPE} \
|
|
||||||
--seed 1
|
|
@ -215,8 +215,8 @@ class Critic(nn.Module):
|
|||||||
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
|
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
|
||||||
|
|
||||||
|
|
||||||
class SACAgent(object):
|
class SacAeAgent(object):
|
||||||
"""Soft Actor-Critic algorithm."""
|
"""SAC+AE algorithm."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
obs_shape,
|
obs_shape,
|
||||||
@ -237,20 +237,17 @@ class SACAgent(object):
|
|||||||
critic_beta=0.9,
|
critic_beta=0.9,
|
||||||
critic_tau=0.005,
|
critic_tau=0.005,
|
||||||
critic_target_update_freq=2,
|
critic_target_update_freq=2,
|
||||||
encoder_type='identity',
|
encoder_type='pixel',
|
||||||
encoder_feature_dim=50,
|
encoder_feature_dim=50,
|
||||||
encoder_lr=1e-3,
|
encoder_lr=1e-3,
|
||||||
encoder_tau=0.005,
|
encoder_tau=0.005,
|
||||||
decoder_type='identity',
|
decoder_type='pixel',
|
||||||
decoder_lr=1e-3,
|
decoder_lr=1e-3,
|
||||||
decoder_update_freq=1,
|
decoder_update_freq=1,
|
||||||
decoder_latent_lambda=0.0,
|
decoder_latent_lambda=0.0,
|
||||||
decoder_weight_lambda=0.0,
|
decoder_weight_lambda=0.0,
|
||||||
decoder_kl_lambda=0.0,
|
|
||||||
num_layers=4,
|
num_layers=4,
|
||||||
num_filters=32,
|
num_filters=32
|
||||||
freeze_encoder=False,
|
|
||||||
use_dynamics=False
|
|
||||||
):
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.discount = discount
|
self.discount = discount
|
||||||
@ -260,11 +257,6 @@ class SACAgent(object):
|
|||||||
self.critic_target_update_freq = critic_target_update_freq
|
self.critic_target_update_freq = critic_target_update_freq
|
||||||
self.decoder_update_freq = decoder_update_freq
|
self.decoder_update_freq = decoder_update_freq
|
||||||
self.decoder_latent_lambda = decoder_latent_lambda
|
self.decoder_latent_lambda = decoder_latent_lambda
|
||||||
self.decoder_kl_lambda = decoder_kl_lambda
|
|
||||||
self.decoder_type = decoder_type
|
|
||||||
self.use_dynamics = use_dynamics
|
|
||||||
|
|
||||||
stochastic = decoder_kl_lambda > 0.0
|
|
||||||
|
|
||||||
self.actor = Actor(
|
self.actor = Actor(
|
||||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
@ -420,9 +412,6 @@ class SACAgent(object):
|
|||||||
self.log_alpha_optimizer.step()
|
self.log_alpha_optimizer.step()
|
||||||
|
|
||||||
def update_decoder(self, obs, target_obs, L, step):
|
def update_decoder(self, obs, target_obs, L, step):
|
||||||
if self.decoder is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
h = self.critic.encoder(obs)
|
h = self.critic.encoder(obs)
|
||||||
|
|
||||||
if target_obs.dim() == 4:
|
if target_obs.dim() == 4:
|
||||||
@ -477,9 +466,8 @@ class SACAgent(object):
|
|||||||
self.encoder_tau
|
self.encoder_tau
|
||||||
)
|
)
|
||||||
|
|
||||||
if step % self.decoder_update_freq == 0:
|
if self.decoder is None and step % self.decoder_update_freq == 0:
|
||||||
target = obs if self.decoder_type == 'pixel' else state
|
self.update_decoder(obs, obs, L, step)
|
||||||
self.update_decoder(obs, target, L, step)
|
|
||||||
|
|
||||||
def save(self, model_dir, step):
|
def save(self, model_dir, step):
|
||||||
torch.save(
|
torch.save(
|
259
td3.py
259
td3.py
@ -1,259 +0,0 @@
|
|||||||
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import utils
|
|
||||||
from encoder import make_encoder
|
|
||||||
|
|
||||||
LOG_FREQ = 10000
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.encoder = make_encoder(
|
|
||||||
encoder_type, obs_shape, encoder_feature_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
|
|
||||||
self.l2 = nn.Linear(400, 300)
|
|
||||||
self.l3 = nn.Linear(300, action_shape[0])
|
|
||||||
|
|
||||||
self.outputs = dict()
|
|
||||||
|
|
||||||
def forward(self, obs, detach_encoder=False):
|
|
||||||
obs = self.encoder(obs, detach=detach_encoder)
|
|
||||||
h = F.relu(self.l1(obs))
|
|
||||||
h = F.relu(self.l2(h))
|
|
||||||
action = torch.tanh(self.l3(h))
|
|
||||||
self.outputs['mu'] = action
|
|
||||||
return action
|
|
||||||
|
|
||||||
def log(self, L, step, log_freq=LOG_FREQ):
|
|
||||||
if step % log_freq != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
for k, v in self.outputs.items():
|
|
||||||
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
|
||||||
|
|
||||||
L.log_param('train_actor/fc1', self.l1, step)
|
|
||||||
L.log_param('train_actor/fc2', self.l2, step)
|
|
||||||
L.log_param('train_actor/fc3', self.l3, step)
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.encoder = make_encoder(
|
|
||||||
encoder_type, obs_shape, encoder_feature_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
# Q1 architecture
|
|
||||||
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
|
||||||
self.l2 = nn.Linear(400, 300)
|
|
||||||
self.l3 = nn.Linear(300, 1)
|
|
||||||
|
|
||||||
# Q2 architecture
|
|
||||||
self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
|
||||||
self.l5 = nn.Linear(400, 300)
|
|
||||||
self.l6 = nn.Linear(300, 1)
|
|
||||||
|
|
||||||
self.outputs = dict()
|
|
||||||
|
|
||||||
def forward(self, obs, action, detach_encoder=False):
|
|
||||||
obs = self.encoder(obs, detach=detach_encoder)
|
|
||||||
|
|
||||||
obs_action = torch.cat([obs, action], 1)
|
|
||||||
|
|
||||||
h1 = F.relu(self.l1(obs_action))
|
|
||||||
h1 = F.relu(self.l2(h1))
|
|
||||||
q1 = self.l3(h1)
|
|
||||||
|
|
||||||
h2 = F.relu(self.l4(obs_action))
|
|
||||||
h2 = F.relu(self.l5(h2))
|
|
||||||
q2 = self.l6(h2)
|
|
||||||
|
|
||||||
self.outputs['q1'] = q1
|
|
||||||
self.outputs['q2'] = q2
|
|
||||||
|
|
||||||
return q1, q2
|
|
||||||
|
|
||||||
def Q1(self, obs, action, detach_encoder=False):
|
|
||||||
obs = self.encoder(obs, detach=detach_encoder)
|
|
||||||
|
|
||||||
obs_action = torch.cat([obs, action], 1)
|
|
||||||
|
|
||||||
h1 = F.relu(self.l1(obs_action))
|
|
||||||
h1 = F.relu(self.l2(h1))
|
|
||||||
q1 = self.l3(h1)
|
|
||||||
return q1
|
|
||||||
|
|
||||||
def log(self, L, step, log_freq=LOG_FREQ):
|
|
||||||
if step % log_freq != 0: return
|
|
||||||
|
|
||||||
self.encoder.log(L, step, log_freq)
|
|
||||||
|
|
||||||
for k, v in self.outputs.items():
|
|
||||||
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
|
||||||
|
|
||||||
L.log_param('train_critic/q1_fc1', self.l1, step)
|
|
||||||
L.log_param('train_critic/q1_fc2', self.l2, step)
|
|
||||||
L.log_param('train_critic/q1_fc3', self.l3, step)
|
|
||||||
L.log_param('train_critic/q1_fc4', self.l4, step)
|
|
||||||
L.log_param('train_critic/q1_fc5', self.l5, step)
|
|
||||||
L.log_param('train_critic/q1_fc6', self.l6, step)
|
|
||||||
|
|
||||||
|
|
||||||
class TD3Agent(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
obs_shape,
|
|
||||||
action_shape,
|
|
||||||
device,
|
|
||||||
discount=0.99,
|
|
||||||
tau=0.005,
|
|
||||||
policy_noise=0.2,
|
|
||||||
noise_clip=0.5,
|
|
||||||
expl_noise=0.1,
|
|
||||||
actor_lr=1e-3,
|
|
||||||
critic_lr=1e-3,
|
|
||||||
encoder_type='identity',
|
|
||||||
encoder_feature_dim=50,
|
|
||||||
actor_update_freq=2,
|
|
||||||
target_update_freq=2,
|
|
||||||
):
|
|
||||||
self.device = device
|
|
||||||
self.discount = discount
|
|
||||||
self.tau = tau
|
|
||||||
self.policy_noise = policy_noise
|
|
||||||
self.noise_clip = noise_clip
|
|
||||||
self.expl_noise = expl_noise
|
|
||||||
self.actor_update_freq = actor_update_freq
|
|
||||||
self.target_update_freq = target_update_freq
|
|
||||||
|
|
||||||
# models
|
|
||||||
self.actor = Actor(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
self.critic = Critic(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
|
||||||
|
|
||||||
self.actor_target = Actor(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
|
||||||
|
|
||||||
self.critic_target = Critic(
|
|
||||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
||||||
).to(device)
|
|
||||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
|
||||||
|
|
||||||
# optimizers
|
|
||||||
self.actor_optimizer = torch.optim.Adam(
|
|
||||||
self.actor.parameters(), lr=actor_lr
|
|
||||||
)
|
|
||||||
|
|
||||||
self.critic_optimizer = torch.optim.Adam(
|
|
||||||
self.critic.parameters(), lr=critic_lr
|
|
||||||
)
|
|
||||||
|
|
||||||
self.train()
|
|
||||||
self.critic_target.train()
|
|
||||||
self.actor_target.train()
|
|
||||||
|
|
||||||
def train(self, training=True):
|
|
||||||
self.training = training
|
|
||||||
self.actor.train(training)
|
|
||||||
self.critic.train(training)
|
|
||||||
|
|
||||||
def select_action(self, obs):
|
|
||||||
with torch.no_grad():
|
|
||||||
obs = torch.FloatTensor(obs).to(self.device)
|
|
||||||
obs = obs.unsqueeze(0)
|
|
||||||
action = self.actor(obs)
|
|
||||||
return action.cpu().data.numpy().flatten()
|
|
||||||
|
|
||||||
def sample_action(self, obs):
|
|
||||||
with torch.no_grad():
|
|
||||||
obs = torch.FloatTensor(obs).to(self.device)
|
|
||||||
obs = obs.unsqueeze(0)
|
|
||||||
action = self.actor(obs)
|
|
||||||
noise = torch.randn_like(action) * self.expl_noise
|
|
||||||
action = (action + noise).clamp(-1.0, 1.0)
|
|
||||||
return action.cpu().data.numpy().flatten()
|
|
||||||
|
|
||||||
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
|
||||||
with torch.no_grad():
|
|
||||||
noise = torch.randn_like(action).to(self.device) * self.policy_noise
|
|
||||||
noise = noise.clamp(-self.noise_clip, self.noise_clip)
|
|
||||||
next_action = self.actor_target(next_obs) + noise
|
|
||||||
next_action = next_action.clamp(-1.0, 1.0)
|
|
||||||
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
|
|
||||||
target_Q = torch.min(target_Q1, target_Q2)
|
|
||||||
target_Q = reward + (not_done * self.discount * target_Q)
|
|
||||||
|
|
||||||
current_Q1, current_Q2 = self.critic(obs, action)
|
|
||||||
|
|
||||||
critic_loss = F.mse_loss(current_Q1,
|
|
||||||
target_Q) + F.mse_loss(current_Q2, target_Q)
|
|
||||||
L.log('train_critic/loss', critic_loss, step)
|
|
||||||
|
|
||||||
self.critic_optimizer.zero_grad()
|
|
||||||
critic_loss.backward()
|
|
||||||
self.critic_optimizer.step()
|
|
||||||
|
|
||||||
self.critic.log(L, step)
|
|
||||||
|
|
||||||
def update_actor(self, obs, L, step):
|
|
||||||
action = self.actor(obs, detach_encoder=True)
|
|
||||||
actor_Q = self.critic.Q1(obs, action, detach_encoder=True)
|
|
||||||
actor_loss = -actor_Q.mean()
|
|
||||||
|
|
||||||
self.actor_optimizer.zero_grad()
|
|
||||||
actor_loss.backward()
|
|
||||||
self.actor_optimizer.step()
|
|
||||||
|
|
||||||
self.actor.log(L, step)
|
|
||||||
|
|
||||||
def update(self, replay_buffer, L, step):
|
|
||||||
obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
|
||||||
|
|
||||||
L.log('train/batch_reward', reward.mean(), step)
|
|
||||||
|
|
||||||
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
|
||||||
|
|
||||||
if step % self.actor_update_freq == 0:
|
|
||||||
self.update_actor(obs, L, step)
|
|
||||||
|
|
||||||
if step % self.target_update_freq == 0:
|
|
||||||
utils.soft_update_params(self.critic, self.critic_target, self.tau)
|
|
||||||
utils.soft_update_params(self.actor, self.actor_target, self.tau)
|
|
||||||
|
|
||||||
def save(self, model_dir, step):
|
|
||||||
torch.save(
|
|
||||||
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
|
||||||
)
|
|
||||||
torch.save(
|
|
||||||
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, model_dir, step):
|
|
||||||
self.actor.load_state_dict(
|
|
||||||
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
|
||||||
)
|
|
||||||
self.critic.load_state_dict(
|
|
||||||
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
|
||||||
)
|
|
87
train.py
87
train.py
@ -15,9 +15,7 @@ import utils
|
|||||||
from logger import Logger
|
from logger import Logger
|
||||||
from video import VideoRecorder
|
from video import VideoRecorder
|
||||||
|
|
||||||
from sac import SACAgent
|
from sac_ae import SacAeAgent
|
||||||
from td3 import TD3Agent
|
|
||||||
from ddpg import DDPGAgent
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -31,7 +29,7 @@ def parse_args():
|
|||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
||||||
# train
|
# train
|
||||||
parser.add_argument('--agent', default='sac', type=str)
|
parser.add_argument('--agent', default='sac_ae', type=str)
|
||||||
parser.add_argument('--init_steps', default=1000, type=int)
|
parser.add_argument('--init_steps', default=1000, type=int)
|
||||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||||
parser.add_argument('--batch_size', default=512, type=int)
|
parser.add_argument('--batch_size', default=512, type=int)
|
||||||
@ -51,30 +49,22 @@ def parse_args():
|
|||||||
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
||||||
parser.add_argument('--actor_update_freq', default=2, type=int)
|
parser.add_argument('--actor_update_freq', default=2, type=int)
|
||||||
# encoder/decoder
|
# encoder/decoder
|
||||||
parser.add_argument('--encoder_type', default='identity', type=str)
|
parser.add_argument('--encoder_type', default='pixel', type=str)
|
||||||
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
||||||
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
||||||
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
||||||
parser.add_argument('--decoder_type', default='identity', type=str)
|
parser.add_argument('--decoder_type', default='pixel', type=str)
|
||||||
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
||||||
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
||||||
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
|
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
|
||||||
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
||||||
parser.add_argument('--decoder_kl_lambda', default=0.0, type=float)
|
|
||||||
parser.add_argument('--num_layers', default=4, type=int)
|
parser.add_argument('--num_layers', default=4, type=int)
|
||||||
parser.add_argument('--num_filters', default=32, type=int)
|
parser.add_argument('--num_filters', default=32, type=int)
|
||||||
parser.add_argument('--freeze_encoder', default=False, action='store_true')
|
|
||||||
parser.add_argument('--use_dynamics', default=False, action='store_true')
|
|
||||||
# sac
|
# sac
|
||||||
parser.add_argument('--discount', default=0.99, type=float)
|
parser.add_argument('--discount', default=0.99, type=float)
|
||||||
parser.add_argument('--init_temperature', default=0.01, type=float)
|
parser.add_argument('--init_temperature', default=0.01, type=float)
|
||||||
parser.add_argument('--alpha_lr', default=1e-3, type=float)
|
parser.add_argument('--alpha_lr', default=1e-3, type=float)
|
||||||
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
||||||
# td3
|
|
||||||
parser.add_argument('--policy_noise', default=0.2, type=float)
|
|
||||||
parser.add_argument('--expl_noise', default=0.1, type=float)
|
|
||||||
parser.add_argument('--noise_clip', default=0.5, type=float)
|
|
||||||
parser.add_argument('--tau', default=0.005, type=float)
|
|
||||||
# misc
|
# misc
|
||||||
parser.add_argument('--seed', default=1, type=int)
|
parser.add_argument('--seed', default=1, type=int)
|
||||||
parser.add_argument('--work_dir', default='.', type=str)
|
parser.add_argument('--work_dir', default='.', type=str)
|
||||||
@ -82,8 +72,6 @@ def parse_args():
|
|||||||
parser.add_argument('--save_model', default=False, action='store_true')
|
parser.add_argument('--save_model', default=False, action='store_true')
|
||||||
parser.add_argument('--save_buffer', default=False, action='store_true')
|
parser.add_argument('--save_buffer', default=False, action='store_true')
|
||||||
parser.add_argument('--save_video', default=False, action='store_true')
|
parser.add_argument('--save_video', default=False, action='store_true')
|
||||||
parser.add_argument('--pretrained_info', default=None, type=str)
|
|
||||||
parser.add_argument('--pretrained_decoder', default=False, action='store_true')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@ -108,8 +96,8 @@ def evaluate(env, agent, video, num_episodes, L, step):
|
|||||||
|
|
||||||
|
|
||||||
def make_agent(obs_shape, state_shape, action_shape, args, device):
|
def make_agent(obs_shape, state_shape, action_shape, args, device):
|
||||||
if args.agent == 'sac':
|
if args.agent == 'sac_ae':
|
||||||
return SACAgent(
|
return SacAeAgent(
|
||||||
obs_shape=obs_shape,
|
obs_shape=obs_shape,
|
||||||
state_shape=state_shape,
|
state_shape=state_shape,
|
||||||
action_shape=action_shape,
|
action_shape=action_shape,
|
||||||
@ -137,63 +125,13 @@ def make_agent(obs_shape, state_shape, action_shape, args, device):
|
|||||||
decoder_update_freq=args.decoder_update_freq,
|
decoder_update_freq=args.decoder_update_freq,
|
||||||
decoder_latent_lambda=args.decoder_latent_lambda,
|
decoder_latent_lambda=args.decoder_latent_lambda,
|
||||||
decoder_weight_lambda=args.decoder_weight_lambda,
|
decoder_weight_lambda=args.decoder_weight_lambda,
|
||||||
decoder_kl_lambda=args.decoder_kl_lambda,
|
|
||||||
num_layers=args.num_layers,
|
num_layers=args.num_layers,
|
||||||
num_filters=args.num_filters,
|
num_filters=args.num_filters
|
||||||
freeze_encoder=args.freeze_encoder,
|
|
||||||
use_dynamics=args.use_dynamics
|
|
||||||
)
|
|
||||||
elif args.agent == 'td3':
|
|
||||||
return TD3Agent(
|
|
||||||
obs_shape=obs_shape,
|
|
||||||
action_shape=action_shape,
|
|
||||||
device=device,
|
|
||||||
discount=args.discount,
|
|
||||||
tau=args.tau,
|
|
||||||
policy_noise=args.policy_noise,
|
|
||||||
noise_clip=args.noise_clip,
|
|
||||||
expl_noise=args.expl_noise,
|
|
||||||
actor_lr=args.actor_lr,
|
|
||||||
critic_lr=args.critic_lr,
|
|
||||||
encoder_type=args.encoder_type,
|
|
||||||
encoder_feature_dim=args.encoder_feature_dim,
|
|
||||||
actor_update_freq=args.actor_update_freq,
|
|
||||||
target_update_freq=args.critic_target_update_freq
|
|
||||||
)
|
|
||||||
elif args.agent == 'ddpg':
|
|
||||||
return DDPGAgent(
|
|
||||||
obs_shape=obs_shape,
|
|
||||||
action_shape=action_shape,
|
|
||||||
device=device,
|
|
||||||
discount=args.discount,
|
|
||||||
tau=args.tau,
|
|
||||||
actor_lr=args.actor_lr,
|
|
||||||
critic_lr=args.critic_lr,
|
|
||||||
encoder_type=args.encoder_type,
|
|
||||||
encoder_feature_dim=args.encoder_feature_dim
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert 'agent is not supported: %s' % args.agent
|
assert 'agent is not supported: %s' % args.agent
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder):
|
|
||||||
path, version = pretrained_info.split(':')
|
|
||||||
|
|
||||||
pretrained_agent = copy.deepcopy(agent)
|
|
||||||
pretrained_agent.load(path, int(version))
|
|
||||||
agent.critic.encoder.load_state_dict(
|
|
||||||
pretrained_agent.critic.encoder.state_dict()
|
|
||||||
)
|
|
||||||
agent.actor.encoder.load_state_dict(
|
|
||||||
pretrained_agent.actor.encoder.state_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
if pretrained_decoder:
|
|
||||||
agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict())
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
utils.set_seed_everywhere(args.seed)
|
utils.set_seed_everywhere(args.seed)
|
||||||
@ -232,7 +170,6 @@ def main():
|
|||||||
|
|
||||||
replay_buffer = utils.ReplayBuffer(
|
replay_buffer = utils.ReplayBuffer(
|
||||||
obs_shape=env.observation_space.shape,
|
obs_shape=env.observation_space.shape,
|
||||||
state_shape=env.state_space.shape,
|
|
||||||
action_shape=env.action_space.shape,
|
action_shape=env.action_space.shape,
|
||||||
capacity=args.replay_buffer_capacity,
|
capacity=args.replay_buffer_capacity,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -241,17 +178,11 @@ def main():
|
|||||||
|
|
||||||
agent = make_agent(
|
agent = make_agent(
|
||||||
obs_shape=env.observation_space.shape,
|
obs_shape=env.observation_space.shape,
|
||||||
state_shape=env.state_space.shape,
|
|
||||||
action_shape=env.action_space.shape,
|
action_shape=env.action_space.shape,
|
||||||
args=args,
|
args=args,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.pretrained_info is not None:
|
|
||||||
agent = load_pretrained_encoder(
|
|
||||||
agent, args.pretrained_info, args.pretrained_decoder
|
|
||||||
)
|
|
||||||
|
|
||||||
L = Logger(args.work_dir, use_tb=args.save_tb)
|
L = Logger(args.work_dir, use_tb=args.save_tb)
|
||||||
|
|
||||||
episode, episode_reward, done = 0, 0, True
|
episode, episode_reward, done = 0, 0, True
|
||||||
@ -295,9 +226,7 @@ def main():
|
|||||||
for _ in range(num_updates):
|
for _ in range(num_updates):
|
||||||
agent.update(replay_buffer, L, step)
|
agent.update(replay_buffer, L, step)
|
||||||
|
|
||||||
state = env.env.env._current_state
|
|
||||||
next_obs, reward, done, _ = env.step(action)
|
next_obs, reward, done, _ = env.step(action)
|
||||||
next_state = env.env.env._current_state.shape
|
|
||||||
|
|
||||||
# allow infinit bootstrap
|
# allow infinit bootstrap
|
||||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
||||||
@ -305,7 +234,7 @@ def main():
|
|||||||
)
|
)
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
|
|
||||||
replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state)
|
replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
||||||
|
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
episode_step += 1
|
episode_step += 1
|
||||||
|
20
utils.py
20
utils.py
@ -67,10 +67,7 @@ def preprocess_obs(obs, bits=5):
|
|||||||
|
|
||||||
class ReplayBuffer(object):
|
class ReplayBuffer(object):
|
||||||
"""Buffer to store environment transitions."""
|
"""Buffer to store environment transitions."""
|
||||||
def __init__(
|
def __init__(self, obs_shape, action_shape, capacity, batch_size, device):
|
||||||
self, obs_shape, state_shape, action_shape, capacity, batch_size,
|
|
||||||
device
|
|
||||||
):
|
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -83,21 +80,17 @@ class ReplayBuffer(object):
|
|||||||
self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
|
self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
|
||||||
self.rewards = np.empty((capacity, 1), dtype=np.float32)
|
self.rewards = np.empty((capacity, 1), dtype=np.float32)
|
||||||
self.not_dones = np.empty((capacity, 1), dtype=np.float32)
|
self.not_dones = np.empty((capacity, 1), dtype=np.float32)
|
||||||
self.states = np.empty((capacity, *state_shape), dtype=np.float32)
|
|
||||||
self.next_states = np.empty((capacity, *state_shape), dtype=np.float32)
|
|
||||||
|
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
self.last_save = 0
|
self.last_save = 0
|
||||||
self.full = False
|
self.full = False
|
||||||
|
|
||||||
def add(self, obs, action, reward, next_obs, done, state, next_state):
|
def add(self, obs, action, reward, next_obs, done):
|
||||||
np.copyto(self.obses[self.idx], obs)
|
np.copyto(self.obses[self.idx], obs)
|
||||||
np.copyto(self.actions[self.idx], action)
|
np.copyto(self.actions[self.idx], action)
|
||||||
np.copyto(self.rewards[self.idx], reward)
|
np.copyto(self.rewards[self.idx], reward)
|
||||||
np.copyto(self.next_obses[self.idx], next_obs)
|
np.copyto(self.next_obses[self.idx], next_obs)
|
||||||
np.copyto(self.not_dones[self.idx], not done)
|
np.copyto(self.not_dones[self.idx], not done)
|
||||||
np.copyto(self.states[self.idx], state)
|
|
||||||
np.copyto(self.next_states[self.idx], next_state)
|
|
||||||
|
|
||||||
self.idx = (self.idx + 1) % self.capacity
|
self.idx = (self.idx + 1) % self.capacity
|
||||||
self.full = self.full or self.idx == 0
|
self.full = self.full or self.idx == 0
|
||||||
@ -114,9 +107,8 @@ class ReplayBuffer(object):
|
|||||||
self.next_obses[idxs], device=self.device
|
self.next_obses[idxs], device=self.device
|
||||||
).float()
|
).float()
|
||||||
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
|
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
|
||||||
states = torch.as_tensor(self.states[idxs], device=self.device)
|
|
||||||
|
|
||||||
return obses, actions, rewards, next_obses, not_dones, states
|
return obses, actions, rewards, next_obses, not_dones
|
||||||
|
|
||||||
def save(self, save_dir):
|
def save(self, save_dir):
|
||||||
if self.idx == self.last_save:
|
if self.idx == self.last_save:
|
||||||
@ -127,9 +119,7 @@ class ReplayBuffer(object):
|
|||||||
self.next_obses[self.last_save:self.idx],
|
self.next_obses[self.last_save:self.idx],
|
||||||
self.actions[self.last_save:self.idx],
|
self.actions[self.last_save:self.idx],
|
||||||
self.rewards[self.last_save:self.idx],
|
self.rewards[self.last_save:self.idx],
|
||||||
self.not_dones[self.last_save:self.idx],
|
self.not_dones[self.last_save:self.idx]
|
||||||
self.states[self.last_save:self.idx],
|
|
||||||
self.next_states[self.last_save:self.idx]
|
|
||||||
]
|
]
|
||||||
self.last_save = self.idx
|
self.last_save = self.idx
|
||||||
torch.save(payload, path)
|
torch.save(payload, path)
|
||||||
@ -147,8 +137,6 @@ class ReplayBuffer(object):
|
|||||||
self.actions[start:end] = payload[2]
|
self.actions[start:end] = payload[2]
|
||||||
self.rewards[start:end] = payload[3]
|
self.rewards[start:end] = payload[3]
|
||||||
self.not_dones[start:end] = payload[4]
|
self.not_dones[start:end] = payload[4]
|
||||||
self.states[start:end] = payload[5]
|
|
||||||
self.next_states[start:end] = payload[6]
|
|
||||||
self.idx = end
|
self.idx = end
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user