This commit is contained in:
Denis Yarats 2019-09-23 11:20:48 -07:00
commit 681e13b12a
12 changed files with 2059 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
__pycache__/
.ipynb_checkpoints/
runs

75
README.md Normal file
View File

@ -0,0 +1,75 @@
# Soft Actor-Critic implementaiton in PyTorch
## Running locally
To train SAC locally one can use provided `run_local.sh` script (change it to modify particular arguments):
```
./run_local.sh
```
This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command:
```
tensorboard --logdir save
```
Then open up tensorboad in your browser.
You will also see some console output, something like this:
```
| train | E: 1 | S: 1000 | D: 0.8 s | R: 0.0000 | BR: 0.0000 | ALOSS: 0.0000 | CLOSS: 0.0000 | RLOSS: 0.0000
```
This line means:
```
train - training episode
E - total number of episodes
S - total number of environment steps
D - duration in seconds to train 1 episode
R - episode reward
BR - average reward of sampled batch
ALOSS - average loss of actor
CLOSS - average loss of critic
RLOSS - average reconstruction loss (only if is trained from pixels and decoder)
```
These are just the most important number, more of all other metrics can be found in tensorboard.
Also, besides training, once in a while there is evaluation output, like this:
```
| eval | S: 0 | ER: 21.1676
```
Which just tells the expected reward `ER` evaluating current policy after `S` steps. Note that `ER` is average evaluation performance over `num_eval_episodes` episodes (usually 10).
## Running on the cluster
You can find the `run_cluster.sh` script file that allows you run training on the cluster. It is a simple bash script, that is super easy to modify. We usually run 10 different seeds for each configuration to get reliable results. For example to schedule 10 runs of `walker walk` simple do this:
```
./run_cluster.sh walker walk
```
This script will schedule 10 jobs and all the output will be stored under `./runs/walker_walk/{configuration_name}/seed_i`. The folder structure looks like this:
```
runs/
walker_walk/
sac_states/
seed_1/
id # slurm job id
stdout # standard output of your job
stderr # standard error of your jobs
run.sh # starting script
run.slrm # slurm script
eval.log # log file for evaluation
train.log # log file for training
tb/ # folder that stores tensorboard output
video/ # folder stores evaluation videos
10000.mp4 # video of one episode after 10000 steps
seed_2/
...
```
Again, you can attach tensorboard to a particular configuration, for example:
```
tensorboard --logdir runs/walker_walk/sac_states
```
For convinience, you can also use an iPython notebook to get aggregated over 10 seeds results. An example of such notebook is `runs.ipynb`
## Run entire testbed
Another scirpt that allow to run all 10 dm_control task on the cluster is here:
```
./run_all.sh
```
It will call `run_cluster.sh` for each task, so you only need to modify `run_cluster.sh` to change the hyper parameters.

209
ddpg.py Normal file
View File

@ -0,0 +1,209 @@
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from encoder import make_encoder
LOG_FREQ = 10000
class Actor(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_shape[0])
self.outputs = dict()
def forward(self, obs, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
h = F.relu(self.l1(obs))
h = F.relu(self.l2(h))
action = torch.tanh(self.l3(h))
self.outputs['mu'] = action
return action
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_actor/%s_hist' % k, v, step)
L.log_param('train_actor/fc1', self.l1, step)
L.log_param('train_actor/fc2', self.l2, step)
L.log_param('train_actor/fc3', self.l3, step)
class Critic(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], dim=1)
h = F.relu(self.l1(obs_action))
h = F.relu(self.l2(h))
q = self.l3(h)
self.outputs['q'] = q
return q
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
self.encoder.log(L, step, log_freq)
for k, v in self.outputs.items():
L.log_histogram('train_critic/%s_hist' % k, v, step)
L.log_param('train_critic/fc1', self.l1, step)
L.log_param('train_critic/fc2', self.l2, step)
L.log_param('train_critic/fc3', self.l3, step)
class DDPGAgent(object):
def __init__(
self,
obs_shape,
action_shape,
device,
discount=0.99,
tau=0.005,
actor_lr=1e-3,
critic_lr=1e-3,
encoder_type='identity',
encoder_feature_dim=50
):
self.device = device
self.discount = discount
self.tau = tau
# models
self.actor = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.actor_target = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr
)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=critic_lr
)
self.train()
self.critic_target.train()
self.actor_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
def select_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
return action.cpu().data.numpy().flatten()
def sample_action(self, obs):
return self.select_action(obs)
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
target_Q = self.critic_target(
next_obs, self.actor_target(next_obs)
)
target_Q = reward + (not_done * self.discount * target_Q)
current_Q = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q, target_Q)
L.log('train_critic/loss', critic_loss, step)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(L, step)
def update_actor(self, obs, L, step):
action = self.actor(obs, detach_encoder=True)
actor_Q = self.critic(obs, action, detach_encoder=True)
actor_loss = -actor_Q.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(L, step)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
self.update_actor(obs, L, step)
utils.soft_update_params(self.critic, self.critic_target, self.tau)
utils.soft_update_params(self.actor, self.actor_target, self.tau)
def save(self, model_dir, step):
torch.save(
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
)
torch.save(
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
)
def load(self, model_dir, step):
self.actor.load_state_dict(
torch.load('%s/actor_%s.pt' % (model_dir, step))
)
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)

106
decoder.py Normal file
View File

@ -0,0 +1,106 @@
import torch
import torch.nn as nn
from encoder import OUT_DIM
class PixelDecoder(nn.Module):
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32):
super().__init__()
self.num_layers = num_layers
self.num_filters = num_filters
self.out_dim = OUT_DIM[num_layers]
self.fc = nn.Linear(
feature_dim, num_filters * self.out_dim * self.out_dim
)
self.deconvs = nn.ModuleList()
for i in range(self.num_layers - 1):
self.deconvs.append(
nn.ConvTranspose2d(num_filters, num_filters, 3, stride=1)
)
self.deconvs.append(
nn.ConvTranspose2d(
num_filters, obs_shape[0], 3, stride=2, output_padding=1
)
)
self.outputs = dict()
def forward(self, h):
h = torch.relu(self.fc(h))
self.outputs['fc'] = h
deconv = h.view(-1, self.num_filters, self.out_dim, self.out_dim)
self.outputs['deconv1'] = deconv
for i in range(0, self.num_layers - 1):
deconv = torch.relu(self.deconvs[i](deconv))
self.outputs['deconv%s' % (i + 1)] = deconv
obs = self.deconvs[-1](deconv)
self.outputs['obs'] = obs
return obs
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_decoder/%s_hist' % k, v, step)
if len(v.shape) > 2:
L.log_image('train_decoder/%s_i' % k, v[0], step)
for i in range(self.num_layers):
L.log_param(
'train_decoder/deconv%s' % (i + 1), self.deconvs[i], step
)
L.log_param('train_decoder/fc', self.fc, step)
class StateDecoder(nn.Module):
def __init__(self, obs_shape, feature_dim):
super().__init__()
assert len(obs_shape) == 1
self.trunk = nn.Sequential(
nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024),
nn.ReLU(), nn.Linear(1024, obs_shape[0]), nn.ReLU()
)
self.outputs = dict()
def forward(self, obs, detach=False):
h = self.trunk(obs)
if detach:
h = h.detach()
self.outputs['h'] = h
return h
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
L.log_param('train_encoder/fc1', self.trunk[0], step)
L.log_param('train_encoder/fc2', self.trunk[2], step)
for k, v in self.outputs.items():
L.log_histogram('train_encoder/%s_hist' % k, v, step)
_AVAILABLE_DECODERS = {'pixel': PixelDecoder, 'state': StateDecoder}
def make_decoder(
decoder_type, obs_shape, feature_dim, num_layers, num_filters
):
assert decoder_type in _AVAILABLE_DECODERS
if decoder_type == 'pixel':
return _AVAILABLE_DECODERS[decoder_type](
obs_shape, feature_dim, num_layers, num_filters
)
return _AVAILABLE_DECODERS[decoder_type](obs_shape, feature_dim)

185
encoder.py Normal file
View File

@ -0,0 +1,185 @@
import torch
import torch.nn as nn
def tie_weights(src, trg):
assert type(src) == type(trg)
trg.weight = src.weight
trg.bias = src.bias
OUT_DIM = {2: 39, 4: 35, 6: 31}
class PixelEncoder(nn.Module):
"""Convolutional encoder of pixels observations."""
def __init__(
self,
obs_shape,
feature_dim,
num_layers=2,
num_filters=32,
stochastic=False
):
super().__init__()
assert len(obs_shape) == 3
self.feature_dim = feature_dim
self.num_layers = num_layers
self.stochastic = stochastic
self.convs = nn.ModuleList(
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
)
for i in range(num_layers - 1):
self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
out_dim = OUT_DIM[num_layers]
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
self.ln = nn.LayerNorm(self.feature_dim)
if self.stochastic:
self.log_std_min = -10
self.log_std_max = 2
self.fc_log_std = nn.Linear(
num_filters * out_dim * out_dim, self.feature_dim
)
self.outputs = dict()
def reparameterize(self, mu, logstd):
std = torch.exp(logstd)
eps = torch.randn_like(std)
return mu + eps * std
def forward_conv(self, obs):
obs = obs / 255.
self.outputs['obs'] = obs
conv = torch.relu(self.convs[0](obs))
self.outputs['conv1'] = conv
for i in range(1, self.num_layers):
conv = torch.relu(self.convs[i](conv))
self.outputs['conv%s' % (i + 1)] = conv
h = conv.view(conv.size(0), -1)
return h
def forward(self, obs, detach=False):
h = self.forward_conv(obs)
if detach:
h = h.detach()
h_fc = self.fc(h)
self.outputs['fc'] = h_fc
h_norm = self.ln(h_fc)
self.outputs['ln'] = h_norm
out = torch.tanh(h_norm)
if self.stochastic:
self.outputs['mu'] = out
log_std = torch.tanh(self.fc_log_std(h))
# normalize
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1)
out = self.reparameterize(out, log_std)
self.outputs['log_std'] = log_std
self.outputs['tanh'] = out
return out
def copy_conv_weights_from(self, source):
"""Tie convolutional layers"""
# only tie conv layers
for i in range(self.num_layers):
tie_weights(src=source.convs[i], trg=self.convs[i])
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_encoder/%s_hist' % k, v, step)
if len(v.shape) > 2:
L.log_image('train_encoder/%s_img' % k, v[0], step)
for i in range(self.num_layers):
L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
L.log_param('train_encoder/fc', self.fc, step)
L.log_param('train_encoder/ln', self.ln, step)
class StateEncoder(nn.Module):
def __init__(self, obs_shape, feature_dim):
super().__init__()
assert len(obs_shape) == 1
self.feature_dim = feature_dim
self.trunk = nn.Sequential(
nn.Linear(obs_shape[0], 256), nn.ReLU(),
nn.Linear(256, feature_dim), nn.ReLU()
)
self.outputs = dict()
def forward(self, obs, detach=False):
h = self.trunk(obs)
if detach:
h = h.detach()
self.outputs['h'] = h
return h
def copy_conv_weights_from(self, source):
pass
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
L.log_param('train_encoder/fc1', self.trunk[0], step)
L.log_param('train_encoder/fc2', self.trunk[2], step)
for k, v in self.outputs.items():
L.log_histogram('train_encoder/%s_hist' % k, v, step)
class IdentityEncoder(nn.Module):
def __init__(self, obs_shape, feature_dim):
super().__init__()
assert len(obs_shape) == 1
self.feature_dim = obs_shape[0]
def forward(self, obs, detach=False):
return obs
def copy_conv_weights_from(self, source):
pass
def log(self, L, step, log_freq):
pass
_AVAILABLE_ENCODERS = {
'pixel': PixelEncoder,
'state': StateEncoder,
'identity': IdentityEncoder
}
def make_encoder(
encoder_type, obs_shape, feature_dim, num_layers, num_filters, stochastic
):
assert encoder_type in _AVAILABLE_ENCODERS
if encoder_type == 'pixel':
return _AVAILABLE_ENCODERS[encoder_type](
obs_shape, feature_dim, num_layers, num_filters, stochastic
)
return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim)

165
logger.py Normal file
View File

@ -0,0 +1,165 @@
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import json
import os
import shutil
import torch
import torchvision
import numpy as np
from termcolor import colored
FORMAT_CONFIG = {
'rl': {
'train': [('episode', 'E', 'int'),
('step', 'S', 'int'),
('duration', 'D', 'time'),
('episode_reward', 'R', 'float'),
('batch_reward', 'BR', 'float'),
('actor_loss', 'ALOSS', 'float'),
('critic_loss', 'CLOSS', 'float'),
('ae_loss', 'RLOSS', 'float')],
'eval': [('step', 'S', 'int'),
('episode_reward', 'ER', 'float')]
}
}
class AverageMeter(object):
def __init__(self):
self._sum = 0
self._count = 0
def update(self, value, n=1):
self._sum += value
self._count += n
def value(self):
return self._sum / max(1, self._count)
class MetersGroup(object):
def __init__(self, file_name, formating):
self._file_name = file_name
if os.path.exists(file_name):
os.remove(file_name)
self._formating = formating
self._meters = defaultdict(AverageMeter)
def log(self, key, value, n=1):
self._meters[key].update(value, n)
def _prime_meters(self):
data = dict()
for key, meter in self._meters.items():
if key.startswith('train'):
key = key[len('train') + 1:]
else:
key = key[len('eval') + 1:]
key = key.replace('/', '_')
data[key] = meter.value()
return data
def _dump_to_file(self, data):
with open(self._file_name, 'a') as f:
f.write(json.dumps(data) + '\n')
def _format(self, key, value, ty):
template = '%s: '
if ty == 'int':
template += '%d'
elif ty == 'float':
template += '%.04f'
elif ty == 'time':
template += '%.01f s'
else:
raise 'invalid format type: %s' % ty
return template % (key, value)
def _dump_to_console(self, data, prefix):
prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
pieces = ['{:5}'.format(prefix)]
for key, disp_key, ty in self._formating:
value = data.get(key, 0)
pieces.append(self._format(disp_key, value, ty))
print('| %s' % (' | '.join(pieces)))
def dump(self, step, prefix):
if len(self._meters) == 0:
return
data = self._prime_meters()
data['step'] = step
self._dump_to_file(data)
self._dump_to_console(data, prefix)
self._meters.clear()
class Logger(object):
def __init__(self, log_dir, use_tb=True, config='rl'):
self._log_dir = log_dir
if use_tb:
tb_dir = os.path.join(log_dir, 'tb')
if os.path.exists(tb_dir):
shutil.rmtree(tb_dir)
self._sw = SummaryWriter(tb_dir)
else:
self._sw = None
self._train_mg = MetersGroup(
os.path.join(log_dir, 'train.log'),
formating=FORMAT_CONFIG[config]['train'])
self._eval_mg = MetersGroup(
os.path.join(log_dir, 'eval.log'),
formating=FORMAT_CONFIG[config]['eval'])
def _try_sw_log(self, key, value, step):
if self._sw is not None:
self._sw.add_scalar(key, value, step)
def _try_sw_log_image(self, key, image, step):
if self._sw is not None:
assert image.dim() == 3
grid = torchvision.utils.make_grid(image.unsqueeze(1))
self._sw.add_image(key, grid, step)
def _try_sw_log_video(self, key, frames, step):
if self._sw is not None:
frames = torch.from_numpy(np.array(frames))
frames = frames.unsqueeze(0)
self._sw.add_video(key, frames, step, fps=30)
def _try_sw_log_histogram(self, key, histogram, step):
if self._sw is not None:
self._sw.add_histogram(key, histogram, step)
def log(self, key, value, step, n=1):
assert key.startswith('train') or key.startswith('eval')
if type(value) == torch.Tensor:
value = value.item()
self._try_sw_log(key, value / n, step)
mg = self._train_mg if key.startswith('train') else self._eval_mg
mg.log(key, value, n)
def log_param(self, key, param, step):
self.log_histogram(key + '_w', param.weight.data, step)
if hasattr(param.weight, 'grad') and param.weight.grad is not None:
self.log_histogram(key + '_w_g', param.weight.grad.data, step)
if hasattr(param, 'bias'):
self.log_histogram(key + '_b', param.bias.data, step)
if hasattr(param.bias, 'grad') and param.bias.grad is not None:
self.log_histogram(key + '_b_g', param.bias.grad.data, step)
def log_image(self, key, image, step):
assert key.startswith('train') or key.startswith('eval')
self._try_sw_log_image(key, image, step)
def log_video(self, key, frames, step):
assert key.startswith('train') or key.startswith('eval')
self._try_sw_log_video(key, frames, step)
def log_histogram(self, key, histogram, step):
assert key.startswith('train') or key.startswith('eval')
self._try_sw_log_histogram(key, histogram, step)
def dump(self, step):
self._train_mg.dump(step, 'train')
self._eval_mg.dump(step, 'eval')

21
run.sh Executable file
View File

@ -0,0 +1,21 @@
#!/bin/bash
DOMAIN=cheetah
TASK=run
ACTION_REPEAT=4
ENCODER_TYPE=pixel
ENCODER_TYPE=pixel
WORK_DIR=./runs
python train.py \
--domain_name ${DOMAIN} \
--task_name ${TASK} \
--encoder_type ${ENCODER_TYPE} \
--decoder_type ${DECODER_TYPE} \
--action_repeat ${ACTION_REPEAT} \
--save_video \
--save_tb \
--work_dir ${WORK_DIR}/${DOMAIN}_{TASK}/_ae_encoder_${ENCODER_TYPE}_decoder_{ENCODER_TYPE} \
--seed 1

507
sac.py Normal file
View File

@ -0,0 +1,507 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import utils
from encoder import make_encoder
from decoder import make_decoder
LOG_FREQ = 10000
def gaussian_logprob(noise, log_std):
"""Compute Gaussian log probability."""
residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
def squash(mu, pi, log_pi):
"""Apply squashing function.
See appendix C from https://arxiv.org/pdf/1812.05905.pdf.
"""
mu = torch.tanh(mu)
if pi is not None:
pi = torch.tanh(pi)
if log_pi is not None:
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
return mu, pi, log_pi
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
assert m.weight.size(2) == m.weight.size(3)
m.weight.data.fill_(0.0)
m.bias.data.fill_(0.0)
mid = m.weight.size(2) // 2
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
class Actor(nn.Module):
"""MLP actor network."""
def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters,
freeze_encoder, stochastic
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim, num_layers,
num_filters, stochastic
)
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.freeze_encoder = freeze_encoder
self.trunk = nn.Sequential(
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 2 * action_shape[0])
)
self.outputs = dict()
self.apply(weight_init)
def forward(
self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
):
obs = self.encoder(obs, detach=detach_encoder)
if self.freeze_encoder:
obs = obs.detach()
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
# constrain log_std inside [log_std_min, log_std_max]
log_std = F.tanh(log_std)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1)
self.outputs['mu'] = mu
self.outputs['std'] = log_std.exp()
if compute_pi:
std = log_std.exp()
noise = torch.randn_like(mu)
pi = mu + noise * std
else:
pi = None
entropy = None
if compute_log_pi:
log_pi = gaussian_logprob(noise, log_std)
else:
log_pi = None
mu, pi, log_pi = squash(mu, pi, log_pi)
return mu, pi, log_pi, log_std
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_actor/%s_hist' % k, v, step)
L.log_param('train_actor/fc1', self.trunk[0], step)
L.log_param('train_actor/fc2', self.trunk[2], step)
L.log_param('train_actor/fc3', self.trunk[4], step)
class QFunction(nn.Module):
"""MLP for q-function."""
def __init__(self, obs_dim, action_dim, hidden_dim):
super().__init__()
self.trunk = nn.Sequential(
nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, obs, action):
assert obs.size(0) == action.size(0)
obs_action = torch.cat([obs, action], dim=1)
return self.trunk(obs_action)
class DynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super().__init__()
self.trunk = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, state_dim)
)
def forward(self, state, action):
assert state.size(0) == action.size(0)
state_action = torch.cat([state, action], dim=1)
return self.trunk(state_action)
class Critic(nn.Module):
"""Critic network, employes two q-functions."""
def __init__(
self, obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
use_dynamics, stochastic
):
super().__init__()
self.freeze_encoder = freeze_encoder
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim, num_layers,
num_filters, stochastic
)
if use_dynamics:
self.forward_model = DynamicsModel(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.Q1 = QFunction(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.Q2 = QFunction(
self.encoder.feature_dim, action_shape[0], hidden_dim
)
self.outputs = dict()
self.apply(weight_init)
def forward(self, obs, action, detach_encoder=False):
# detach_encoder allows to stop gradient propogation to encoder
obs = self.encoder(obs, detach=detach_encoder)
if self.freeze_encoder:
obs = obs.detach()
q1 = self.Q1(obs, action)
q2 = self.Q2(obs, action)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
self.encoder.log(L, step, log_freq)
for k, v in self.outputs.items():
L.log_histogram('train_critic/%s_hist' % k, v, step)
for i in range(3):
L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
class SACAgent(object):
"""Soft Actor-Critic algorithm."""
def __init__(
self,
obs_shape,
state_shape,
action_shape,
device,
hidden_dim=256,
discount=0.99,
init_temperature=0.01,
alpha_lr=1e-3,
alpha_beta=0.9,
actor_lr=1e-3,
actor_beta=0.9,
actor_log_std_min=-10,
actor_log_std_max=2,
actor_update_freq=2,
critic_lr=1e-3,
critic_beta=0.9,
critic_tau=0.005,
critic_target_update_freq=2,
encoder_type='identity',
encoder_feature_dim=50,
encoder_lr=1e-3,
encoder_tau=0.005,
decoder_type='identity',
decoder_lr=1e-3,
decoder_update_freq=1,
decoder_latent_lambda=0.0,
decoder_weight_lambda=0.0,
decoder_kl_lambda=0.0,
num_layers=4,
num_filters=32,
freeze_encoder=False,
use_dynamics=False
):
self.device = device
self.discount = discount
self.critic_tau = critic_tau
self.encoder_tau = encoder_tau
self.actor_update_freq = actor_update_freq
self.critic_target_update_freq = critic_target_update_freq
self.decoder_update_freq = decoder_update_freq
self.decoder_latent_lambda = decoder_latent_lambda
self.decoder_kl_lambda = decoder_kl_lambda
self.decoder_type = decoder_type
self.use_dynamics = use_dynamics
stochastic = decoder_kl_lambda > 0.0
self.actor = Actor(
obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
num_layers, num_filters, freeze_encoder, stochastic
).to(device)
self.critic = Critic(
obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
use_dynamics, stochastic
).to(device)
self.critic_target = Critic(
obs_shape, action_shape, hidden_dim, encoder_type,
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
use_dynamics, stochastic
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# tie encoders between actor and critic
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
self.log_alpha.requires_grad = True
# set target entropy to -|A|
self.target_entropy = -np.prod(action_shape)
self.decoder = None
if decoder_type != 'identity':
# create decoder
shape = obs_shape if decoder_type == 'pixel' else state_shape
self.decoder = make_decoder(
decoder_type, shape, encoder_feature_dim, num_layers,
num_filters
).to(device)
self.decoder.apply(weight_init)
# optimizer for critic encoder for reconstruction loss
self.encoder_optimizer = torch.optim.Adam(
self.critic.encoder.parameters(), lr=encoder_lr
)
# optimizer for decoder
self.decoder_optimizer = torch.optim.Adam(
self.decoder.parameters(),
lr=decoder_lr,
weight_decay=decoder_weight_lambda
)
# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
)
self.log_alpha_optimizer = torch.optim.Adam(
[self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
)
self.train()
self.critic_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
if self.decoder is not None:
self.decoder.train(training)
@property
def alpha(self):
return self.log_alpha.exp()
def select_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
mu, _, _, _ = self.actor(
obs, compute_pi=False, compute_log_pi=False
)
return mu.cpu().data.numpy().flatten()
def sample_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
return pi.cpu().data.numpy().flatten()
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
_, policy_action, log_pi, _ = self.actor(next_obs)
target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
target_V = torch.min(target_Q1,
target_Q2) - self.alpha.detach() * log_pi
target_Q = reward + (not_done * self.discount * target_V)
# get current Q estimates
current_Q1, current_Q2 = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q1,
target_Q) + F.mse_loss(current_Q2, target_Q)
L.log('train_critic/loss', critic_loss, step)
# update dynamics (optional)
if self.use_dynamics:
h_obs = self.critic.encoder.outputs['mu']
with torch.no_grad():
next_latent = self.critic.encoder(next_obs)
pred_next_latent = self.critic.forward_model(h_obs, action)
dynamics_loss = F.mse_loss(pred_next_latent, next_latent)
L.log('train_critic/dynamics_loss', dynamics_loss, step)
critic_loss += dynamics_loss
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(L, step)
def update_actor_and_alpha(self, obs, L, step):
# detach encoder, so we don't update it with the actor loss
_, pi, log_pi, log_std = self.actor(obs, detach_encoder=True)
actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
actor_Q = torch.min(actor_Q1, actor_Q2)
actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
L.log('train_actor/loss', actor_loss, step)
L.log('train_actor/target_entropy', self.target_entropy, step)
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
) + log_std.sum(dim=-1)
L.log('train_actor/entropy', entropy.mean(), step)
# optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(L, step)
self.log_alpha_optimizer.zero_grad()
alpha_loss = (self.alpha *
(-log_pi - self.target_entropy).detach()).mean()
L.log('train_alpha/loss', alpha_loss, step)
L.log('train_alpha/value', self.alpha, step)
alpha_loss.backward()
self.log_alpha_optimizer.step()
def update_decoder(self, obs, target_obs, L, step):
if self.decoder is None:
return
h = self.critic.encoder(obs)
if target_obs.dim() == 4:
# preprocess images to be in [-0.5, 0.5] range
target_obs = utils.preprocess_obs(target_obs)
rec_obs = self.decoder(h)
rec_loss = F.mse_loss(target_obs, rec_obs)
# add L2 penalty on latent representation
# see https://arxiv.org/pdf/1903.12436.pdf
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
# add KL penalty for VAE
if self.decoder_kl_lambda > 0.0:
log_std = self.critic.encoder.outputs['log_std']
mu = self.critic.encoder.outputs['mu']
kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp())
kl_div = kl_div.sum(1).mean(0, True)
else:
kl_div = 0.0
loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
loss.backward()
self.encoder_optimizer.step()
self.decoder_optimizer.step()
L.log('train_ae/ae_loss', loss, step)
self.decoder.log(L, step, log_freq=LOG_FREQ)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done, state = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
if step % self.actor_update_freq == 0:
self.update_actor_and_alpha(obs, L, step)
if step % self.critic_target_update_freq == 0:
utils.soft_update_params(
self.critic.Q1, self.critic_target.Q1, self.critic_tau
)
utils.soft_update_params(
self.critic.Q2, self.critic_target.Q2, self.critic_tau
)
utils.soft_update_params(
self.critic.encoder, self.critic_target.encoder,
self.encoder_tau
)
if step % self.decoder_update_freq == 0:
target = obs if self.decoder_type == 'pixel' else state
self.update_decoder(obs, target, L, step)
def save(self, model_dir, step):
torch.save(
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
)
torch.save(
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
)
if self.decoder is not None:
torch.save(
self.decoder.state_dict(),
'%s/decoder_%s.pt' % (model_dir, step)
)
def load(self, model_dir, step):
self.actor.load_state_dict(
torch.load('%s/actor_%s.pt' % (model_dir, step))
)
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)
if self.decoder is not None:
self.decoder.load_state_dict(
torch.load('%s/decoder_%s.pt' % (model_dir, step))
)

259
td3.py Normal file
View File

@ -0,0 +1,259 @@
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from encoder import make_encoder
LOG_FREQ = 10000
class Actor(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_shape[0])
self.outputs = dict()
def forward(self, obs, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
h = F.relu(self.l1(obs))
h = F.relu(self.l2(h))
action = torch.tanh(self.l3(h))
self.outputs['mu'] = action
return action
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_actor/%s_hist' % k, v, step)
L.log_param('train_actor/fc1', self.l1, step)
L.log_param('train_actor/fc2', self.l2, step)
L.log_param('train_actor/fc3', self.l3, step)
class Critic(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
# Q1 architecture
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)
# Q2 architecture
self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l5 = nn.Linear(400, 300)
self.l6 = nn.Linear(300, 1)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], 1)
h1 = F.relu(self.l1(obs_action))
h1 = F.relu(self.l2(h1))
q1 = self.l3(h1)
h2 = F.relu(self.l4(obs_action))
h2 = F.relu(self.l5(h2))
q2 = self.l6(h2)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
def Q1(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], 1)
h1 = F.relu(self.l1(obs_action))
h1 = F.relu(self.l2(h1))
q1 = self.l3(h1)
return q1
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0: return
self.encoder.log(L, step, log_freq)
for k, v in self.outputs.items():
L.log_histogram('train_critic/%s_hist' % k, v, step)
L.log_param('train_critic/q1_fc1', self.l1, step)
L.log_param('train_critic/q1_fc2', self.l2, step)
L.log_param('train_critic/q1_fc3', self.l3, step)
L.log_param('train_critic/q1_fc4', self.l4, step)
L.log_param('train_critic/q1_fc5', self.l5, step)
L.log_param('train_critic/q1_fc6', self.l6, step)
class TD3Agent(object):
def __init__(
self,
obs_shape,
action_shape,
device,
discount=0.99,
tau=0.005,
policy_noise=0.2,
noise_clip=0.5,
expl_noise=0.1,
actor_lr=1e-3,
critic_lr=1e-3,
encoder_type='identity',
encoder_feature_dim=50,
actor_update_freq=2,
target_update_freq=2,
):
self.device = device
self.discount = discount
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.expl_noise = expl_noise
self.actor_update_freq = actor_update_freq
self.target_update_freq = target_update_freq
# models
self.actor = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.actor_target = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr
)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=critic_lr
)
self.train()
self.critic_target.train()
self.actor_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
def select_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
return action.cpu().data.numpy().flatten()
def sample_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
noise = torch.randn_like(action) * self.expl_noise
action = (action + noise).clamp(-1.0, 1.0)
return action.cpu().data.numpy().flatten()
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
noise = torch.randn_like(action).to(self.device) * self.policy_noise
noise = noise.clamp(-self.noise_clip, self.noise_clip)
next_action = self.actor_target(next_obs) + noise
next_action = next_action.clamp(-1.0, 1.0)
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + (not_done * self.discount * target_Q)
current_Q1, current_Q2 = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q1,
target_Q) + F.mse_loss(current_Q2, target_Q)
L.log('train_critic/loss', critic_loss, step)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(L, step)
def update_actor(self, obs, L, step):
action = self.actor(obs, detach_encoder=True)
actor_Q = self.critic.Q1(obs, action, detach_encoder=True)
actor_loss = -actor_Q.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(L, step)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
if step % self.actor_update_freq == 0:
self.update_actor(obs, L, step)
if step % self.target_update_freq == 0:
utils.soft_update_params(self.critic, self.critic_target, self.tau)
utils.soft_update_params(self.actor, self.actor_target, self.tau)
def save(self, model_dir, step):
torch.save(
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
)
torch.save(
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
)
def load(self, model_dir, step):
self.actor.load_state_dict(
torch.load('%s/actor_%s.pt' % (model_dir, step))
)
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)

315
train.py Normal file
View File

@ -0,0 +1,315 @@
import numpy as np
import torch
import argparse
import os
import math
import gym
import sys
import random
import time
import json
import dmc2gym
import copy
import utils
from logger import Logger
from video import VideoRecorder
from sac import SACAgent
from td3 import TD3Agent
from ddpg import DDPGAgent
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
# train
parser.add_argument('--agent', default='sac', type=str)
parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000000, type=int)
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--hidden_dim', default=256, type=int)
# eval
parser.add_argument('--eval_freq', default=10000, type=int)
parser.add_argument('--num_eval_episodes', default=10, type=int)
# critic
parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int)
# actor
parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder
parser.add_argument('--encoder_type', default='identity', type=str)
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--decoder_type', default='identity', type=str)
parser.add_argument('--decoder_lr', default=1e-3, type=float)
parser.add_argument('--decoder_update_freq', default=1, type=int)
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--decoder_kl_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--freeze_encoder', default=False, action='store_true')
parser.add_argument('--use_dynamics', default=False, action='store_true')
# sac
parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float)
parser.add_argument('--alpha_lr', default=1e-3, type=float)
parser.add_argument('--alpha_beta', default=0.9, type=float)
# td3
parser.add_argument('--policy_noise', default=0.2, type=float)
parser.add_argument('--expl_noise', default=0.1, type=float)
parser.add_argument('--noise_clip', default=0.5, type=float)
parser.add_argument('--tau', default=0.005, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--pretrained_info', default=None, type=str)
parser.add_argument('--pretrained_decoder', default=False, action='store_true')
args = parser.parse_args()
return args
def evaluate(env, agent, video, num_episodes, L, step):
for i in range(num_episodes):
obs = env.reset()
video.init(enabled=(i == 0))
done = False
episode_reward = 0
while not done:
with utils.eval_mode(agent):
action = agent.select_action(obs)
obs, reward, done, _ = env.step(action)
video.record(env)
episode_reward += reward
video.save('%d.mp4' % step)
L.log('eval/episode_reward', episode_reward, step)
L.dump(step)
def make_agent(obs_shape, state_shape, action_shape, args, device):
if args.agent == 'sac':
return SACAgent(
obs_shape=obs_shape,
state_shape=state_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_latent_lambda=args.decoder_latent_lambda,
decoder_weight_lambda=args.decoder_weight_lambda,
decoder_kl_lambda=args.decoder_kl_lambda,
num_layers=args.num_layers,
num_filters=args.num_filters,
freeze_encoder=args.freeze_encoder,
use_dynamics=args.use_dynamics
)
elif args.agent == 'td3':
return TD3Agent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
discount=args.discount,
tau=args.tau,
policy_noise=args.policy_noise,
noise_clip=args.noise_clip,
expl_noise=args.expl_noise,
actor_lr=args.actor_lr,
critic_lr=args.critic_lr,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
actor_update_freq=args.actor_update_freq,
target_update_freq=args.critic_target_update_freq
)
elif args.agent == 'ddpg':
return DDPGAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
discount=args.discount,
tau=args.tau,
actor_lr=args.actor_lr,
critic_lr=args.critic_lr,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim
)
else:
assert 'agent is not supported: %s' % args.agent
def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder):
path, version = pretrained_info.split(':')
pretrained_agent = copy.deepcopy(agent)
pretrained_agent.load(path, int(version))
agent.critic.encoder.load_state_dict(
pretrained_agent.critic.encoder.state_dict()
)
agent.actor.encoder.load_state_dict(
pretrained_agent.actor.encoder.state_dict()
)
if pretrained_decoder:
agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict())
return agent
def main():
args = parse_args()
utils.set_seed_everywhere(args.seed)
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
)
env.seed(args.seed)
# stack several consecutive frames together
if args.encoder_type == 'pixel':
env = utils.FrameStack(env, k=args.frame_stack)
utils.make_dir(args.work_dir)
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
video = VideoRecorder(video_dir if args.save_video else None)
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# the dmc2gym wrapper standardizes actions
assert env.action_space.low.min() >= -1
assert env.action_space.high.max() <= 1
replay_buffer = utils.ReplayBuffer(
obs_shape=env.observation_space.shape,
state_shape=env.state_space.shape,
action_shape=env.action_space.shape,
capacity=args.replay_buffer_capacity,
batch_size=args.batch_size,
device=device
)
agent = make_agent(
obs_shape=env.observation_space.shape,
state_shape=env.state_space.shape,
action_shape=env.action_space.shape,
args=args,
device=device
)
if args.pretrained_info is not None:
agent = load_pretrained_encoder(
agent, args.pretrained_info, args.pretrained_decoder
)
L = Logger(args.work_dir, use_tb=args.save_tb)
episode, episode_reward, done = 0, 0, True
start_time = time.time()
for step in range(args.num_train_steps):
if done:
if step > 0:
L.log('train/duration', time.time() - start_time, step)
start_time = time.time()
L.dump(step)
# evaluate agent periodically
if step % args.eval_freq == 0:
L.log('eval/episode', episode, step)
evaluate(env, agent, video, args.num_eval_episodes, L, step)
if args.save_model:
agent.save(model_dir, step)
if args.save_buffer:
replay_buffer.save(buffer_dir)
L.log('train/episode_reward', episode_reward, step)
obs = env.reset()
done = False
episode_reward = 0
episode_step = 0
episode += 1
L.log('train/episode', episode, step)
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
action = agent.sample_action(obs)
# run training update
if step >= args.init_steps:
num_updates = args.init_steps if step == args.init_steps else 1
for _ in range(num_updates):
agent.update(replay_buffer, L, step)
state = env.env.env._current_state
next_obs, reward, done, _ = env.step(action)
next_state = env.env.env._current_state.shape
# allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
done
)
episode_reward += reward
replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state)
obs = next_obs
episode_step += 1
if __name__ == '__main__':
main()

182
utils.py Normal file
View File

@ -0,0 +1,182 @@
import torch
import numpy as np
import torch.nn as nn
import gym
import os
from collections import deque
import random
class eval_mode(object):
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(False)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def module_hash(module):
result = 0
for tensor in module.state_dict().values():
result += tensor.sum().item()
return result
def make_dir(dir_path):
try:
os.mkdir(dir_path)
except OSError:
pass
return dir_path
def preprocess_obs(obs, bits=5):
"""Preprocessing image, see https://arxiv.org/abs/1807.03039."""
bins = 2**bits
assert obs.dtype == torch.float32
if bits < 8:
obs = torch.floor(obs / 2**(8 - bits))
obs = obs / bins
obs = obs + torch.rand_like(obs) / bins
obs = obs - 0.5
return obs
class ReplayBuffer(object):
"""Buffer to store environment transitions."""
def __init__(
self, obs_shape, state_shape, action_shape, capacity, batch_size,
device
):
self.capacity = capacity
self.batch_size = batch_size
self.device = device
# the proprioceptive obs is stored as float32, pixels obs as uint8
obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
self.rewards = np.empty((capacity, 1), dtype=np.float32)
self.not_dones = np.empty((capacity, 1), dtype=np.float32)
self.states = np.empty((capacity, *state_shape), dtype=np.float32)
self.next_states = np.empty((capacity, *state_shape), dtype=np.float32)
self.idx = 0
self.last_save = 0
self.full = False
def add(self, obs, action, reward, next_obs, done, state, next_state):
np.copyto(self.obses[self.idx], obs)
np.copyto(self.actions[self.idx], action)
np.copyto(self.rewards[self.idx], reward)
np.copyto(self.next_obses[self.idx], next_obs)
np.copyto(self.not_dones[self.idx], not done)
np.copyto(self.states[self.idx], state)
np.copyto(self.next_states[self.idx], next_state)
self.idx = (self.idx + 1) % self.capacity
self.full = self.full or self.idx == 0
def sample(self):
idxs = np.random.randint(
0, self.capacity if self.full else self.idx, size=self.batch_size
)
obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
actions = torch.as_tensor(self.actions[idxs], device=self.device)
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
next_obses = torch.as_tensor(
self.next_obses[idxs], device=self.device
).float()
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
states = torch.as_tensor(self.states[idxs], device=self.device)
return obses, actions, rewards, next_obses, not_dones, states
def save(self, save_dir):
if self.idx == self.last_save:
return
path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
payload = [
self.obses[self.last_save:self.idx],
self.next_obses[self.last_save:self.idx],
self.actions[self.last_save:self.idx],
self.rewards[self.last_save:self.idx],
self.not_dones[self.last_save:self.idx],
self.states[self.last_save:self.idx],
self.next_states[self.last_save:self.idx]
]
self.last_save = self.idx
torch.save(payload, path)
def load(self, save_dir):
chunks = os.listdir(save_dir)
chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
for chunk in chucks:
start, end = [int(x) for x in chunk.split('.')[0].split('_')]
path = os.path.join(save_dir, chunk)
payload = torch.load(path)
assert self.idx == start
self.obses[start:end] = payload[0]
self.next_obses[start:end] = payload[1]
self.actions[start:end] = payload[2]
self.rewards[start:end] = payload[3]
self.not_dones[start:end] = payload[4]
self.states[start:end] = payload[5]
self.next_states[start:end] = payload[6]
self.idx = end
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self._k = k
self._frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0,
high=1,
shape=((shp[0] * k,) + shp[1:]),
dtype=env.observation_space.dtype
)
self._max_episode_steps = env._max_episode_steps
def reset(self):
obs = self.env.reset()
for _ in range(self._k):
self._frames.append(obs)
return self._get_obs()
def step(self, action):
obs, reward, done, info = self.env.step(action)
self._frames.append(obs)
return self._get_obs(), reward, done, info
def _get_obs(self):
assert len(self._frames) == self._k
return np.concatenate(list(self._frames), axis=0)

32
video.py Normal file
View File

@ -0,0 +1,32 @@
import imageio
import os
import numpy as np
class VideoRecorder(object):
def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30):
self.dir_name = dir_name
self.height = height
self.width = width
self.camera_id = camera_id
self.fps = fps
self.frames = []
def init(self, enabled=True):
self.frames = []
self.enabled = self.dir_name is not None and enabled
def record(self, env):
if self.enabled:
frame = env.render(
mode='rgb_array',
height=self.height,
width=self.width,
camera_id=self.camera_id
)
self.frames.append(frame)
def save(self, file_name):
if self.enabled:
path = os.path.join(self.dir_name, file_name)
imageio.mimsave(path, self.frames, fps=self.fps)