init
This commit is contained in:
commit
681e13b12a
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
__pycache__/
|
||||
.ipynb_checkpoints/
|
||||
runs
|
75
README.md
Normal file
75
README.md
Normal file
@ -0,0 +1,75 @@
|
||||
# Soft Actor-Critic implementaiton in PyTorch
|
||||
|
||||
|
||||
## Running locally
|
||||
To train SAC locally one can use provided `run_local.sh` script (change it to modify particular arguments):
|
||||
```
|
||||
./run_local.sh
|
||||
```
|
||||
This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command:
|
||||
```
|
||||
tensorboard --logdir save
|
||||
```
|
||||
Then open up tensorboad in your browser.
|
||||
|
||||
You will also see some console output, something like this:
|
||||
```
|
||||
| train | E: 1 | S: 1000 | D: 0.8 s | R: 0.0000 | BR: 0.0000 | ALOSS: 0.0000 | CLOSS: 0.0000 | RLOSS: 0.0000
|
||||
```
|
||||
This line means:
|
||||
```
|
||||
train - training episode
|
||||
E - total number of episodes
|
||||
S - total number of environment steps
|
||||
D - duration in seconds to train 1 episode
|
||||
R - episode reward
|
||||
BR - average reward of sampled batch
|
||||
ALOSS - average loss of actor
|
||||
CLOSS - average loss of critic
|
||||
RLOSS - average reconstruction loss (only if is trained from pixels and decoder)
|
||||
```
|
||||
These are just the most important number, more of all other metrics can be found in tensorboard.
|
||||
Also, besides training, once in a while there is evaluation output, like this:
|
||||
```
|
||||
| eval | S: 0 | ER: 21.1676
|
||||
```
|
||||
Which just tells the expected reward `ER` evaluating current policy after `S` steps. Note that `ER` is average evaluation performance over `num_eval_episodes` episodes (usually 10).
|
||||
|
||||
## Running on the cluster
|
||||
You can find the `run_cluster.sh` script file that allows you run training on the cluster. It is a simple bash script, that is super easy to modify. We usually run 10 different seeds for each configuration to get reliable results. For example to schedule 10 runs of `walker walk` simple do this:
|
||||
```
|
||||
./run_cluster.sh walker walk
|
||||
```
|
||||
This script will schedule 10 jobs and all the output will be stored under `./runs/walker_walk/{configuration_name}/seed_i`. The folder structure looks like this:
|
||||
```
|
||||
runs/
|
||||
walker_walk/
|
||||
sac_states/
|
||||
seed_1/
|
||||
id # slurm job id
|
||||
stdout # standard output of your job
|
||||
stderr # standard error of your jobs
|
||||
run.sh # starting script
|
||||
run.slrm # slurm script
|
||||
eval.log # log file for evaluation
|
||||
train.log # log file for training
|
||||
tb/ # folder that stores tensorboard output
|
||||
video/ # folder stores evaluation videos
|
||||
10000.mp4 # video of one episode after 10000 steps
|
||||
seed_2/
|
||||
...
|
||||
```
|
||||
Again, you can attach tensorboard to a particular configuration, for example:
|
||||
```
|
||||
tensorboard --logdir runs/walker_walk/sac_states
|
||||
```
|
||||
|
||||
For convinience, you can also use an iPython notebook to get aggregated over 10 seeds results. An example of such notebook is `runs.ipynb`
|
||||
|
||||
|
||||
## Run entire testbed
|
||||
Another scirpt that allow to run all 10 dm_control task on the cluster is here:
|
||||
```
|
||||
./run_all.sh
|
||||
```
|
||||
It will call `run_cluster.sh` for each task, so you only need to modify `run_cluster.sh` to change the hyper parameters.
|
209
ddpg.py
Normal file
209
ddpg.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import utils
|
||||
from encoder import make_encoder
|
||||
|
||||
LOG_FREQ = 10000
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim
|
||||
)
|
||||
|
||||
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
|
||||
self.l2 = nn.Linear(400, 300)
|
||||
self.l3 = nn.Linear(300, action_shape[0])
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def forward(self, obs, detach_encoder=False):
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
h = F.relu(self.l1(obs))
|
||||
h = F.relu(self.l2(h))
|
||||
action = torch.tanh(self.l3(h))
|
||||
self.outputs['mu'] = action
|
||||
return action
|
||||
|
||||
def log(self, L, step, log_freq=LOG_FREQ):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
||||
|
||||
L.log_param('train_actor/fc1', self.l1, step)
|
||||
L.log_param('train_actor/fc2', self.l2, step)
|
||||
L.log_param('train_actor/fc3', self.l3, step)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim
|
||||
)
|
||||
|
||||
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
||||
self.l2 = nn.Linear(400, 300)
|
||||
self.l3 = nn.Linear(300, 1)
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def forward(self, obs, action, detach_encoder=False):
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
obs_action = torch.cat([obs, action], dim=1)
|
||||
h = F.relu(self.l1(obs_action))
|
||||
h = F.relu(self.l2(h))
|
||||
q = self.l3(h)
|
||||
self.outputs['q'] = q
|
||||
return q
|
||||
|
||||
def log(self, L, step, log_freq=LOG_FREQ):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
self.encoder.log(L, step, log_freq)
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
||||
|
||||
L.log_param('train_critic/fc1', self.l1, step)
|
||||
L.log_param('train_critic/fc2', self.l2, step)
|
||||
L.log_param('train_critic/fc3', self.l3, step)
|
||||
|
||||
|
||||
class DDPGAgent(object):
|
||||
def __init__(
|
||||
self,
|
||||
obs_shape,
|
||||
action_shape,
|
||||
device,
|
||||
discount=0.99,
|
||||
tau=0.005,
|
||||
actor_lr=1e-3,
|
||||
critic_lr=1e-3,
|
||||
encoder_type='identity',
|
||||
encoder_feature_dim=50
|
||||
):
|
||||
self.device = device
|
||||
self.discount = discount
|
||||
self.tau = tau
|
||||
|
||||
# models
|
||||
self.actor = Actor(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
|
||||
self.critic = Critic(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
|
||||
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
||||
|
||||
self.actor_target = Actor(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||||
|
||||
self.critic_target = Critic(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
|
||||
# optimizers
|
||||
self.actor_optimizer = torch.optim.Adam(
|
||||
self.actor.parameters(), lr=actor_lr
|
||||
)
|
||||
|
||||
self.critic_optimizer = torch.optim.Adam(
|
||||
self.critic.parameters(), lr=critic_lr
|
||||
)
|
||||
|
||||
self.train()
|
||||
self.critic_target.train()
|
||||
self.actor_target.train()
|
||||
|
||||
def train(self, training=True):
|
||||
self.training = training
|
||||
self.actor.train(training)
|
||||
self.critic.train(training)
|
||||
|
||||
def select_action(self, obs):
|
||||
with torch.no_grad():
|
||||
obs = torch.FloatTensor(obs).to(self.device)
|
||||
obs = obs.unsqueeze(0)
|
||||
action = self.actor(obs)
|
||||
return action.cpu().data.numpy().flatten()
|
||||
|
||||
def sample_action(self, obs):
|
||||
return self.select_action(obs)
|
||||
|
||||
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
||||
with torch.no_grad():
|
||||
target_Q = self.critic_target(
|
||||
next_obs, self.actor_target(next_obs)
|
||||
)
|
||||
target_Q = reward + (not_done * self.discount * target_Q)
|
||||
|
||||
current_Q = self.critic(obs, action)
|
||||
|
||||
critic_loss = F.mse_loss(current_Q, target_Q)
|
||||
L.log('train_critic/loss', critic_loss, step)
|
||||
|
||||
self.critic_optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
self.critic.log(L, step)
|
||||
|
||||
def update_actor(self, obs, L, step):
|
||||
action = self.actor(obs, detach_encoder=True)
|
||||
actor_Q = self.critic(obs, action, detach_encoder=True)
|
||||
actor_loss = -actor_Q.mean()
|
||||
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optimizer.step()
|
||||
|
||||
self.actor.log(L, step)
|
||||
|
||||
def update(self, replay_buffer, L, step):
|
||||
obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
||||
|
||||
L.log('train/batch_reward', reward.mean(), step)
|
||||
|
||||
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
||||
self.update_actor(obs, L, step)
|
||||
|
||||
utils.soft_update_params(self.critic, self.critic_target, self.tau)
|
||||
utils.soft_update_params(self.actor, self.actor_target, self.tau)
|
||||
|
||||
def save(self, model_dir, step):
|
||||
torch.save(
|
||||
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
||||
)
|
||||
torch.save(
|
||||
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
||||
)
|
||||
|
||||
def load(self, model_dir, step):
|
||||
self.actor.load_state_dict(
|
||||
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
||||
)
|
||||
self.critic.load_state_dict(
|
||||
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
||||
)
|
106
decoder.py
Normal file
106
decoder.py
Normal file
@ -0,0 +1,106 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from encoder import OUT_DIM
|
||||
|
||||
|
||||
class PixelDecoder(nn.Module):
|
||||
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32):
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.num_filters = num_filters
|
||||
self.out_dim = OUT_DIM[num_layers]
|
||||
|
||||
self.fc = nn.Linear(
|
||||
feature_dim, num_filters * self.out_dim * self.out_dim
|
||||
)
|
||||
|
||||
self.deconvs = nn.ModuleList()
|
||||
|
||||
for i in range(self.num_layers - 1):
|
||||
self.deconvs.append(
|
||||
nn.ConvTranspose2d(num_filters, num_filters, 3, stride=1)
|
||||
)
|
||||
self.deconvs.append(
|
||||
nn.ConvTranspose2d(
|
||||
num_filters, obs_shape[0], 3, stride=2, output_padding=1
|
||||
)
|
||||
)
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def forward(self, h):
|
||||
h = torch.relu(self.fc(h))
|
||||
self.outputs['fc'] = h
|
||||
|
||||
deconv = h.view(-1, self.num_filters, self.out_dim, self.out_dim)
|
||||
self.outputs['deconv1'] = deconv
|
||||
|
||||
for i in range(0, self.num_layers - 1):
|
||||
deconv = torch.relu(self.deconvs[i](deconv))
|
||||
self.outputs['deconv%s' % (i + 1)] = deconv
|
||||
|
||||
obs = self.deconvs[-1](deconv)
|
||||
self.outputs['obs'] = obs
|
||||
|
||||
return obs
|
||||
|
||||
def log(self, L, step, log_freq):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_decoder/%s_hist' % k, v, step)
|
||||
if len(v.shape) > 2:
|
||||
L.log_image('train_decoder/%s_i' % k, v[0], step)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
L.log_param(
|
||||
'train_decoder/deconv%s' % (i + 1), self.deconvs[i], step
|
||||
)
|
||||
L.log_param('train_decoder/fc', self.fc, step)
|
||||
|
||||
|
||||
class StateDecoder(nn.Module):
|
||||
def __init__(self, obs_shape, feature_dim):
|
||||
super().__init__()
|
||||
|
||||
assert len(obs_shape) == 1
|
||||
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024),
|
||||
nn.ReLU(), nn.Linear(1024, obs_shape[0]), nn.ReLU()
|
||||
)
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def forward(self, obs, detach=False):
|
||||
h = self.trunk(obs)
|
||||
if detach:
|
||||
h = h.detach()
|
||||
self.outputs['h'] = h
|
||||
return h
|
||||
|
||||
def log(self, L, step, log_freq):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
L.log_param('train_encoder/fc1', self.trunk[0], step)
|
||||
L.log_param('train_encoder/fc2', self.trunk[2], step)
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_encoder/%s_hist' % k, v, step)
|
||||
|
||||
|
||||
_AVAILABLE_DECODERS = {'pixel': PixelDecoder, 'state': StateDecoder}
|
||||
|
||||
|
||||
def make_decoder(
|
||||
decoder_type, obs_shape, feature_dim, num_layers, num_filters
|
||||
):
|
||||
assert decoder_type in _AVAILABLE_DECODERS
|
||||
if decoder_type == 'pixel':
|
||||
return _AVAILABLE_DECODERS[decoder_type](
|
||||
obs_shape, feature_dim, num_layers, num_filters
|
||||
)
|
||||
return _AVAILABLE_DECODERS[decoder_type](obs_shape, feature_dim)
|
185
encoder.py
Normal file
185
encoder.py
Normal file
@ -0,0 +1,185 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def tie_weights(src, trg):
|
||||
assert type(src) == type(trg)
|
||||
trg.weight = src.weight
|
||||
trg.bias = src.bias
|
||||
|
||||
|
||||
OUT_DIM = {2: 39, 4: 35, 6: 31}
|
||||
|
||||
|
||||
class PixelEncoder(nn.Module):
|
||||
"""Convolutional encoder of pixels observations."""
|
||||
def __init__(
|
||||
self,
|
||||
obs_shape,
|
||||
feature_dim,
|
||||
num_layers=2,
|
||||
num_filters=32,
|
||||
stochastic=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert len(obs_shape) == 3
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_layers = num_layers
|
||||
self.stochastic = stochastic
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
|
||||
)
|
||||
for i in range(num_layers - 1):
|
||||
self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
|
||||
|
||||
out_dim = OUT_DIM[num_layers]
|
||||
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
|
||||
self.ln = nn.LayerNorm(self.feature_dim)
|
||||
|
||||
if self.stochastic:
|
||||
self.log_std_min = -10
|
||||
self.log_std_max = 2
|
||||
self.fc_log_std = nn.Linear(
|
||||
num_filters * out_dim * out_dim, self.feature_dim
|
||||
)
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def reparameterize(self, mu, logstd):
|
||||
std = torch.exp(logstd)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def forward_conv(self, obs):
|
||||
obs = obs / 255.
|
||||
self.outputs['obs'] = obs
|
||||
|
||||
conv = torch.relu(self.convs[0](obs))
|
||||
self.outputs['conv1'] = conv
|
||||
|
||||
for i in range(1, self.num_layers):
|
||||
conv = torch.relu(self.convs[i](conv))
|
||||
self.outputs['conv%s' % (i + 1)] = conv
|
||||
|
||||
h = conv.view(conv.size(0), -1)
|
||||
return h
|
||||
|
||||
def forward(self, obs, detach=False):
|
||||
h = self.forward_conv(obs)
|
||||
|
||||
if detach:
|
||||
h = h.detach()
|
||||
|
||||
h_fc = self.fc(h)
|
||||
self.outputs['fc'] = h_fc
|
||||
|
||||
h_norm = self.ln(h_fc)
|
||||
self.outputs['ln'] = h_norm
|
||||
|
||||
out = torch.tanh(h_norm)
|
||||
|
||||
if self.stochastic:
|
||||
self.outputs['mu'] = out
|
||||
log_std = torch.tanh(self.fc_log_std(h))
|
||||
# normalize
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1)
|
||||
out = self.reparameterize(out, log_std)
|
||||
self.outputs['log_std'] = log_std
|
||||
|
||||
self.outputs['tanh'] = out
|
||||
|
||||
return out
|
||||
|
||||
def copy_conv_weights_from(self, source):
|
||||
"""Tie convolutional layers"""
|
||||
# only tie conv layers
|
||||
for i in range(self.num_layers):
|
||||
tie_weights(src=source.convs[i], trg=self.convs[i])
|
||||
|
||||
def log(self, L, step, log_freq):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_encoder/%s_hist' % k, v, step)
|
||||
if len(v.shape) > 2:
|
||||
L.log_image('train_encoder/%s_img' % k, v[0], step)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
|
||||
L.log_param('train_encoder/fc', self.fc, step)
|
||||
L.log_param('train_encoder/ln', self.ln, step)
|
||||
|
||||
|
||||
class StateEncoder(nn.Module):
|
||||
def __init__(self, obs_shape, feature_dim):
|
||||
super().__init__()
|
||||
|
||||
assert len(obs_shape) == 1
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(obs_shape[0], 256), nn.ReLU(),
|
||||
nn.Linear(256, feature_dim), nn.ReLU()
|
||||
)
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def forward(self, obs, detach=False):
|
||||
h = self.trunk(obs)
|
||||
if detach:
|
||||
h = h.detach()
|
||||
self.outputs['h'] = h
|
||||
return h
|
||||
|
||||
def copy_conv_weights_from(self, source):
|
||||
pass
|
||||
|
||||
def log(self, L, step, log_freq):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
L.log_param('train_encoder/fc1', self.trunk[0], step)
|
||||
L.log_param('train_encoder/fc2', self.trunk[2], step)
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_encoder/%s_hist' % k, v, step)
|
||||
|
||||
|
||||
class IdentityEncoder(nn.Module):
|
||||
def __init__(self, obs_shape, feature_dim):
|
||||
super().__init__()
|
||||
|
||||
assert len(obs_shape) == 1
|
||||
self.feature_dim = obs_shape[0]
|
||||
|
||||
def forward(self, obs, detach=False):
|
||||
return obs
|
||||
|
||||
def copy_conv_weights_from(self, source):
|
||||
pass
|
||||
|
||||
def log(self, L, step, log_freq):
|
||||
pass
|
||||
|
||||
|
||||
_AVAILABLE_ENCODERS = {
|
||||
'pixel': PixelEncoder,
|
||||
'state': StateEncoder,
|
||||
'identity': IdentityEncoder
|
||||
}
|
||||
|
||||
|
||||
def make_encoder(
|
||||
encoder_type, obs_shape, feature_dim, num_layers, num_filters, stochastic
|
||||
):
|
||||
assert encoder_type in _AVAILABLE_ENCODERS
|
||||
if encoder_type == 'pixel':
|
||||
return _AVAILABLE_ENCODERS[encoder_type](
|
||||
obs_shape, feature_dim, num_layers, num_filters, stochastic
|
||||
)
|
||||
return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim)
|
165
logger.py
Normal file
165
logger.py
Normal file
@ -0,0 +1,165 @@
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import torch
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
FORMAT_CONFIG = {
|
||||
'rl': {
|
||||
'train': [('episode', 'E', 'int'),
|
||||
('step', 'S', 'int'),
|
||||
('duration', 'D', 'time'),
|
||||
('episode_reward', 'R', 'float'),
|
||||
('batch_reward', 'BR', 'float'),
|
||||
('actor_loss', 'ALOSS', 'float'),
|
||||
('critic_loss', 'CLOSS', 'float'),
|
||||
('ae_loss', 'RLOSS', 'float')],
|
||||
'eval': [('step', 'S', 'int'),
|
||||
('episode_reward', 'ER', 'float')]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
def __init__(self):
|
||||
self._sum = 0
|
||||
self._count = 0
|
||||
|
||||
def update(self, value, n=1):
|
||||
self._sum += value
|
||||
self._count += n
|
||||
|
||||
def value(self):
|
||||
return self._sum / max(1, self._count)
|
||||
|
||||
|
||||
class MetersGroup(object):
|
||||
def __init__(self, file_name, formating):
|
||||
self._file_name = file_name
|
||||
if os.path.exists(file_name):
|
||||
os.remove(file_name)
|
||||
self._formating = formating
|
||||
self._meters = defaultdict(AverageMeter)
|
||||
|
||||
def log(self, key, value, n=1):
|
||||
self._meters[key].update(value, n)
|
||||
|
||||
def _prime_meters(self):
|
||||
data = dict()
|
||||
for key, meter in self._meters.items():
|
||||
if key.startswith('train'):
|
||||
key = key[len('train') + 1:]
|
||||
else:
|
||||
key = key[len('eval') + 1:]
|
||||
key = key.replace('/', '_')
|
||||
data[key] = meter.value()
|
||||
return data
|
||||
|
||||
def _dump_to_file(self, data):
|
||||
with open(self._file_name, 'a') as f:
|
||||
f.write(json.dumps(data) + '\n')
|
||||
|
||||
def _format(self, key, value, ty):
|
||||
template = '%s: '
|
||||
if ty == 'int':
|
||||
template += '%d'
|
||||
elif ty == 'float':
|
||||
template += '%.04f'
|
||||
elif ty == 'time':
|
||||
template += '%.01f s'
|
||||
else:
|
||||
raise 'invalid format type: %s' % ty
|
||||
return template % (key, value)
|
||||
|
||||
def _dump_to_console(self, data, prefix):
|
||||
prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
|
||||
pieces = ['{:5}'.format(prefix)]
|
||||
for key, disp_key, ty in self._formating:
|
||||
value = data.get(key, 0)
|
||||
pieces.append(self._format(disp_key, value, ty))
|
||||
print('| %s' % (' | '.join(pieces)))
|
||||
|
||||
def dump(self, step, prefix):
|
||||
if len(self._meters) == 0:
|
||||
return
|
||||
data = self._prime_meters()
|
||||
data['step'] = step
|
||||
self._dump_to_file(data)
|
||||
self._dump_to_console(data, prefix)
|
||||
self._meters.clear()
|
||||
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, log_dir, use_tb=True, config='rl'):
|
||||
self._log_dir = log_dir
|
||||
if use_tb:
|
||||
tb_dir = os.path.join(log_dir, 'tb')
|
||||
if os.path.exists(tb_dir):
|
||||
shutil.rmtree(tb_dir)
|
||||
self._sw = SummaryWriter(tb_dir)
|
||||
else:
|
||||
self._sw = None
|
||||
self._train_mg = MetersGroup(
|
||||
os.path.join(log_dir, 'train.log'),
|
||||
formating=FORMAT_CONFIG[config]['train'])
|
||||
self._eval_mg = MetersGroup(
|
||||
os.path.join(log_dir, 'eval.log'),
|
||||
formating=FORMAT_CONFIG[config]['eval'])
|
||||
|
||||
def _try_sw_log(self, key, value, step):
|
||||
if self._sw is not None:
|
||||
self._sw.add_scalar(key, value, step)
|
||||
|
||||
def _try_sw_log_image(self, key, image, step):
|
||||
if self._sw is not None:
|
||||
assert image.dim() == 3
|
||||
grid = torchvision.utils.make_grid(image.unsqueeze(1))
|
||||
self._sw.add_image(key, grid, step)
|
||||
|
||||
def _try_sw_log_video(self, key, frames, step):
|
||||
if self._sw is not None:
|
||||
frames = torch.from_numpy(np.array(frames))
|
||||
frames = frames.unsqueeze(0)
|
||||
self._sw.add_video(key, frames, step, fps=30)
|
||||
|
||||
def _try_sw_log_histogram(self, key, histogram, step):
|
||||
if self._sw is not None:
|
||||
self._sw.add_histogram(key, histogram, step)
|
||||
|
||||
def log(self, key, value, step, n=1):
|
||||
assert key.startswith('train') or key.startswith('eval')
|
||||
if type(value) == torch.Tensor:
|
||||
value = value.item()
|
||||
self._try_sw_log(key, value / n, step)
|
||||
mg = self._train_mg if key.startswith('train') else self._eval_mg
|
||||
mg.log(key, value, n)
|
||||
|
||||
def log_param(self, key, param, step):
|
||||
self.log_histogram(key + '_w', param.weight.data, step)
|
||||
if hasattr(param.weight, 'grad') and param.weight.grad is not None:
|
||||
self.log_histogram(key + '_w_g', param.weight.grad.data, step)
|
||||
if hasattr(param, 'bias'):
|
||||
self.log_histogram(key + '_b', param.bias.data, step)
|
||||
if hasattr(param.bias, 'grad') and param.bias.grad is not None:
|
||||
self.log_histogram(key + '_b_g', param.bias.grad.data, step)
|
||||
|
||||
def log_image(self, key, image, step):
|
||||
assert key.startswith('train') or key.startswith('eval')
|
||||
self._try_sw_log_image(key, image, step)
|
||||
|
||||
def log_video(self, key, frames, step):
|
||||
assert key.startswith('train') or key.startswith('eval')
|
||||
self._try_sw_log_video(key, frames, step)
|
||||
|
||||
def log_histogram(self, key, histogram, step):
|
||||
assert key.startswith('train') or key.startswith('eval')
|
||||
self._try_sw_log_histogram(key, histogram, step)
|
||||
|
||||
def dump(self, step):
|
||||
self._train_mg.dump(step, 'train')
|
||||
self._eval_mg.dump(step, 'eval')
|
21
run.sh
Executable file
21
run.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
DOMAIN=cheetah
|
||||
TASK=run
|
||||
ACTION_REPEAT=4
|
||||
ENCODER_TYPE=pixel
|
||||
ENCODER_TYPE=pixel
|
||||
|
||||
|
||||
WORK_DIR=./runs
|
||||
|
||||
python train.py \
|
||||
--domain_name ${DOMAIN} \
|
||||
--task_name ${TASK} \
|
||||
--encoder_type ${ENCODER_TYPE} \
|
||||
--decoder_type ${DECODER_TYPE} \
|
||||
--action_repeat ${ACTION_REPEAT} \
|
||||
--save_video \
|
||||
--save_tb \
|
||||
--work_dir ${WORK_DIR}/${DOMAIN}_{TASK}/_ae_encoder_${ENCODER_TYPE}_decoder_{ENCODER_TYPE} \
|
||||
--seed 1
|
507
sac.py
Normal file
507
sac.py
Normal file
@ -0,0 +1,507 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
import math
|
||||
|
||||
import utils
|
||||
from encoder import make_encoder
|
||||
from decoder import make_decoder
|
||||
|
||||
LOG_FREQ = 10000
|
||||
|
||||
|
||||
def gaussian_logprob(noise, log_std):
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
|
||||
return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
|
||||
|
||||
|
||||
def squash(mu, pi, log_pi):
|
||||
"""Apply squashing function.
|
||||
See appendix C from https://arxiv.org/pdf/1812.05905.pdf.
|
||||
"""
|
||||
mu = torch.tanh(mu)
|
||||
if pi is not None:
|
||||
pi = torch.tanh(pi)
|
||||
if log_pi is not None:
|
||||
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
def weight_init(m):
|
||||
"""Custom weight init for Conv2D and Linear layers."""
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
|
||||
assert m.weight.size(2) == m.weight.size(3)
|
||||
m.weight.data.fill_(0.0)
|
||||
m.bias.data.fill_(0.0)
|
||||
mid = m.weight.size(2) // 2
|
||||
gain = nn.init.calculate_gain('relu')
|
||||
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
"""MLP actor network."""
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters,
|
||||
freeze_encoder, stochastic
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||
num_filters, stochastic
|
||||
)
|
||||
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.freeze_encoder = freeze_encoder
|
||||
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 2 * action_shape[0])
|
||||
)
|
||||
|
||||
self.outputs = dict()
|
||||
self.apply(weight_init)
|
||||
|
||||
def forward(
|
||||
self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
|
||||
):
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
|
||||
if self.freeze_encoder:
|
||||
obs = obs.detach()
|
||||
|
||||
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
|
||||
|
||||
# constrain log_std inside [log_std_min, log_std_max]
|
||||
log_std = F.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1)
|
||||
|
||||
self.outputs['mu'] = mu
|
||||
self.outputs['std'] = log_std.exp()
|
||||
|
||||
if compute_pi:
|
||||
std = log_std.exp()
|
||||
noise = torch.randn_like(mu)
|
||||
pi = mu + noise * std
|
||||
else:
|
||||
pi = None
|
||||
entropy = None
|
||||
|
||||
if compute_log_pi:
|
||||
log_pi = gaussian_logprob(noise, log_std)
|
||||
else:
|
||||
log_pi = None
|
||||
|
||||
mu, pi, log_pi = squash(mu, pi, log_pi)
|
||||
|
||||
return mu, pi, log_pi, log_std
|
||||
|
||||
def log(self, L, step, log_freq=LOG_FREQ):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
||||
|
||||
L.log_param('train_actor/fc1', self.trunk[0], step)
|
||||
L.log_param('train_actor/fc2', self.trunk[2], step)
|
||||
L.log_param('train_actor/fc3', self.trunk[4], step)
|
||||
|
||||
|
||||
class QFunction(nn.Module):
|
||||
"""MLP for q-function."""
|
||||
def __init__(self, obs_dim, action_dim, hidden_dim):
|
||||
super().__init__()
|
||||
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
def forward(self, obs, action):
|
||||
assert obs.size(0) == action.size(0)
|
||||
|
||||
obs_action = torch.cat([obs, action], dim=1)
|
||||
return self.trunk(obs_action)
|
||||
|
||||
|
||||
class DynamicsModel(nn.Module):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim):
|
||||
super().__init__()
|
||||
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, state_dim)
|
||||
)
|
||||
|
||||
def forward(self, state, action):
|
||||
assert state.size(0) == action.size(0)
|
||||
|
||||
state_action = torch.cat([state, action], dim=1)
|
||||
return self.trunk(state_action)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""Critic network, employes two q-functions."""
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
||||
use_dynamics, stochastic
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.freeze_encoder = freeze_encoder
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||
num_filters, stochastic
|
||||
)
|
||||
|
||||
if use_dynamics:
|
||||
self.forward_model = DynamicsModel(
|
||||
self.encoder.feature_dim, action_shape[0], hidden_dim
|
||||
)
|
||||
|
||||
self.Q1 = QFunction(
|
||||
self.encoder.feature_dim, action_shape[0], hidden_dim
|
||||
)
|
||||
self.Q2 = QFunction(
|
||||
self.encoder.feature_dim, action_shape[0], hidden_dim
|
||||
)
|
||||
|
||||
self.outputs = dict()
|
||||
self.apply(weight_init)
|
||||
|
||||
def forward(self, obs, action, detach_encoder=False):
|
||||
# detach_encoder allows to stop gradient propogation to encoder
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
|
||||
if self.freeze_encoder:
|
||||
obs = obs.detach()
|
||||
|
||||
q1 = self.Q1(obs, action)
|
||||
q2 = self.Q2(obs, action)
|
||||
|
||||
self.outputs['q1'] = q1
|
||||
self.outputs['q2'] = q2
|
||||
|
||||
return q1, q2
|
||||
|
||||
def log(self, L, step, log_freq=LOG_FREQ):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
self.encoder.log(L, step, log_freq)
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
||||
|
||||
for i in range(3):
|
||||
L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
|
||||
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
|
||||
|
||||
|
||||
class SACAgent(object):
|
||||
"""Soft Actor-Critic algorithm."""
|
||||
def __init__(
|
||||
self,
|
||||
obs_shape,
|
||||
state_shape,
|
||||
action_shape,
|
||||
device,
|
||||
hidden_dim=256,
|
||||
discount=0.99,
|
||||
init_temperature=0.01,
|
||||
alpha_lr=1e-3,
|
||||
alpha_beta=0.9,
|
||||
actor_lr=1e-3,
|
||||
actor_beta=0.9,
|
||||
actor_log_std_min=-10,
|
||||
actor_log_std_max=2,
|
||||
actor_update_freq=2,
|
||||
critic_lr=1e-3,
|
||||
critic_beta=0.9,
|
||||
critic_tau=0.005,
|
||||
critic_target_update_freq=2,
|
||||
encoder_type='identity',
|
||||
encoder_feature_dim=50,
|
||||
encoder_lr=1e-3,
|
||||
encoder_tau=0.005,
|
||||
decoder_type='identity',
|
||||
decoder_lr=1e-3,
|
||||
decoder_update_freq=1,
|
||||
decoder_latent_lambda=0.0,
|
||||
decoder_weight_lambda=0.0,
|
||||
decoder_kl_lambda=0.0,
|
||||
num_layers=4,
|
||||
num_filters=32,
|
||||
freeze_encoder=False,
|
||||
use_dynamics=False
|
||||
):
|
||||
self.device = device
|
||||
self.discount = discount
|
||||
self.critic_tau = critic_tau
|
||||
self.encoder_tau = encoder_tau
|
||||
self.actor_update_freq = actor_update_freq
|
||||
self.critic_target_update_freq = critic_target_update_freq
|
||||
self.decoder_update_freq = decoder_update_freq
|
||||
self.decoder_latent_lambda = decoder_latent_lambda
|
||||
self.decoder_kl_lambda = decoder_kl_lambda
|
||||
self.decoder_type = decoder_type
|
||||
self.use_dynamics = use_dynamics
|
||||
|
||||
stochastic = decoder_kl_lambda > 0.0
|
||||
|
||||
self.actor = Actor(
|
||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
||||
num_layers, num_filters, freeze_encoder, stochastic
|
||||
).to(device)
|
||||
|
||||
self.critic = Critic(
|
||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
||||
use_dynamics, stochastic
|
||||
).to(device)
|
||||
|
||||
self.critic_target = Critic(
|
||||
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||
encoder_feature_dim, num_layers, num_filters, freeze_encoder,
|
||||
use_dynamics, stochastic
|
||||
).to(device)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
|
||||
# tie encoders between actor and critic
|
||||
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
||||
|
||||
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
|
||||
self.log_alpha.requires_grad = True
|
||||
# set target entropy to -|A|
|
||||
self.target_entropy = -np.prod(action_shape)
|
||||
|
||||
self.decoder = None
|
||||
if decoder_type != 'identity':
|
||||
# create decoder
|
||||
shape = obs_shape if decoder_type == 'pixel' else state_shape
|
||||
self.decoder = make_decoder(
|
||||
decoder_type, shape, encoder_feature_dim, num_layers,
|
||||
num_filters
|
||||
).to(device)
|
||||
self.decoder.apply(weight_init)
|
||||
|
||||
# optimizer for critic encoder for reconstruction loss
|
||||
self.encoder_optimizer = torch.optim.Adam(
|
||||
self.critic.encoder.parameters(), lr=encoder_lr
|
||||
)
|
||||
|
||||
# optimizer for decoder
|
||||
self.decoder_optimizer = torch.optim.Adam(
|
||||
self.decoder.parameters(),
|
||||
lr=decoder_lr,
|
||||
weight_decay=decoder_weight_lambda
|
||||
)
|
||||
|
||||
# optimizers
|
||||
self.actor_optimizer = torch.optim.Adam(
|
||||
self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
|
||||
)
|
||||
|
||||
self.critic_optimizer = torch.optim.Adam(
|
||||
self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
|
||||
)
|
||||
|
||||
self.log_alpha_optimizer = torch.optim.Adam(
|
||||
[self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
|
||||
)
|
||||
|
||||
self.train()
|
||||
self.critic_target.train()
|
||||
|
||||
def train(self, training=True):
|
||||
self.training = training
|
||||
self.actor.train(training)
|
||||
self.critic.train(training)
|
||||
if self.decoder is not None:
|
||||
self.decoder.train(training)
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
return self.log_alpha.exp()
|
||||
|
||||
def select_action(self, obs):
|
||||
with torch.no_grad():
|
||||
obs = torch.FloatTensor(obs).to(self.device)
|
||||
obs = obs.unsqueeze(0)
|
||||
mu, _, _, _ = self.actor(
|
||||
obs, compute_pi=False, compute_log_pi=False
|
||||
)
|
||||
return mu.cpu().data.numpy().flatten()
|
||||
|
||||
def sample_action(self, obs):
|
||||
with torch.no_grad():
|
||||
obs = torch.FloatTensor(obs).to(self.device)
|
||||
obs = obs.unsqueeze(0)
|
||||
mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
|
||||
return pi.cpu().data.numpy().flatten()
|
||||
|
||||
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
||||
with torch.no_grad():
|
||||
_, policy_action, log_pi, _ = self.actor(next_obs)
|
||||
target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
|
||||
target_V = torch.min(target_Q1,
|
||||
target_Q2) - self.alpha.detach() * log_pi
|
||||
target_Q = reward + (not_done * self.discount * target_V)
|
||||
|
||||
# get current Q estimates
|
||||
current_Q1, current_Q2 = self.critic(obs, action)
|
||||
critic_loss = F.mse_loss(current_Q1,
|
||||
target_Q) + F.mse_loss(current_Q2, target_Q)
|
||||
L.log('train_critic/loss', critic_loss, step)
|
||||
|
||||
# update dynamics (optional)
|
||||
if self.use_dynamics:
|
||||
h_obs = self.critic.encoder.outputs['mu']
|
||||
with torch.no_grad():
|
||||
next_latent = self.critic.encoder(next_obs)
|
||||
pred_next_latent = self.critic.forward_model(h_obs, action)
|
||||
dynamics_loss = F.mse_loss(pred_next_latent, next_latent)
|
||||
L.log('train_critic/dynamics_loss', dynamics_loss, step)
|
||||
critic_loss += dynamics_loss
|
||||
|
||||
# Optimize the critic
|
||||
self.critic_optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
self.critic.log(L, step)
|
||||
|
||||
def update_actor_and_alpha(self, obs, L, step):
|
||||
# detach encoder, so we don't update it with the actor loss
|
||||
_, pi, log_pi, log_std = self.actor(obs, detach_encoder=True)
|
||||
actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
|
||||
|
||||
actor_Q = torch.min(actor_Q1, actor_Q2)
|
||||
actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
|
||||
|
||||
L.log('train_actor/loss', actor_loss, step)
|
||||
L.log('train_actor/target_entropy', self.target_entropy, step)
|
||||
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
|
||||
) + log_std.sum(dim=-1)
|
||||
L.log('train_actor/entropy', entropy.mean(), step)
|
||||
|
||||
# optimize the actor
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optimizer.step()
|
||||
|
||||
self.actor.log(L, step)
|
||||
|
||||
self.log_alpha_optimizer.zero_grad()
|
||||
alpha_loss = (self.alpha *
|
||||
(-log_pi - self.target_entropy).detach()).mean()
|
||||
L.log('train_alpha/loss', alpha_loss, step)
|
||||
L.log('train_alpha/value', self.alpha, step)
|
||||
alpha_loss.backward()
|
||||
self.log_alpha_optimizer.step()
|
||||
|
||||
def update_decoder(self, obs, target_obs, L, step):
|
||||
if self.decoder is None:
|
||||
return
|
||||
|
||||
h = self.critic.encoder(obs)
|
||||
|
||||
if target_obs.dim() == 4:
|
||||
# preprocess images to be in [-0.5, 0.5] range
|
||||
target_obs = utils.preprocess_obs(target_obs)
|
||||
rec_obs = self.decoder(h)
|
||||
rec_loss = F.mse_loss(target_obs, rec_obs)
|
||||
|
||||
# add L2 penalty on latent representation
|
||||
# see https://arxiv.org/pdf/1903.12436.pdf
|
||||
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
|
||||
|
||||
# add KL penalty for VAE
|
||||
if self.decoder_kl_lambda > 0.0:
|
||||
log_std = self.critic.encoder.outputs['log_std']
|
||||
mu = self.critic.encoder.outputs['mu']
|
||||
kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp())
|
||||
kl_div = kl_div.sum(1).mean(0, True)
|
||||
else:
|
||||
kl_div = 0.0
|
||||
loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div
|
||||
|
||||
self.encoder_optimizer.zero_grad()
|
||||
self.decoder_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
self.encoder_optimizer.step()
|
||||
self.decoder_optimizer.step()
|
||||
L.log('train_ae/ae_loss', loss, step)
|
||||
|
||||
self.decoder.log(L, step, log_freq=LOG_FREQ)
|
||||
|
||||
def update(self, replay_buffer, L, step):
|
||||
obs, action, reward, next_obs, not_done, state = replay_buffer.sample()
|
||||
|
||||
L.log('train/batch_reward', reward.mean(), step)
|
||||
|
||||
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
||||
|
||||
if step % self.actor_update_freq == 0:
|
||||
self.update_actor_and_alpha(obs, L, step)
|
||||
|
||||
if step % self.critic_target_update_freq == 0:
|
||||
utils.soft_update_params(
|
||||
self.critic.Q1, self.critic_target.Q1, self.critic_tau
|
||||
)
|
||||
utils.soft_update_params(
|
||||
self.critic.Q2, self.critic_target.Q2, self.critic_tau
|
||||
)
|
||||
utils.soft_update_params(
|
||||
self.critic.encoder, self.critic_target.encoder,
|
||||
self.encoder_tau
|
||||
)
|
||||
|
||||
if step % self.decoder_update_freq == 0:
|
||||
target = obs if self.decoder_type == 'pixel' else state
|
||||
self.update_decoder(obs, target, L, step)
|
||||
|
||||
def save(self, model_dir, step):
|
||||
torch.save(
|
||||
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
||||
)
|
||||
torch.save(
|
||||
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
||||
)
|
||||
if self.decoder is not None:
|
||||
torch.save(
|
||||
self.decoder.state_dict(),
|
||||
'%s/decoder_%s.pt' % (model_dir, step)
|
||||
)
|
||||
|
||||
def load(self, model_dir, step):
|
||||
self.actor.load_state_dict(
|
||||
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
||||
)
|
||||
self.critic.load_state_dict(
|
||||
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
||||
)
|
||||
if self.decoder is not None:
|
||||
self.decoder.load_state_dict(
|
||||
torch.load('%s/decoder_%s.pt' % (model_dir, step))
|
||||
)
|
259
td3.py
Normal file
259
td3.py
Normal file
@ -0,0 +1,259 @@
|
||||
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import utils
|
||||
from encoder import make_encoder
|
||||
|
||||
LOG_FREQ = 10000
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim
|
||||
)
|
||||
|
||||
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
|
||||
self.l2 = nn.Linear(400, 300)
|
||||
self.l3 = nn.Linear(300, action_shape[0])
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def forward(self, obs, detach_encoder=False):
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
h = F.relu(self.l1(obs))
|
||||
h = F.relu(self.l2(h))
|
||||
action = torch.tanh(self.l3(h))
|
||||
self.outputs['mu'] = action
|
||||
return action
|
||||
|
||||
def log(self, L, step, log_freq=LOG_FREQ):
|
||||
if step % log_freq != 0:
|
||||
return
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
||||
|
||||
L.log_param('train_actor/fc1', self.l1, step)
|
||||
L.log_param('train_actor/fc2', self.l2, step)
|
||||
L.log_param('train_actor/fc3', self.l3, step)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = make_encoder(
|
||||
encoder_type, obs_shape, encoder_feature_dim
|
||||
)
|
||||
|
||||
# Q1 architecture
|
||||
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
||||
self.l2 = nn.Linear(400, 300)
|
||||
self.l3 = nn.Linear(300, 1)
|
||||
|
||||
# Q2 architecture
|
||||
self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
||||
self.l5 = nn.Linear(400, 300)
|
||||
self.l6 = nn.Linear(300, 1)
|
||||
|
||||
self.outputs = dict()
|
||||
|
||||
def forward(self, obs, action, detach_encoder=False):
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
|
||||
obs_action = torch.cat([obs, action], 1)
|
||||
|
||||
h1 = F.relu(self.l1(obs_action))
|
||||
h1 = F.relu(self.l2(h1))
|
||||
q1 = self.l3(h1)
|
||||
|
||||
h2 = F.relu(self.l4(obs_action))
|
||||
h2 = F.relu(self.l5(h2))
|
||||
q2 = self.l6(h2)
|
||||
|
||||
self.outputs['q1'] = q1
|
||||
self.outputs['q2'] = q2
|
||||
|
||||
return q1, q2
|
||||
|
||||
def Q1(self, obs, action, detach_encoder=False):
|
||||
obs = self.encoder(obs, detach=detach_encoder)
|
||||
|
||||
obs_action = torch.cat([obs, action], 1)
|
||||
|
||||
h1 = F.relu(self.l1(obs_action))
|
||||
h1 = F.relu(self.l2(h1))
|
||||
q1 = self.l3(h1)
|
||||
return q1
|
||||
|
||||
def log(self, L, step, log_freq=LOG_FREQ):
|
||||
if step % log_freq != 0: return
|
||||
|
||||
self.encoder.log(L, step, log_freq)
|
||||
|
||||
for k, v in self.outputs.items():
|
||||
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
||||
|
||||
L.log_param('train_critic/q1_fc1', self.l1, step)
|
||||
L.log_param('train_critic/q1_fc2', self.l2, step)
|
||||
L.log_param('train_critic/q1_fc3', self.l3, step)
|
||||
L.log_param('train_critic/q1_fc4', self.l4, step)
|
||||
L.log_param('train_critic/q1_fc5', self.l5, step)
|
||||
L.log_param('train_critic/q1_fc6', self.l6, step)
|
||||
|
||||
|
||||
class TD3Agent(object):
|
||||
def __init__(
|
||||
self,
|
||||
obs_shape,
|
||||
action_shape,
|
||||
device,
|
||||
discount=0.99,
|
||||
tau=0.005,
|
||||
policy_noise=0.2,
|
||||
noise_clip=0.5,
|
||||
expl_noise=0.1,
|
||||
actor_lr=1e-3,
|
||||
critic_lr=1e-3,
|
||||
encoder_type='identity',
|
||||
encoder_feature_dim=50,
|
||||
actor_update_freq=2,
|
||||
target_update_freq=2,
|
||||
):
|
||||
self.device = device
|
||||
self.discount = discount
|
||||
self.tau = tau
|
||||
self.policy_noise = policy_noise
|
||||
self.noise_clip = noise_clip
|
||||
self.expl_noise = expl_noise
|
||||
self.actor_update_freq = actor_update_freq
|
||||
self.target_update_freq = target_update_freq
|
||||
|
||||
# models
|
||||
self.actor = Actor(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
|
||||
self.critic = Critic(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
|
||||
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
||||
|
||||
self.actor_target = Actor(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||||
|
||||
self.critic_target = Critic(
|
||||
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
||||
).to(device)
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
|
||||
# optimizers
|
||||
self.actor_optimizer = torch.optim.Adam(
|
||||
self.actor.parameters(), lr=actor_lr
|
||||
)
|
||||
|
||||
self.critic_optimizer = torch.optim.Adam(
|
||||
self.critic.parameters(), lr=critic_lr
|
||||
)
|
||||
|
||||
self.train()
|
||||
self.critic_target.train()
|
||||
self.actor_target.train()
|
||||
|
||||
def train(self, training=True):
|
||||
self.training = training
|
||||
self.actor.train(training)
|
||||
self.critic.train(training)
|
||||
|
||||
def select_action(self, obs):
|
||||
with torch.no_grad():
|
||||
obs = torch.FloatTensor(obs).to(self.device)
|
||||
obs = obs.unsqueeze(0)
|
||||
action = self.actor(obs)
|
||||
return action.cpu().data.numpy().flatten()
|
||||
|
||||
def sample_action(self, obs):
|
||||
with torch.no_grad():
|
||||
obs = torch.FloatTensor(obs).to(self.device)
|
||||
obs = obs.unsqueeze(0)
|
||||
action = self.actor(obs)
|
||||
noise = torch.randn_like(action) * self.expl_noise
|
||||
action = (action + noise).clamp(-1.0, 1.0)
|
||||
return action.cpu().data.numpy().flatten()
|
||||
|
||||
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
||||
with torch.no_grad():
|
||||
noise = torch.randn_like(action).to(self.device) * self.policy_noise
|
||||
noise = noise.clamp(-self.noise_clip, self.noise_clip)
|
||||
next_action = self.actor_target(next_obs) + noise
|
||||
next_action = next_action.clamp(-1.0, 1.0)
|
||||
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
|
||||
target_Q = torch.min(target_Q1, target_Q2)
|
||||
target_Q = reward + (not_done * self.discount * target_Q)
|
||||
|
||||
current_Q1, current_Q2 = self.critic(obs, action)
|
||||
|
||||
critic_loss = F.mse_loss(current_Q1,
|
||||
target_Q) + F.mse_loss(current_Q2, target_Q)
|
||||
L.log('train_critic/loss', critic_loss, step)
|
||||
|
||||
self.critic_optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
self.critic.log(L, step)
|
||||
|
||||
def update_actor(self, obs, L, step):
|
||||
action = self.actor(obs, detach_encoder=True)
|
||||
actor_Q = self.critic.Q1(obs, action, detach_encoder=True)
|
||||
actor_loss = -actor_Q.mean()
|
||||
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optimizer.step()
|
||||
|
||||
self.actor.log(L, step)
|
||||
|
||||
def update(self, replay_buffer, L, step):
|
||||
obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
||||
|
||||
L.log('train/batch_reward', reward.mean(), step)
|
||||
|
||||
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
||||
|
||||
if step % self.actor_update_freq == 0:
|
||||
self.update_actor(obs, L, step)
|
||||
|
||||
if step % self.target_update_freq == 0:
|
||||
utils.soft_update_params(self.critic, self.critic_target, self.tau)
|
||||
utils.soft_update_params(self.actor, self.actor_target, self.tau)
|
||||
|
||||
def save(self, model_dir, step):
|
||||
torch.save(
|
||||
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
||||
)
|
||||
torch.save(
|
||||
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
||||
)
|
||||
|
||||
def load(self, model_dir, step):
|
||||
self.actor.load_state_dict(
|
||||
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
||||
)
|
||||
self.critic.load_state_dict(
|
||||
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
||||
)
|
315
train.py
Normal file
315
train.py
Normal file
@ -0,0 +1,315 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import argparse
|
||||
import os
|
||||
import math
|
||||
import gym
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import dmc2gym
|
||||
import copy
|
||||
|
||||
import utils
|
||||
from logger import Logger
|
||||
from video import VideoRecorder
|
||||
|
||||
from sac import SACAgent
|
||||
from td3 import TD3Agent
|
||||
from ddpg import DDPGAgent
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# environment
|
||||
parser.add_argument('--domain_name', default='cheetah')
|
||||
parser.add_argument('--task_name', default='run')
|
||||
parser.add_argument('--image_size', default=84, type=int)
|
||||
parser.add_argument('--action_repeat', default=1, type=int)
|
||||
parser.add_argument('--frame_stack', default=3, type=int)
|
||||
# replay buffer
|
||||
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
||||
# train
|
||||
parser.add_argument('--agent', default='sac', type=str)
|
||||
parser.add_argument('--init_steps', default=1000, type=int)
|
||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||
parser.add_argument('--batch_size', default=512, type=int)
|
||||
parser.add_argument('--hidden_dim', default=256, type=int)
|
||||
# eval
|
||||
parser.add_argument('--eval_freq', default=10000, type=int)
|
||||
parser.add_argument('--num_eval_episodes', default=10, type=int)
|
||||
# critic
|
||||
parser.add_argument('--critic_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--critic_beta', default=0.9, type=float)
|
||||
parser.add_argument('--critic_tau', default=0.005, type=float)
|
||||
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
||||
# actor
|
||||
parser.add_argument('--actor_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--actor_beta', default=0.9, type=float)
|
||||
parser.add_argument('--actor_log_std_min', default=-10, type=float)
|
||||
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
||||
parser.add_argument('--actor_update_freq', default=2, type=int)
|
||||
# encoder/decoder
|
||||
parser.add_argument('--encoder_type', default='identity', type=str)
|
||||
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
||||
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
||||
parser.add_argument('--decoder_type', default='identity', type=str)
|
||||
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
||||
parser.add_argument('--decoder_latent_lambda', default=0.0, type=float)
|
||||
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
||||
parser.add_argument('--decoder_kl_lambda', default=0.0, type=float)
|
||||
parser.add_argument('--num_layers', default=4, type=int)
|
||||
parser.add_argument('--num_filters', default=32, type=int)
|
||||
parser.add_argument('--freeze_encoder', default=False, action='store_true')
|
||||
parser.add_argument('--use_dynamics', default=False, action='store_true')
|
||||
# sac
|
||||
parser.add_argument('--discount', default=0.99, type=float)
|
||||
parser.add_argument('--init_temperature', default=0.01, type=float)
|
||||
parser.add_argument('--alpha_lr', default=1e-3, type=float)
|
||||
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
||||
# td3
|
||||
parser.add_argument('--policy_noise', default=0.2, type=float)
|
||||
parser.add_argument('--expl_noise', default=0.1, type=float)
|
||||
parser.add_argument('--noise_clip', default=0.5, type=float)
|
||||
parser.add_argument('--tau', default=0.005, type=float)
|
||||
# misc
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--work_dir', default='.', type=str)
|
||||
parser.add_argument('--save_tb', default=False, action='store_true')
|
||||
parser.add_argument('--save_model', default=False, action='store_true')
|
||||
parser.add_argument('--save_buffer', default=False, action='store_true')
|
||||
parser.add_argument('--save_video', default=False, action='store_true')
|
||||
parser.add_argument('--pretrained_info', default=None, type=str)
|
||||
parser.add_argument('--pretrained_decoder', default=False, action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def evaluate(env, agent, video, num_episodes, L, step):
|
||||
for i in range(num_episodes):
|
||||
obs = env.reset()
|
||||
video.init(enabled=(i == 0))
|
||||
done = False
|
||||
episode_reward = 0
|
||||
while not done:
|
||||
with utils.eval_mode(agent):
|
||||
action = agent.select_action(obs)
|
||||
obs, reward, done, _ = env.step(action)
|
||||
video.record(env)
|
||||
episode_reward += reward
|
||||
|
||||
video.save('%d.mp4' % step)
|
||||
L.log('eval/episode_reward', episode_reward, step)
|
||||
L.dump(step)
|
||||
|
||||
|
||||
def make_agent(obs_shape, state_shape, action_shape, args, device):
|
||||
if args.agent == 'sac':
|
||||
return SACAgent(
|
||||
obs_shape=obs_shape,
|
||||
state_shape=state_shape,
|
||||
action_shape=action_shape,
|
||||
device=device,
|
||||
hidden_dim=args.hidden_dim,
|
||||
discount=args.discount,
|
||||
init_temperature=args.init_temperature,
|
||||
alpha_lr=args.alpha_lr,
|
||||
alpha_beta=args.alpha_beta,
|
||||
actor_lr=args.actor_lr,
|
||||
actor_beta=args.actor_beta,
|
||||
actor_log_std_min=args.actor_log_std_min,
|
||||
actor_log_std_max=args.actor_log_std_max,
|
||||
actor_update_freq=args.actor_update_freq,
|
||||
critic_lr=args.critic_lr,
|
||||
critic_beta=args.critic_beta,
|
||||
critic_tau=args.critic_tau,
|
||||
critic_target_update_freq=args.critic_target_update_freq,
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_feature_dim=args.encoder_feature_dim,
|
||||
encoder_lr=args.encoder_lr,
|
||||
encoder_tau=args.encoder_tau,
|
||||
decoder_type=args.decoder_type,
|
||||
decoder_lr=args.decoder_lr,
|
||||
decoder_update_freq=args.decoder_update_freq,
|
||||
decoder_latent_lambda=args.decoder_latent_lambda,
|
||||
decoder_weight_lambda=args.decoder_weight_lambda,
|
||||
decoder_kl_lambda=args.decoder_kl_lambda,
|
||||
num_layers=args.num_layers,
|
||||
num_filters=args.num_filters,
|
||||
freeze_encoder=args.freeze_encoder,
|
||||
use_dynamics=args.use_dynamics
|
||||
)
|
||||
elif args.agent == 'td3':
|
||||
return TD3Agent(
|
||||
obs_shape=obs_shape,
|
||||
action_shape=action_shape,
|
||||
device=device,
|
||||
discount=args.discount,
|
||||
tau=args.tau,
|
||||
policy_noise=args.policy_noise,
|
||||
noise_clip=args.noise_clip,
|
||||
expl_noise=args.expl_noise,
|
||||
actor_lr=args.actor_lr,
|
||||
critic_lr=args.critic_lr,
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_feature_dim=args.encoder_feature_dim,
|
||||
actor_update_freq=args.actor_update_freq,
|
||||
target_update_freq=args.critic_target_update_freq
|
||||
)
|
||||
elif args.agent == 'ddpg':
|
||||
return DDPGAgent(
|
||||
obs_shape=obs_shape,
|
||||
action_shape=action_shape,
|
||||
device=device,
|
||||
discount=args.discount,
|
||||
tau=args.tau,
|
||||
actor_lr=args.actor_lr,
|
||||
critic_lr=args.critic_lr,
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_feature_dim=args.encoder_feature_dim
|
||||
)
|
||||
else:
|
||||
assert 'agent is not supported: %s' % args.agent
|
||||
|
||||
|
||||
def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder):
|
||||
path, version = pretrained_info.split(':')
|
||||
|
||||
pretrained_agent = copy.deepcopy(agent)
|
||||
pretrained_agent.load(path, int(version))
|
||||
agent.critic.encoder.load_state_dict(
|
||||
pretrained_agent.critic.encoder.state_dict()
|
||||
)
|
||||
agent.actor.encoder.load_state_dict(
|
||||
pretrained_agent.actor.encoder.state_dict()
|
||||
)
|
||||
|
||||
if pretrained_decoder:
|
||||
agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict())
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
utils.set_seed_everywhere(args.seed)
|
||||
|
||||
env = dmc2gym.make(
|
||||
domain_name=args.domain_name,
|
||||
task_name=args.task_name,
|
||||
seed=args.seed,
|
||||
visualize_reward=False,
|
||||
from_pixels=(args.encoder_type == 'pixel'),
|
||||
height=args.image_size,
|
||||
width=args.image_size,
|
||||
frame_skip=args.action_repeat
|
||||
)
|
||||
env.seed(args.seed)
|
||||
|
||||
# stack several consecutive frames together
|
||||
if args.encoder_type == 'pixel':
|
||||
env = utils.FrameStack(env, k=args.frame_stack)
|
||||
|
||||
utils.make_dir(args.work_dir)
|
||||
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
|
||||
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
|
||||
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
|
||||
|
||||
video = VideoRecorder(video_dir if args.save_video else None)
|
||||
|
||||
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
|
||||
json.dump(vars(args), f, sort_keys=True, indent=4)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# the dmc2gym wrapper standardizes actions
|
||||
assert env.action_space.low.min() >= -1
|
||||
assert env.action_space.high.max() <= 1
|
||||
|
||||
replay_buffer = utils.ReplayBuffer(
|
||||
obs_shape=env.observation_space.shape,
|
||||
state_shape=env.state_space.shape,
|
||||
action_shape=env.action_space.shape,
|
||||
capacity=args.replay_buffer_capacity,
|
||||
batch_size=args.batch_size,
|
||||
device=device
|
||||
)
|
||||
|
||||
agent = make_agent(
|
||||
obs_shape=env.observation_space.shape,
|
||||
state_shape=env.state_space.shape,
|
||||
action_shape=env.action_space.shape,
|
||||
args=args,
|
||||
device=device
|
||||
)
|
||||
|
||||
if args.pretrained_info is not None:
|
||||
agent = load_pretrained_encoder(
|
||||
agent, args.pretrained_info, args.pretrained_decoder
|
||||
)
|
||||
|
||||
L = Logger(args.work_dir, use_tb=args.save_tb)
|
||||
|
||||
episode, episode_reward, done = 0, 0, True
|
||||
start_time = time.time()
|
||||
for step in range(args.num_train_steps):
|
||||
if done:
|
||||
if step > 0:
|
||||
L.log('train/duration', time.time() - start_time, step)
|
||||
start_time = time.time()
|
||||
L.dump(step)
|
||||
|
||||
# evaluate agent periodically
|
||||
if step % args.eval_freq == 0:
|
||||
L.log('eval/episode', episode, step)
|
||||
evaluate(env, agent, video, args.num_eval_episodes, L, step)
|
||||
if args.save_model:
|
||||
agent.save(model_dir, step)
|
||||
if args.save_buffer:
|
||||
replay_buffer.save(buffer_dir)
|
||||
|
||||
L.log('train/episode_reward', episode_reward, step)
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
episode_reward = 0
|
||||
episode_step = 0
|
||||
episode += 1
|
||||
|
||||
L.log('train/episode', episode, step)
|
||||
|
||||
# sample action for data collection
|
||||
if step < args.init_steps:
|
||||
action = env.action_space.sample()
|
||||
else:
|
||||
with utils.eval_mode(agent):
|
||||
action = agent.sample_action(obs)
|
||||
|
||||
# run training update
|
||||
if step >= args.init_steps:
|
||||
num_updates = args.init_steps if step == args.init_steps else 1
|
||||
for _ in range(num_updates):
|
||||
agent.update(replay_buffer, L, step)
|
||||
|
||||
state = env.env.env._current_state
|
||||
next_obs, reward, done, _ = env.step(action)
|
||||
next_state = env.env.env._current_state.shape
|
||||
|
||||
# allow infinit bootstrap
|
||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
||||
done
|
||||
)
|
||||
episode_reward += reward
|
||||
|
||||
replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state)
|
||||
|
||||
obs = next_obs
|
||||
episode_step += 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
182
utils.py
Normal file
182
utils.py
Normal file
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import gym
|
||||
import os
|
||||
from collections import deque
|
||||
import random
|
||||
|
||||
|
||||
class eval_mode(object):
|
||||
def __init__(self, *models):
|
||||
self.models = models
|
||||
|
||||
def __enter__(self):
|
||||
self.prev_states = []
|
||||
for model in self.models:
|
||||
self.prev_states.append(model.training)
|
||||
model.train(False)
|
||||
|
||||
def __exit__(self, *args):
|
||||
for model, state in zip(self.models, self.prev_states):
|
||||
model.train(state)
|
||||
return False
|
||||
|
||||
|
||||
def soft_update_params(net, target_net, tau):
|
||||
for param, target_param in zip(net.parameters(), target_net.parameters()):
|
||||
target_param.data.copy_(
|
||||
tau * param.data + (1 - tau) * target_param.data
|
||||
)
|
||||
|
||||
|
||||
def set_seed_everywhere(seed):
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def module_hash(module):
|
||||
result = 0
|
||||
for tensor in module.state_dict().values():
|
||||
result += tensor.sum().item()
|
||||
return result
|
||||
|
||||
|
||||
def make_dir(dir_path):
|
||||
try:
|
||||
os.mkdir(dir_path)
|
||||
except OSError:
|
||||
pass
|
||||
return dir_path
|
||||
|
||||
|
||||
def preprocess_obs(obs, bits=5):
|
||||
"""Preprocessing image, see https://arxiv.org/abs/1807.03039."""
|
||||
bins = 2**bits
|
||||
assert obs.dtype == torch.float32
|
||||
if bits < 8:
|
||||
obs = torch.floor(obs / 2**(8 - bits))
|
||||
obs = obs / bins
|
||||
obs = obs + torch.rand_like(obs) / bins
|
||||
obs = obs - 0.5
|
||||
return obs
|
||||
|
||||
|
||||
class ReplayBuffer(object):
|
||||
"""Buffer to store environment transitions."""
|
||||
def __init__(
|
||||
self, obs_shape, state_shape, action_shape, capacity, batch_size,
|
||||
device
|
||||
):
|
||||
self.capacity = capacity
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
|
||||
# the proprioceptive obs is stored as float32, pixels obs as uint8
|
||||
obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
|
||||
|
||||
self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
|
||||
self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
|
||||
self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
|
||||
self.rewards = np.empty((capacity, 1), dtype=np.float32)
|
||||
self.not_dones = np.empty((capacity, 1), dtype=np.float32)
|
||||
self.states = np.empty((capacity, *state_shape), dtype=np.float32)
|
||||
self.next_states = np.empty((capacity, *state_shape), dtype=np.float32)
|
||||
|
||||
self.idx = 0
|
||||
self.last_save = 0
|
||||
self.full = False
|
||||
|
||||
def add(self, obs, action, reward, next_obs, done, state, next_state):
|
||||
np.copyto(self.obses[self.idx], obs)
|
||||
np.copyto(self.actions[self.idx], action)
|
||||
np.copyto(self.rewards[self.idx], reward)
|
||||
np.copyto(self.next_obses[self.idx], next_obs)
|
||||
np.copyto(self.not_dones[self.idx], not done)
|
||||
np.copyto(self.states[self.idx], state)
|
||||
np.copyto(self.next_states[self.idx], next_state)
|
||||
|
||||
self.idx = (self.idx + 1) % self.capacity
|
||||
self.full = self.full or self.idx == 0
|
||||
|
||||
def sample(self):
|
||||
idxs = np.random.randint(
|
||||
0, self.capacity if self.full else self.idx, size=self.batch_size
|
||||
)
|
||||
|
||||
obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
|
||||
actions = torch.as_tensor(self.actions[idxs], device=self.device)
|
||||
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
|
||||
next_obses = torch.as_tensor(
|
||||
self.next_obses[idxs], device=self.device
|
||||
).float()
|
||||
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
|
||||
states = torch.as_tensor(self.states[idxs], device=self.device)
|
||||
|
||||
return obses, actions, rewards, next_obses, not_dones, states
|
||||
|
||||
def save(self, save_dir):
|
||||
if self.idx == self.last_save:
|
||||
return
|
||||
path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
|
||||
payload = [
|
||||
self.obses[self.last_save:self.idx],
|
||||
self.next_obses[self.last_save:self.idx],
|
||||
self.actions[self.last_save:self.idx],
|
||||
self.rewards[self.last_save:self.idx],
|
||||
self.not_dones[self.last_save:self.idx],
|
||||
self.states[self.last_save:self.idx],
|
||||
self.next_states[self.last_save:self.idx]
|
||||
]
|
||||
self.last_save = self.idx
|
||||
torch.save(payload, path)
|
||||
|
||||
def load(self, save_dir):
|
||||
chunks = os.listdir(save_dir)
|
||||
chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
|
||||
for chunk in chucks:
|
||||
start, end = [int(x) for x in chunk.split('.')[0].split('_')]
|
||||
path = os.path.join(save_dir, chunk)
|
||||
payload = torch.load(path)
|
||||
assert self.idx == start
|
||||
self.obses[start:end] = payload[0]
|
||||
self.next_obses[start:end] = payload[1]
|
||||
self.actions[start:end] = payload[2]
|
||||
self.rewards[start:end] = payload[3]
|
||||
self.not_dones[start:end] = payload[4]
|
||||
self.states[start:end] = payload[5]
|
||||
self.next_states[start:end] = payload[6]
|
||||
self.idx = end
|
||||
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self._k = k
|
||||
self._frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=((shp[0] * k,) + shp[1:]),
|
||||
dtype=env.observation_space.dtype
|
||||
)
|
||||
self._max_episode_steps = env._max_episode_steps
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
for _ in range(self._k):
|
||||
self._frames.append(obs)
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._frames.append(obs)
|
||||
return self._get_obs(), reward, done, info
|
||||
|
||||
def _get_obs(self):
|
||||
assert len(self._frames) == self._k
|
||||
return np.concatenate(list(self._frames), axis=0)
|
32
video.py
Normal file
32
video.py
Normal file
@ -0,0 +1,32 @@
|
||||
import imageio
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VideoRecorder(object):
|
||||
def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30):
|
||||
self.dir_name = dir_name
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.camera_id = camera_id
|
||||
self.fps = fps
|
||||
self.frames = []
|
||||
|
||||
def init(self, enabled=True):
|
||||
self.frames = []
|
||||
self.enabled = self.dir_name is not None and enabled
|
||||
|
||||
def record(self, env):
|
||||
if self.enabled:
|
||||
frame = env.render(
|
||||
mode='rgb_array',
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
camera_id=self.camera_id
|
||||
)
|
||||
self.frames.append(frame)
|
||||
|
||||
def save(self, file_name):
|
||||
if self.enabled:
|
||||
path = os.path.join(self.dir_name, file_name)
|
||||
imageio.mimsave(path, self.frames, fps=self.fps)
|
Loading…
Reference in New Issue
Block a user