From 681e13b12a42497d087384fe076ff1e1fe7da716 Mon Sep 17 00:00:00 2001 From: Denis Yarats Date: Mon, 23 Sep 2019 11:20:48 -0700 Subject: [PATCH] init --- .gitignore | 3 + README.md | 75 ++++++++ ddpg.py | 209 ++++++++++++++++++++++ decoder.py | 106 +++++++++++ encoder.py | 185 +++++++++++++++++++ logger.py | 165 +++++++++++++++++ run.sh | 21 +++ sac.py | 507 +++++++++++++++++++++++++++++++++++++++++++++++++++++ td3.py | 259 +++++++++++++++++++++++++++ train.py | 315 +++++++++++++++++++++++++++++++++ utils.py | 182 +++++++++++++++++++ video.py | 32 ++++ 12 files changed, 2059 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 ddpg.py create mode 100644 decoder.py create mode 100644 encoder.py create mode 100644 logger.py create mode 100755 run.sh create mode 100644 sac.py create mode 100644 td3.py create mode 100644 train.py create mode 100644 utils.py create mode 100644 video.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..37a0c9a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +.ipynb_checkpoints/ +runs diff --git a/README.md b/README.md new file mode 100644 index 0000000..5573476 --- /dev/null +++ b/README.md @@ -0,0 +1,75 @@ +# Soft Actor-Critic implementaiton in PyTorch + + +## Running locally +To train SAC locally one can use provided `run_local.sh` script (change it to modify particular arguments): +``` +./run_local.sh +``` +This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command: +``` +tensorboard --logdir save +``` +Then open up tensorboad in your browser. + +You will also see some console output, something like this: +``` +| train | E: 1 | S: 1000 | D: 0.8 s | R: 0.0000 | BR: 0.0000 | ALOSS: 0.0000 | CLOSS: 0.0000 | RLOSS: 0.0000 +``` +This line means: +``` +train - training episode +E - total number of episodes +S - total number of environment steps +D - duration in seconds to train 1 episode +R - episode reward +BR - average reward of sampled batch +ALOSS - average loss of actor +CLOSS - average loss of critic +RLOSS - average reconstruction loss (only if is trained from pixels and decoder) +``` +These are just the most important number, more of all other metrics can be found in tensorboard. +Also, besides training, once in a while there is evaluation output, like this: +``` +| eval | S: 0 | ER: 21.1676 +``` +Which just tells the expected reward `ER` evaluating current policy after `S` steps. Note that `ER` is average evaluation performance over `num_eval_episodes` episodes (usually 10). + +## Running on the cluster +You can find the `run_cluster.sh` script file that allows you run training on the cluster. It is a simple bash script, that is super easy to modify. We usually run 10 different seeds for each configuration to get reliable results. For example to schedule 10 runs of `walker walk` simple do this: +``` +./run_cluster.sh walker walk +``` +This script will schedule 10 jobs and all the output will be stored under `./runs/walker_walk/{configuration_name}/seed_i`. The folder structure looks like this: +``` +runs/ + walker_walk/ + sac_states/ + seed_1/ + id # slurm job id + stdout # standard output of your job + stderr # standard error of your jobs + run.sh # starting script + run.slrm # slurm script + eval.log # log file for evaluation + train.log # log file for training + tb/ # folder that stores tensorboard output + video/ # folder stores evaluation videos + 10000.mp4 # video of one episode after 10000 steps + seed_2/ + ... +``` +Again, you can attach tensorboard to a particular configuration, for example: +``` +tensorboard --logdir runs/walker_walk/sac_states +``` + +For convinience, you can also use an iPython notebook to get aggregated over 10 seeds results. An example of such notebook is `runs.ipynb` + + +## Run entire testbed +Another scirpt that allow to run all 10 dm_control task on the cluster is here: +``` +./run_all.sh +``` +It will call `run_cluster.sh` for each task, so you only need to modify `run_cluster.sh` to change the hyper parameters. diff --git a/ddpg.py b/ddpg.py new file mode 100644 index 0000000..2b03972 --- /dev/null +++ b/ddpg.py @@ -0,0 +1,209 @@ +# Code is taken from https://github.com/sfujim/TD3 with slight modifications + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import utils +from encoder import make_encoder + +LOG_FREQ = 10000 + + +class Actor(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + self.l1 = nn.Linear(self.encoder.feature_dim, 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, action_shape[0]) + + self.outputs = dict() + + def forward(self, obs, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + h = F.relu(self.l1(obs)) + h = F.relu(self.l2(h)) + action = torch.tanh(self.l3(h)) + self.outputs['mu'] = action + return action + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_actor/%s_hist' % k, v, step) + + L.log_param('train_actor/fc1', self.l1, step) + L.log_param('train_actor/fc2', self.l2, step) + L.log_param('train_actor/fc3', self.l3, step) + + +class Critic(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, 1) + + self.outputs = dict() + + def forward(self, obs, action, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + obs_action = torch.cat([obs, action], dim=1) + h = F.relu(self.l1(obs_action)) + h = F.relu(self.l2(h)) + q = self.l3(h) + self.outputs['q'] = q + return q + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + self.encoder.log(L, step, log_freq) + + for k, v in self.outputs.items(): + L.log_histogram('train_critic/%s_hist' % k, v, step) + + L.log_param('train_critic/fc1', self.l1, step) + L.log_param('train_critic/fc2', self.l2, step) + L.log_param('train_critic/fc3', self.l3, step) + + +class DDPGAgent(object): + def __init__( + self, + obs_shape, + action_shape, + device, + discount=0.99, + tau=0.005, + actor_lr=1e-3, + critic_lr=1e-3, + encoder_type='identity', + encoder_feature_dim=50 + ): + self.device = device + self.discount = discount + self.tau = tau + + # models + self.actor = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.critic = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.actor.encoder.copy_conv_weights_from(self.critic.encoder) + + self.actor_target = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.actor_target.load_state_dict(self.actor.state_dict()) + + self.critic_target = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.critic_target.load_state_dict(self.critic.state_dict()) + + # optimizers + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=actor_lr + ) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=critic_lr + ) + + self.train() + self.critic_target.train() + self.actor_target.train() + + def train(self, training=True): + self.training = training + self.actor.train(training) + self.critic.train(training) + + def select_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + action = self.actor(obs) + return action.cpu().data.numpy().flatten() + + def sample_action(self, obs): + return self.select_action(obs) + + def update_critic(self, obs, action, reward, next_obs, not_done, L, step): + with torch.no_grad(): + target_Q = self.critic_target( + next_obs, self.actor_target(next_obs) + ) + target_Q = reward + (not_done * self.discount * target_Q) + + current_Q = self.critic(obs, action) + + critic_loss = F.mse_loss(current_Q, target_Q) + L.log('train_critic/loss', critic_loss, step) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + self.critic.log(L, step) + + def update_actor(self, obs, L, step): + action = self.actor(obs, detach_encoder=True) + actor_Q = self.critic(obs, action, detach_encoder=True) + actor_loss = -actor_Q.mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.actor.log(L, step) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample() + + L.log('train/batch_reward', reward.mean(), step) + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + self.update_actor(obs, L, step) + + utils.soft_update_params(self.critic, self.critic_target, self.tau) + utils.soft_update_params(self.actor, self.actor_target, self.tau) + + def save(self, model_dir, step): + torch.save( + self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) + ) + torch.save( + self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) + ) + + def load(self, model_dir, step): + self.actor.load_state_dict( + torch.load('%s/actor_%s.pt' % (model_dir, step)) + ) + self.critic.load_state_dict( + torch.load('%s/critic_%s.pt' % (model_dir, step)) + ) diff --git a/decoder.py b/decoder.py new file mode 100644 index 0000000..3fbb4a7 --- /dev/null +++ b/decoder.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + +from encoder import OUT_DIM + + +class PixelDecoder(nn.Module): + def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32): + super().__init__() + + self.num_layers = num_layers + self.num_filters = num_filters + self.out_dim = OUT_DIM[num_layers] + + self.fc = nn.Linear( + feature_dim, num_filters * self.out_dim * self.out_dim + ) + + self.deconvs = nn.ModuleList() + + for i in range(self.num_layers - 1): + self.deconvs.append( + nn.ConvTranspose2d(num_filters, num_filters, 3, stride=1) + ) + self.deconvs.append( + nn.ConvTranspose2d( + num_filters, obs_shape[0], 3, stride=2, output_padding=1 + ) + ) + + self.outputs = dict() + + def forward(self, h): + h = torch.relu(self.fc(h)) + self.outputs['fc'] = h + + deconv = h.view(-1, self.num_filters, self.out_dim, self.out_dim) + self.outputs['deconv1'] = deconv + + for i in range(0, self.num_layers - 1): + deconv = torch.relu(self.deconvs[i](deconv)) + self.outputs['deconv%s' % (i + 1)] = deconv + + obs = self.deconvs[-1](deconv) + self.outputs['obs'] = obs + + return obs + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_decoder/%s_hist' % k, v, step) + if len(v.shape) > 2: + L.log_image('train_decoder/%s_i' % k, v[0], step) + + for i in range(self.num_layers): + L.log_param( + 'train_decoder/deconv%s' % (i + 1), self.deconvs[i], step + ) + L.log_param('train_decoder/fc', self.fc, step) + + +class StateDecoder(nn.Module): + def __init__(self, obs_shape, feature_dim): + super().__init__() + + assert len(obs_shape) == 1 + + self.trunk = nn.Sequential( + nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024), + nn.ReLU(), nn.Linear(1024, obs_shape[0]), nn.ReLU() + ) + + self.outputs = dict() + + def forward(self, obs, detach=False): + h = self.trunk(obs) + if detach: + h = h.detach() + self.outputs['h'] = h + return h + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + L.log_param('train_encoder/fc1', self.trunk[0], step) + L.log_param('train_encoder/fc2', self.trunk[2], step) + for k, v in self.outputs.items(): + L.log_histogram('train_encoder/%s_hist' % k, v, step) + + +_AVAILABLE_DECODERS = {'pixel': PixelDecoder, 'state': StateDecoder} + + +def make_decoder( + decoder_type, obs_shape, feature_dim, num_layers, num_filters +): + assert decoder_type in _AVAILABLE_DECODERS + if decoder_type == 'pixel': + return _AVAILABLE_DECODERS[decoder_type]( + obs_shape, feature_dim, num_layers, num_filters + ) + return _AVAILABLE_DECODERS[decoder_type](obs_shape, feature_dim) diff --git a/encoder.py b/encoder.py new file mode 100644 index 0000000..0f2e581 --- /dev/null +++ b/encoder.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn + + +def tie_weights(src, trg): + assert type(src) == type(trg) + trg.weight = src.weight + trg.bias = src.bias + + +OUT_DIM = {2: 39, 4: 35, 6: 31} + + +class PixelEncoder(nn.Module): + """Convolutional encoder of pixels observations.""" + def __init__( + self, + obs_shape, + feature_dim, + num_layers=2, + num_filters=32, + stochastic=False + ): + super().__init__() + + assert len(obs_shape) == 3 + + self.feature_dim = feature_dim + self.num_layers = num_layers + self.stochastic = stochastic + + self.convs = nn.ModuleList( + [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] + ) + for i in range(num_layers - 1): + self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) + + out_dim = OUT_DIM[num_layers] + self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) + self.ln = nn.LayerNorm(self.feature_dim) + + if self.stochastic: + self.log_std_min = -10 + self.log_std_max = 2 + self.fc_log_std = nn.Linear( + num_filters * out_dim * out_dim, self.feature_dim + ) + + self.outputs = dict() + + def reparameterize(self, mu, logstd): + std = torch.exp(logstd) + eps = torch.randn_like(std) + return mu + eps * std + + def forward_conv(self, obs): + obs = obs / 255. + self.outputs['obs'] = obs + + conv = torch.relu(self.convs[0](obs)) + self.outputs['conv1'] = conv + + for i in range(1, self.num_layers): + conv = torch.relu(self.convs[i](conv)) + self.outputs['conv%s' % (i + 1)] = conv + + h = conv.view(conv.size(0), -1) + return h + + def forward(self, obs, detach=False): + h = self.forward_conv(obs) + + if detach: + h = h.detach() + + h_fc = self.fc(h) + self.outputs['fc'] = h_fc + + h_norm = self.ln(h_fc) + self.outputs['ln'] = h_norm + + out = torch.tanh(h_norm) + + if self.stochastic: + self.outputs['mu'] = out + log_std = torch.tanh(self.fc_log_std(h)) + # normalize + log_std = self.log_std_min + 0.5 * ( + self.log_std_max - self.log_std_min + ) * (log_std + 1) + out = self.reparameterize(out, log_std) + self.outputs['log_std'] = log_std + + self.outputs['tanh'] = out + + return out + + def copy_conv_weights_from(self, source): + """Tie convolutional layers""" + # only tie conv layers + for i in range(self.num_layers): + tie_weights(src=source.convs[i], trg=self.convs[i]) + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_encoder/%s_hist' % k, v, step) + if len(v.shape) > 2: + L.log_image('train_encoder/%s_img' % k, v[0], step) + + for i in range(self.num_layers): + L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) + L.log_param('train_encoder/fc', self.fc, step) + L.log_param('train_encoder/ln', self.ln, step) + + +class StateEncoder(nn.Module): + def __init__(self, obs_shape, feature_dim): + super().__init__() + + assert len(obs_shape) == 1 + self.feature_dim = feature_dim + + self.trunk = nn.Sequential( + nn.Linear(obs_shape[0], 256), nn.ReLU(), + nn.Linear(256, feature_dim), nn.ReLU() + ) + + self.outputs = dict() + + def forward(self, obs, detach=False): + h = self.trunk(obs) + if detach: + h = h.detach() + self.outputs['h'] = h + return h + + def copy_conv_weights_from(self, source): + pass + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + L.log_param('train_encoder/fc1', self.trunk[0], step) + L.log_param('train_encoder/fc2', self.trunk[2], step) + for k, v in self.outputs.items(): + L.log_histogram('train_encoder/%s_hist' % k, v, step) + + +class IdentityEncoder(nn.Module): + def __init__(self, obs_shape, feature_dim): + super().__init__() + + assert len(obs_shape) == 1 + self.feature_dim = obs_shape[0] + + def forward(self, obs, detach=False): + return obs + + def copy_conv_weights_from(self, source): + pass + + def log(self, L, step, log_freq): + pass + + +_AVAILABLE_ENCODERS = { + 'pixel': PixelEncoder, + 'state': StateEncoder, + 'identity': IdentityEncoder +} + + +def make_encoder( + encoder_type, obs_shape, feature_dim, num_layers, num_filters, stochastic +): + assert encoder_type in _AVAILABLE_ENCODERS + if encoder_type == 'pixel': + return _AVAILABLE_ENCODERS[encoder_type]( + obs_shape, feature_dim, num_layers, num_filters, stochastic + ) + return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim) diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..93ff12e --- /dev/null +++ b/logger.py @@ -0,0 +1,165 @@ +from torch.utils.tensorboard import SummaryWriter +from collections import defaultdict +import json +import os +import shutil +import torch +import torchvision +import numpy as np +from termcolor import colored + + +FORMAT_CONFIG = { + 'rl': { + 'train': [('episode', 'E', 'int'), + ('step', 'S', 'int'), + ('duration', 'D', 'time'), + ('episode_reward', 'R', 'float'), + ('batch_reward', 'BR', 'float'), + ('actor_loss', 'ALOSS', 'float'), + ('critic_loss', 'CLOSS', 'float'), + ('ae_loss', 'RLOSS', 'float')], + 'eval': [('step', 'S', 'int'), + ('episode_reward', 'ER', 'float')] + } +} + + +class AverageMeter(object): + def __init__(self): + self._sum = 0 + self._count = 0 + + def update(self, value, n=1): + self._sum += value + self._count += n + + def value(self): + return self._sum / max(1, self._count) + + +class MetersGroup(object): + def __init__(self, file_name, formating): + self._file_name = file_name + if os.path.exists(file_name): + os.remove(file_name) + self._formating = formating + self._meters = defaultdict(AverageMeter) + + def log(self, key, value, n=1): + self._meters[key].update(value, n) + + def _prime_meters(self): + data = dict() + for key, meter in self._meters.items(): + if key.startswith('train'): + key = key[len('train') + 1:] + else: + key = key[len('eval') + 1:] + key = key.replace('/', '_') + data[key] = meter.value() + return data + + def _dump_to_file(self, data): + with open(self._file_name, 'a') as f: + f.write(json.dumps(data) + '\n') + + def _format(self, key, value, ty): + template = '%s: ' + if ty == 'int': + template += '%d' + elif ty == 'float': + template += '%.04f' + elif ty == 'time': + template += '%.01f s' + else: + raise 'invalid format type: %s' % ty + return template % (key, value) + + def _dump_to_console(self, data, prefix): + prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') + pieces = ['{:5}'.format(prefix)] + for key, disp_key, ty in self._formating: + value = data.get(key, 0) + pieces.append(self._format(disp_key, value, ty)) + print('| %s' % (' | '.join(pieces))) + + def dump(self, step, prefix): + if len(self._meters) == 0: + return + data = self._prime_meters() + data['step'] = step + self._dump_to_file(data) + self._dump_to_console(data, prefix) + self._meters.clear() + + +class Logger(object): + def __init__(self, log_dir, use_tb=True, config='rl'): + self._log_dir = log_dir + if use_tb: + tb_dir = os.path.join(log_dir, 'tb') + if os.path.exists(tb_dir): + shutil.rmtree(tb_dir) + self._sw = SummaryWriter(tb_dir) + else: + self._sw = None + self._train_mg = MetersGroup( + os.path.join(log_dir, 'train.log'), + formating=FORMAT_CONFIG[config]['train']) + self._eval_mg = MetersGroup( + os.path.join(log_dir, 'eval.log'), + formating=FORMAT_CONFIG[config]['eval']) + + def _try_sw_log(self, key, value, step): + if self._sw is not None: + self._sw.add_scalar(key, value, step) + + def _try_sw_log_image(self, key, image, step): + if self._sw is not None: + assert image.dim() == 3 + grid = torchvision.utils.make_grid(image.unsqueeze(1)) + self._sw.add_image(key, grid, step) + + def _try_sw_log_video(self, key, frames, step): + if self._sw is not None: + frames = torch.from_numpy(np.array(frames)) + frames = frames.unsqueeze(0) + self._sw.add_video(key, frames, step, fps=30) + + def _try_sw_log_histogram(self, key, histogram, step): + if self._sw is not None: + self._sw.add_histogram(key, histogram, step) + + def log(self, key, value, step, n=1): + assert key.startswith('train') or key.startswith('eval') + if type(value) == torch.Tensor: + value = value.item() + self._try_sw_log(key, value / n, step) + mg = self._train_mg if key.startswith('train') else self._eval_mg + mg.log(key, value, n) + + def log_param(self, key, param, step): + self.log_histogram(key + '_w', param.weight.data, step) + if hasattr(param.weight, 'grad') and param.weight.grad is not None: + self.log_histogram(key + '_w_g', param.weight.grad.data, step) + if hasattr(param, 'bias'): + self.log_histogram(key + '_b', param.bias.data, step) + if hasattr(param.bias, 'grad') and param.bias.grad is not None: + self.log_histogram(key + '_b_g', param.bias.grad.data, step) + + def log_image(self, key, image, step): + assert key.startswith('train') or key.startswith('eval') + self._try_sw_log_image(key, image, step) + + def log_video(self, key, frames, step): + assert key.startswith('train') or key.startswith('eval') + self._try_sw_log_video(key, frames, step) + + def log_histogram(self, key, histogram, step): + assert key.startswith('train') or key.startswith('eval') + self._try_sw_log_histogram(key, histogram, step) + + def dump(self, step): + self._train_mg.dump(step, 'train') + self._eval_mg.dump(step, 'eval') diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..78d2e41 --- /dev/null +++ b/run.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +DOMAIN=cheetah +TASK=run +ACTION_REPEAT=4 +ENCODER_TYPE=pixel +ENCODER_TYPE=pixel + + +WORK_DIR=./runs + +python train.py \ + --domain_name ${DOMAIN} \ + --task_name ${TASK} \ + --encoder_type ${ENCODER_TYPE} \ + --decoder_type ${DECODER_TYPE} \ + --action_repeat ${ACTION_REPEAT} \ + --save_video \ + --save_tb \ + --work_dir ${WORK_DIR}/${DOMAIN}_{TASK}/_ae_encoder_${ENCODER_TYPE}_decoder_{ENCODER_TYPE} \ + --seed 1 diff --git a/sac.py b/sac.py new file mode 100644 index 0000000..d1a18f6 --- /dev/null +++ b/sac.py @@ -0,0 +1,507 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +import math + +import utils +from encoder import make_encoder +from decoder import make_decoder + +LOG_FREQ = 10000 + + +def gaussian_logprob(noise, log_std): + """Compute Gaussian log probability.""" + residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) + return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) + + +def squash(mu, pi, log_pi): + """Apply squashing function. + See appendix C from https://arxiv.org/pdf/1812.05905.pdf. + """ + mu = torch.tanh(mu) + if pi is not None: + pi = torch.tanh(pi) + if log_pi is not None: + log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) + return mu, pi, log_pi + + +def weight_init(m): + """Custom weight init for Conv2D and Linear layers.""" + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf + assert m.weight.size(2) == m.weight.size(3) + m.weight.data.fill_(0.0) + m.bias.data.fill_(0.0) + mid = m.weight.size(2) // 2 + gain = nn.init.calculate_gain('relu') + nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) + + +class Actor(nn.Module): + """MLP actor network.""" + def __init__( + self, obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters, + freeze_encoder, stochastic + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim, num_layers, + num_filters, stochastic + ) + + self.log_std_min = log_std_min + self.log_std_max = log_std_max + self.freeze_encoder = freeze_encoder + + self.trunk = nn.Sequential( + nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, 2 * action_shape[0]) + ) + + self.outputs = dict() + self.apply(weight_init) + + def forward( + self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False + ): + obs = self.encoder(obs, detach=detach_encoder) + + if self.freeze_encoder: + obs = obs.detach() + + mu, log_std = self.trunk(obs).chunk(2, dim=-1) + + # constrain log_std inside [log_std_min, log_std_max] + log_std = F.tanh(log_std) + log_std = self.log_std_min + 0.5 * ( + self.log_std_max - self.log_std_min + ) * (log_std + 1) + + self.outputs['mu'] = mu + self.outputs['std'] = log_std.exp() + + if compute_pi: + std = log_std.exp() + noise = torch.randn_like(mu) + pi = mu + noise * std + else: + pi = None + entropy = None + + if compute_log_pi: + log_pi = gaussian_logprob(noise, log_std) + else: + log_pi = None + + mu, pi, log_pi = squash(mu, pi, log_pi) + + return mu, pi, log_pi, log_std + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_actor/%s_hist' % k, v, step) + + L.log_param('train_actor/fc1', self.trunk[0], step) + L.log_param('train_actor/fc2', self.trunk[2], step) + L.log_param('train_actor/fc3', self.trunk[4], step) + + +class QFunction(nn.Module): + """MLP for q-function.""" + def __init__(self, obs_dim, action_dim, hidden_dim): + super().__init__() + + self.trunk = nn.Sequential( + nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, obs, action): + assert obs.size(0) == action.size(0) + + obs_action = torch.cat([obs, action], dim=1) + return self.trunk(obs_action) + + +class DynamicsModel(nn.Module): + def __init__(self, state_dim, action_dim, hidden_dim): + super().__init__() + + self.trunk = nn.Sequential( + nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, state_dim) + ) + + def forward(self, state, action): + assert state.size(0) == action.size(0) + + state_action = torch.cat([state, action], dim=1) + return self.trunk(state_action) + + +class Critic(nn.Module): + """Critic network, employes two q-functions.""" + def __init__( + self, obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, num_layers, num_filters, freeze_encoder, + use_dynamics, stochastic + ): + super().__init__() + + self.freeze_encoder = freeze_encoder + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim, num_layers, + num_filters, stochastic + ) + + if use_dynamics: + self.forward_model = DynamicsModel( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + + self.Q1 = QFunction( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + self.Q2 = QFunction( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + + self.outputs = dict() + self.apply(weight_init) + + def forward(self, obs, action, detach_encoder=False): + # detach_encoder allows to stop gradient propogation to encoder + obs = self.encoder(obs, detach=detach_encoder) + + if self.freeze_encoder: + obs = obs.detach() + + q1 = self.Q1(obs, action) + q2 = self.Q2(obs, action) + + self.outputs['q1'] = q1 + self.outputs['q2'] = q2 + + return q1, q2 + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + self.encoder.log(L, step, log_freq) + + for k, v in self.outputs.items(): + L.log_histogram('train_critic/%s_hist' % k, v, step) + + for i in range(3): + L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) + L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) + + +class SACAgent(object): + """Soft Actor-Critic algorithm.""" + def __init__( + self, + obs_shape, + state_shape, + action_shape, + device, + hidden_dim=256, + discount=0.99, + init_temperature=0.01, + alpha_lr=1e-3, + alpha_beta=0.9, + actor_lr=1e-3, + actor_beta=0.9, + actor_log_std_min=-10, + actor_log_std_max=2, + actor_update_freq=2, + critic_lr=1e-3, + critic_beta=0.9, + critic_tau=0.005, + critic_target_update_freq=2, + encoder_type='identity', + encoder_feature_dim=50, + encoder_lr=1e-3, + encoder_tau=0.005, + decoder_type='identity', + decoder_lr=1e-3, + decoder_update_freq=1, + decoder_latent_lambda=0.0, + decoder_weight_lambda=0.0, + decoder_kl_lambda=0.0, + num_layers=4, + num_filters=32, + freeze_encoder=False, + use_dynamics=False + ): + self.device = device + self.discount = discount + self.critic_tau = critic_tau + self.encoder_tau = encoder_tau + self.actor_update_freq = actor_update_freq + self.critic_target_update_freq = critic_target_update_freq + self.decoder_update_freq = decoder_update_freq + self.decoder_latent_lambda = decoder_latent_lambda + self.decoder_kl_lambda = decoder_kl_lambda + self.decoder_type = decoder_type + self.use_dynamics = use_dynamics + + stochastic = decoder_kl_lambda > 0.0 + + self.actor = Actor( + obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, actor_log_std_min, actor_log_std_max, + num_layers, num_filters, freeze_encoder, stochastic + ).to(device) + + self.critic = Critic( + obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, num_layers, num_filters, freeze_encoder, + use_dynamics, stochastic + ).to(device) + + self.critic_target = Critic( + obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, num_layers, num_filters, freeze_encoder, + use_dynamics, stochastic + ).to(device) + + self.critic_target.load_state_dict(self.critic.state_dict()) + + # tie encoders between actor and critic + self.actor.encoder.copy_conv_weights_from(self.critic.encoder) + + self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) + self.log_alpha.requires_grad = True + # set target entropy to -|A| + self.target_entropy = -np.prod(action_shape) + + self.decoder = None + if decoder_type != 'identity': + # create decoder + shape = obs_shape if decoder_type == 'pixel' else state_shape + self.decoder = make_decoder( + decoder_type, shape, encoder_feature_dim, num_layers, + num_filters + ).to(device) + self.decoder.apply(weight_init) + + # optimizer for critic encoder for reconstruction loss + self.encoder_optimizer = torch.optim.Adam( + self.critic.encoder.parameters(), lr=encoder_lr + ) + + # optimizer for decoder + self.decoder_optimizer = torch.optim.Adam( + self.decoder.parameters(), + lr=decoder_lr, + weight_decay=decoder_weight_lambda + ) + + # optimizers + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999) + ) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999) + ) + + self.log_alpha_optimizer = torch.optim.Adam( + [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999) + ) + + self.train() + self.critic_target.train() + + def train(self, training=True): + self.training = training + self.actor.train(training) + self.critic.train(training) + if self.decoder is not None: + self.decoder.train(training) + + @property + def alpha(self): + return self.log_alpha.exp() + + def select_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + mu, _, _, _ = self.actor( + obs, compute_pi=False, compute_log_pi=False + ) + return mu.cpu().data.numpy().flatten() + + def sample_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + mu, pi, _, _ = self.actor(obs, compute_log_pi=False) + return pi.cpu().data.numpy().flatten() + + def update_critic(self, obs, action, reward, next_obs, not_done, L, step): + with torch.no_grad(): + _, policy_action, log_pi, _ = self.actor(next_obs) + target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) + target_V = torch.min(target_Q1, + target_Q2) - self.alpha.detach() * log_pi + target_Q = reward + (not_done * self.discount * target_V) + + # get current Q estimates + current_Q1, current_Q2 = self.critic(obs, action) + critic_loss = F.mse_loss(current_Q1, + target_Q) + F.mse_loss(current_Q2, target_Q) + L.log('train_critic/loss', critic_loss, step) + + # update dynamics (optional) + if self.use_dynamics: + h_obs = self.critic.encoder.outputs['mu'] + with torch.no_grad(): + next_latent = self.critic.encoder(next_obs) + pred_next_latent = self.critic.forward_model(h_obs, action) + dynamics_loss = F.mse_loss(pred_next_latent, next_latent) + L.log('train_critic/dynamics_loss', dynamics_loss, step) + critic_loss += dynamics_loss + + # Optimize the critic + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + self.critic.log(L, step) + + def update_actor_and_alpha(self, obs, L, step): + # detach encoder, so we don't update it with the actor loss + _, pi, log_pi, log_std = self.actor(obs, detach_encoder=True) + actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True) + + actor_Q = torch.min(actor_Q1, actor_Q2) + actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() + + L.log('train_actor/loss', actor_loss, step) + L.log('train_actor/target_entropy', self.target_entropy, step) + entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi) + ) + log_std.sum(dim=-1) + L.log('train_actor/entropy', entropy.mean(), step) + + # optimize the actor + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.actor.log(L, step) + + self.log_alpha_optimizer.zero_grad() + alpha_loss = (self.alpha * + (-log_pi - self.target_entropy).detach()).mean() + L.log('train_alpha/loss', alpha_loss, step) + L.log('train_alpha/value', self.alpha, step) + alpha_loss.backward() + self.log_alpha_optimizer.step() + + def update_decoder(self, obs, target_obs, L, step): + if self.decoder is None: + return + + h = self.critic.encoder(obs) + + if target_obs.dim() == 4: + # preprocess images to be in [-0.5, 0.5] range + target_obs = utils.preprocess_obs(target_obs) + rec_obs = self.decoder(h) + rec_loss = F.mse_loss(target_obs, rec_obs) + + # add L2 penalty on latent representation + # see https://arxiv.org/pdf/1903.12436.pdf + latent_loss = (0.5 * h.pow(2).sum(1)).mean() + + # add KL penalty for VAE + if self.decoder_kl_lambda > 0.0: + log_std = self.critic.encoder.outputs['log_std'] + mu = self.critic.encoder.outputs['mu'] + kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp()) + kl_div = kl_div.sum(1).mean(0, True) + else: + kl_div = 0.0 + loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div + + self.encoder_optimizer.zero_grad() + self.decoder_optimizer.zero_grad() + loss.backward() + + self.encoder_optimizer.step() + self.decoder_optimizer.step() + L.log('train_ae/ae_loss', loss, step) + + self.decoder.log(L, step, log_freq=LOG_FREQ) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done, state = replay_buffer.sample() + + L.log('train/batch_reward', reward.mean(), step) + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + utils.soft_update_params( + self.critic.Q1, self.critic_target.Q1, self.critic_tau + ) + utils.soft_update_params( + self.critic.Q2, self.critic_target.Q2, self.critic_tau + ) + utils.soft_update_params( + self.critic.encoder, self.critic_target.encoder, + self.encoder_tau + ) + + if step % self.decoder_update_freq == 0: + target = obs if self.decoder_type == 'pixel' else state + self.update_decoder(obs, target, L, step) + + def save(self, model_dir, step): + torch.save( + self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) + ) + torch.save( + self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) + ) + if self.decoder is not None: + torch.save( + self.decoder.state_dict(), + '%s/decoder_%s.pt' % (model_dir, step) + ) + + def load(self, model_dir, step): + self.actor.load_state_dict( + torch.load('%s/actor_%s.pt' % (model_dir, step)) + ) + self.critic.load_state_dict( + torch.load('%s/critic_%s.pt' % (model_dir, step)) + ) + if self.decoder is not None: + self.decoder.load_state_dict( + torch.load('%s/decoder_%s.pt' % (model_dir, step)) + ) diff --git a/td3.py b/td3.py new file mode 100644 index 0000000..db6aaf5 --- /dev/null +++ b/td3.py @@ -0,0 +1,259 @@ +# Code is taken from https://github.com/sfujim/TD3 with slight modifications + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import utils +from encoder import make_encoder + +LOG_FREQ = 10000 + + +class Actor(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + self.l1 = nn.Linear(self.encoder.feature_dim, 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, action_shape[0]) + + self.outputs = dict() + + def forward(self, obs, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + h = F.relu(self.l1(obs)) + h = F.relu(self.l2(h)) + action = torch.tanh(self.l3(h)) + self.outputs['mu'] = action + return action + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_actor/%s_hist' % k, v, step) + + L.log_param('train_actor/fc1', self.l1, step) + L.log_param('train_actor/fc2', self.l2, step) + L.log_param('train_actor/fc3', self.l3, step) + + +class Critic(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + # Q1 architecture + self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, 1) + + # Q2 architecture + self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) + self.l5 = nn.Linear(400, 300) + self.l6 = nn.Linear(300, 1) + + self.outputs = dict() + + def forward(self, obs, action, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + + obs_action = torch.cat([obs, action], 1) + + h1 = F.relu(self.l1(obs_action)) + h1 = F.relu(self.l2(h1)) + q1 = self.l3(h1) + + h2 = F.relu(self.l4(obs_action)) + h2 = F.relu(self.l5(h2)) + q2 = self.l6(h2) + + self.outputs['q1'] = q1 + self.outputs['q2'] = q2 + + return q1, q2 + + def Q1(self, obs, action, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + + obs_action = torch.cat([obs, action], 1) + + h1 = F.relu(self.l1(obs_action)) + h1 = F.relu(self.l2(h1)) + q1 = self.l3(h1) + return q1 + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: return + + self.encoder.log(L, step, log_freq) + + for k, v in self.outputs.items(): + L.log_histogram('train_critic/%s_hist' % k, v, step) + + L.log_param('train_critic/q1_fc1', self.l1, step) + L.log_param('train_critic/q1_fc2', self.l2, step) + L.log_param('train_critic/q1_fc3', self.l3, step) + L.log_param('train_critic/q1_fc4', self.l4, step) + L.log_param('train_critic/q1_fc5', self.l5, step) + L.log_param('train_critic/q1_fc6', self.l6, step) + + +class TD3Agent(object): + def __init__( + self, + obs_shape, + action_shape, + device, + discount=0.99, + tau=0.005, + policy_noise=0.2, + noise_clip=0.5, + expl_noise=0.1, + actor_lr=1e-3, + critic_lr=1e-3, + encoder_type='identity', + encoder_feature_dim=50, + actor_update_freq=2, + target_update_freq=2, + ): + self.device = device + self.discount = discount + self.tau = tau + self.policy_noise = policy_noise + self.noise_clip = noise_clip + self.expl_noise = expl_noise + self.actor_update_freq = actor_update_freq + self.target_update_freq = target_update_freq + + # models + self.actor = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.critic = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.actor.encoder.copy_conv_weights_from(self.critic.encoder) + + self.actor_target = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.actor_target.load_state_dict(self.actor.state_dict()) + + self.critic_target = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.critic_target.load_state_dict(self.critic.state_dict()) + + # optimizers + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=actor_lr + ) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=critic_lr + ) + + self.train() + self.critic_target.train() + self.actor_target.train() + + def train(self, training=True): + self.training = training + self.actor.train(training) + self.critic.train(training) + + def select_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + action = self.actor(obs) + return action.cpu().data.numpy().flatten() + + def sample_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + action = self.actor(obs) + noise = torch.randn_like(action) * self.expl_noise + action = (action + noise).clamp(-1.0, 1.0) + return action.cpu().data.numpy().flatten() + + def update_critic(self, obs, action, reward, next_obs, not_done, L, step): + with torch.no_grad(): + noise = torch.randn_like(action).to(self.device) * self.policy_noise + noise = noise.clamp(-self.noise_clip, self.noise_clip) + next_action = self.actor_target(next_obs) + noise + next_action = next_action.clamp(-1.0, 1.0) + target_Q1, target_Q2 = self.critic_target(next_obs, next_action) + target_Q = torch.min(target_Q1, target_Q2) + target_Q = reward + (not_done * self.discount * target_Q) + + current_Q1, current_Q2 = self.critic(obs, action) + + critic_loss = F.mse_loss(current_Q1, + target_Q) + F.mse_loss(current_Q2, target_Q) + L.log('train_critic/loss', critic_loss, step) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + self.critic.log(L, step) + + def update_actor(self, obs, L, step): + action = self.actor(obs, detach_encoder=True) + actor_Q = self.critic.Q1(obs, action, detach_encoder=True) + actor_loss = -actor_Q.mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.actor.log(L, step) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample() + + L.log('train/batch_reward', reward.mean(), step) + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor(obs, L, step) + + if step % self.target_update_freq == 0: + utils.soft_update_params(self.critic, self.critic_target, self.tau) + utils.soft_update_params(self.actor, self.actor_target, self.tau) + + def save(self, model_dir, step): + torch.save( + self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) + ) + torch.save( + self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) + ) + + def load(self, model_dir, step): + self.actor.load_state_dict( + torch.load('%s/actor_%s.pt' % (model_dir, step)) + ) + self.critic.load_state_dict( + torch.load('%s/critic_%s.pt' % (model_dir, step)) + ) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..2af856d --- /dev/null +++ b/train.py @@ -0,0 +1,315 @@ +import numpy as np +import torch +import argparse +import os +import math +import gym +import sys +import random +import time +import json +import dmc2gym +import copy + +import utils +from logger import Logger +from video import VideoRecorder + +from sac import SACAgent +from td3 import TD3Agent +from ddpg import DDPGAgent + + +def parse_args(): + parser = argparse.ArgumentParser() + # environment + parser.add_argument('--domain_name', default='cheetah') + parser.add_argument('--task_name', default='run') + parser.add_argument('--image_size', default=84, type=int) + parser.add_argument('--action_repeat', default=1, type=int) + parser.add_argument('--frame_stack', default=3, type=int) + # replay buffer + parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) + # train + parser.add_argument('--agent', default='sac', type=str) + parser.add_argument('--init_steps', default=1000, type=int) + parser.add_argument('--num_train_steps', default=1000000, type=int) + parser.add_argument('--batch_size', default=512, type=int) + parser.add_argument('--hidden_dim', default=256, type=int) + # eval + parser.add_argument('--eval_freq', default=10000, type=int) + parser.add_argument('--num_eval_episodes', default=10, type=int) + # critic + parser.add_argument('--critic_lr', default=1e-3, type=float) + parser.add_argument('--critic_beta', default=0.9, type=float) + parser.add_argument('--critic_tau', default=0.005, type=float) + parser.add_argument('--critic_target_update_freq', default=2, type=int) + # actor + parser.add_argument('--actor_lr', default=1e-3, type=float) + parser.add_argument('--actor_beta', default=0.9, type=float) + parser.add_argument('--actor_log_std_min', default=-10, type=float) + parser.add_argument('--actor_log_std_max', default=2, type=float) + parser.add_argument('--actor_update_freq', default=2, type=int) + # encoder/decoder + parser.add_argument('--encoder_type', default='identity', type=str) + parser.add_argument('--encoder_feature_dim', default=50, type=int) + parser.add_argument('--encoder_lr', default=1e-3, type=float) + parser.add_argument('--encoder_tau', default=0.005, type=float) + parser.add_argument('--decoder_type', default='identity', type=str) + parser.add_argument('--decoder_lr', default=1e-3, type=float) + parser.add_argument('--decoder_update_freq', default=1, type=int) + parser.add_argument('--decoder_latent_lambda', default=0.0, type=float) + parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) + parser.add_argument('--decoder_kl_lambda', default=0.0, type=float) + parser.add_argument('--num_layers', default=4, type=int) + parser.add_argument('--num_filters', default=32, type=int) + parser.add_argument('--freeze_encoder', default=False, action='store_true') + parser.add_argument('--use_dynamics', default=False, action='store_true') + # sac + parser.add_argument('--discount', default=0.99, type=float) + parser.add_argument('--init_temperature', default=0.01, type=float) + parser.add_argument('--alpha_lr', default=1e-3, type=float) + parser.add_argument('--alpha_beta', default=0.9, type=float) + # td3 + parser.add_argument('--policy_noise', default=0.2, type=float) + parser.add_argument('--expl_noise', default=0.1, type=float) + parser.add_argument('--noise_clip', default=0.5, type=float) + parser.add_argument('--tau', default=0.005, type=float) + # misc + parser.add_argument('--seed', default=1, type=int) + parser.add_argument('--work_dir', default='.', type=str) + parser.add_argument('--save_tb', default=False, action='store_true') + parser.add_argument('--save_model', default=False, action='store_true') + parser.add_argument('--save_buffer', default=False, action='store_true') + parser.add_argument('--save_video', default=False, action='store_true') + parser.add_argument('--pretrained_info', default=None, type=str) + parser.add_argument('--pretrained_decoder', default=False, action='store_true') + + args = parser.parse_args() + return args + + +def evaluate(env, agent, video, num_episodes, L, step): + for i in range(num_episodes): + obs = env.reset() + video.init(enabled=(i == 0)) + done = False + episode_reward = 0 + while not done: + with utils.eval_mode(agent): + action = agent.select_action(obs) + obs, reward, done, _ = env.step(action) + video.record(env) + episode_reward += reward + + video.save('%d.mp4' % step) + L.log('eval/episode_reward', episode_reward, step) + L.dump(step) + + +def make_agent(obs_shape, state_shape, action_shape, args, device): + if args.agent == 'sac': + return SACAgent( + obs_shape=obs_shape, + state_shape=state_shape, + action_shape=action_shape, + device=device, + hidden_dim=args.hidden_dim, + discount=args.discount, + init_temperature=args.init_temperature, + alpha_lr=args.alpha_lr, + alpha_beta=args.alpha_beta, + actor_lr=args.actor_lr, + actor_beta=args.actor_beta, + actor_log_std_min=args.actor_log_std_min, + actor_log_std_max=args.actor_log_std_max, + actor_update_freq=args.actor_update_freq, + critic_lr=args.critic_lr, + critic_beta=args.critic_beta, + critic_tau=args.critic_tau, + critic_target_update_freq=args.critic_target_update_freq, + encoder_type=args.encoder_type, + encoder_feature_dim=args.encoder_feature_dim, + encoder_lr=args.encoder_lr, + encoder_tau=args.encoder_tau, + decoder_type=args.decoder_type, + decoder_lr=args.decoder_lr, + decoder_update_freq=args.decoder_update_freq, + decoder_latent_lambda=args.decoder_latent_lambda, + decoder_weight_lambda=args.decoder_weight_lambda, + decoder_kl_lambda=args.decoder_kl_lambda, + num_layers=args.num_layers, + num_filters=args.num_filters, + freeze_encoder=args.freeze_encoder, + use_dynamics=args.use_dynamics + ) + elif args.agent == 'td3': + return TD3Agent( + obs_shape=obs_shape, + action_shape=action_shape, + device=device, + discount=args.discount, + tau=args.tau, + policy_noise=args.policy_noise, + noise_clip=args.noise_clip, + expl_noise=args.expl_noise, + actor_lr=args.actor_lr, + critic_lr=args.critic_lr, + encoder_type=args.encoder_type, + encoder_feature_dim=args.encoder_feature_dim, + actor_update_freq=args.actor_update_freq, + target_update_freq=args.critic_target_update_freq + ) + elif args.agent == 'ddpg': + return DDPGAgent( + obs_shape=obs_shape, + action_shape=action_shape, + device=device, + discount=args.discount, + tau=args.tau, + actor_lr=args.actor_lr, + critic_lr=args.critic_lr, + encoder_type=args.encoder_type, + encoder_feature_dim=args.encoder_feature_dim + ) + else: + assert 'agent is not supported: %s' % args.agent + + +def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder): + path, version = pretrained_info.split(':') + + pretrained_agent = copy.deepcopy(agent) + pretrained_agent.load(path, int(version)) + agent.critic.encoder.load_state_dict( + pretrained_agent.critic.encoder.state_dict() + ) + agent.actor.encoder.load_state_dict( + pretrained_agent.actor.encoder.state_dict() + ) + + if pretrained_decoder: + agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict()) + + return agent + + +def main(): + args = parse_args() + utils.set_seed_everywhere(args.seed) + + env = dmc2gym.make( + domain_name=args.domain_name, + task_name=args.task_name, + seed=args.seed, + visualize_reward=False, + from_pixels=(args.encoder_type == 'pixel'), + height=args.image_size, + width=args.image_size, + frame_skip=args.action_repeat + ) + env.seed(args.seed) + + # stack several consecutive frames together + if args.encoder_type == 'pixel': + env = utils.FrameStack(env, k=args.frame_stack) + + utils.make_dir(args.work_dir) + video_dir = utils.make_dir(os.path.join(args.work_dir, 'video')) + model_dir = utils.make_dir(os.path.join(args.work_dir, 'model')) + buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer')) + + video = VideoRecorder(video_dir if args.save_video else None) + + with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: + json.dump(vars(args), f, sort_keys=True, indent=4) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # the dmc2gym wrapper standardizes actions + assert env.action_space.low.min() >= -1 + assert env.action_space.high.max() <= 1 + + replay_buffer = utils.ReplayBuffer( + obs_shape=env.observation_space.shape, + state_shape=env.state_space.shape, + action_shape=env.action_space.shape, + capacity=args.replay_buffer_capacity, + batch_size=args.batch_size, + device=device + ) + + agent = make_agent( + obs_shape=env.observation_space.shape, + state_shape=env.state_space.shape, + action_shape=env.action_space.shape, + args=args, + device=device + ) + + if args.pretrained_info is not None: + agent = load_pretrained_encoder( + agent, args.pretrained_info, args.pretrained_decoder + ) + + L = Logger(args.work_dir, use_tb=args.save_tb) + + episode, episode_reward, done = 0, 0, True + start_time = time.time() + for step in range(args.num_train_steps): + if done: + if step > 0: + L.log('train/duration', time.time() - start_time, step) + start_time = time.time() + L.dump(step) + + # evaluate agent periodically + if step % args.eval_freq == 0: + L.log('eval/episode', episode, step) + evaluate(env, agent, video, args.num_eval_episodes, L, step) + if args.save_model: + agent.save(model_dir, step) + if args.save_buffer: + replay_buffer.save(buffer_dir) + + L.log('train/episode_reward', episode_reward, step) + + obs = env.reset() + done = False + episode_reward = 0 + episode_step = 0 + episode += 1 + + L.log('train/episode', episode, step) + + # sample action for data collection + if step < args.init_steps: + action = env.action_space.sample() + else: + with utils.eval_mode(agent): + action = agent.sample_action(obs) + + # run training update + if step >= args.init_steps: + num_updates = args.init_steps if step == args.init_steps else 1 + for _ in range(num_updates): + agent.update(replay_buffer, L, step) + + state = env.env.env._current_state + next_obs, reward, done, _ = env.step(action) + next_state = env.env.env._current_state.shape + + # allow infinit bootstrap + done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( + done + ) + episode_reward += reward + + replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state) + + obs = next_obs + episode_step += 1 + + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..4f0c6c1 --- /dev/null +++ b/utils.py @@ -0,0 +1,182 @@ +import torch +import numpy as np +import torch.nn as nn +import gym +import os +from collections import deque +import random + + +class eval_mode(object): + def __init__(self, *models): + self.models = models + + def __enter__(self): + self.prev_states = [] + for model in self.models: + self.prev_states.append(model.training) + model.train(False) + + def __exit__(self, *args): + for model, state in zip(self.models, self.prev_states): + model.train(state) + return False + + +def soft_update_params(net, target_net, tau): + for param, target_param in zip(net.parameters(), target_net.parameters()): + target_param.data.copy_( + tau * param.data + (1 - tau) * target_param.data + ) + + +def set_seed_everywhere(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def module_hash(module): + result = 0 + for tensor in module.state_dict().values(): + result += tensor.sum().item() + return result + + +def make_dir(dir_path): + try: + os.mkdir(dir_path) + except OSError: + pass + return dir_path + + +def preprocess_obs(obs, bits=5): + """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" + bins = 2**bits + assert obs.dtype == torch.float32 + if bits < 8: + obs = torch.floor(obs / 2**(8 - bits)) + obs = obs / bins + obs = obs + torch.rand_like(obs) / bins + obs = obs - 0.5 + return obs + + +class ReplayBuffer(object): + """Buffer to store environment transitions.""" + def __init__( + self, obs_shape, state_shape, action_shape, capacity, batch_size, + device + ): + self.capacity = capacity + self.batch_size = batch_size + self.device = device + + # the proprioceptive obs is stored as float32, pixels obs as uint8 + obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 + + self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) + self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) + self.actions = np.empty((capacity, *action_shape), dtype=np.float32) + self.rewards = np.empty((capacity, 1), dtype=np.float32) + self.not_dones = np.empty((capacity, 1), dtype=np.float32) + self.states = np.empty((capacity, *state_shape), dtype=np.float32) + self.next_states = np.empty((capacity, *state_shape), dtype=np.float32) + + self.idx = 0 + self.last_save = 0 + self.full = False + + def add(self, obs, action, reward, next_obs, done, state, next_state): + np.copyto(self.obses[self.idx], obs) + np.copyto(self.actions[self.idx], action) + np.copyto(self.rewards[self.idx], reward) + np.copyto(self.next_obses[self.idx], next_obs) + np.copyto(self.not_dones[self.idx], not done) + np.copyto(self.states[self.idx], state) + np.copyto(self.next_states[self.idx], next_state) + + self.idx = (self.idx + 1) % self.capacity + self.full = self.full or self.idx == 0 + + def sample(self): + idxs = np.random.randint( + 0, self.capacity if self.full else self.idx, size=self.batch_size + ) + + obses = torch.as_tensor(self.obses[idxs], device=self.device).float() + actions = torch.as_tensor(self.actions[idxs], device=self.device) + rewards = torch.as_tensor(self.rewards[idxs], device=self.device) + next_obses = torch.as_tensor( + self.next_obses[idxs], device=self.device + ).float() + not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) + states = torch.as_tensor(self.states[idxs], device=self.device) + + return obses, actions, rewards, next_obses, not_dones, states + + def save(self, save_dir): + if self.idx == self.last_save: + return + path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx)) + payload = [ + self.obses[self.last_save:self.idx], + self.next_obses[self.last_save:self.idx], + self.actions[self.last_save:self.idx], + self.rewards[self.last_save:self.idx], + self.not_dones[self.last_save:self.idx], + self.states[self.last_save:self.idx], + self.next_states[self.last_save:self.idx] + ] + self.last_save = self.idx + torch.save(payload, path) + + def load(self, save_dir): + chunks = os.listdir(save_dir) + chucks = sorted(chunks, key=lambda x: int(x.split('_')[0])) + for chunk in chucks: + start, end = [int(x) for x in chunk.split('.')[0].split('_')] + path = os.path.join(save_dir, chunk) + payload = torch.load(path) + assert self.idx == start + self.obses[start:end] = payload[0] + self.next_obses[start:end] = payload[1] + self.actions[start:end] = payload[2] + self.rewards[start:end] = payload[3] + self.not_dones[start:end] = payload[4] + self.states[start:end] = payload[5] + self.next_states[start:end] = payload[6] + self.idx = end + + +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + gym.Wrapper.__init__(self, env) + self._k = k + self._frames = deque([], maxlen=k) + shp = env.observation_space.shape + self.observation_space = gym.spaces.Box( + low=0, + high=1, + shape=((shp[0] * k,) + shp[1:]), + dtype=env.observation_space.dtype + ) + self._max_episode_steps = env._max_episode_steps + + def reset(self): + obs = self.env.reset() + for _ in range(self._k): + self._frames.append(obs) + return self._get_obs() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self._frames.append(obs) + return self._get_obs(), reward, done, info + + def _get_obs(self): + assert len(self._frames) == self._k + return np.concatenate(list(self._frames), axis=0) diff --git a/video.py b/video.py new file mode 100644 index 0000000..36588e6 --- /dev/null +++ b/video.py @@ -0,0 +1,32 @@ +import imageio +import os +import numpy as np + + +class VideoRecorder(object): + def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): + self.dir_name = dir_name + self.height = height + self.width = width + self.camera_id = camera_id + self.fps = fps + self.frames = [] + + def init(self, enabled=True): + self.frames = [] + self.enabled = self.dir_name is not None and enabled + + def record(self, env): + if self.enabled: + frame = env.render( + mode='rgb_array', + height=self.height, + width=self.width, + camera_id=self.camera_id + ) + self.frames.append(frame) + + def save(self, file_name): + if self.enabled: + path = os.path.join(self.dir_name, file_name) + imageio.mimsave(path, self.frames, fps=self.fps)