commit 681e13b12a42497d087384fe076ff1e1fe7da716 Author: Denis Yarats Date: Mon Sep 23 11:20:48 2019 -0700 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..37a0c9a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +.ipynb_checkpoints/ +runs diff --git a/README.md b/README.md new file mode 100644 index 0000000..5573476 --- /dev/null +++ b/README.md @@ -0,0 +1,75 @@ +# Soft Actor-Critic implementaiton in PyTorch + + +## Running locally +To train SAC locally one can use provided `run_local.sh` script (change it to modify particular arguments): +``` +./run_local.sh +``` +This will produce a folder (`./save`) by default, where all the output is going to be stored including train/eval logs, tensorboard blobs, evaluation videos, and model snapshots. It is possible to attach tensorboard to a particular run using the following command: +``` +tensorboard --logdir save +``` +Then open up tensorboad in your browser. + +You will also see some console output, something like this: +``` +| train | E: 1 | S: 1000 | D: 0.8 s | R: 0.0000 | BR: 0.0000 | ALOSS: 0.0000 | CLOSS: 0.0000 | RLOSS: 0.0000 +``` +This line means: +``` +train - training episode +E - total number of episodes +S - total number of environment steps +D - duration in seconds to train 1 episode +R - episode reward +BR - average reward of sampled batch +ALOSS - average loss of actor +CLOSS - average loss of critic +RLOSS - average reconstruction loss (only if is trained from pixels and decoder) +``` +These are just the most important number, more of all other metrics can be found in tensorboard. +Also, besides training, once in a while there is evaluation output, like this: +``` +| eval | S: 0 | ER: 21.1676 +``` +Which just tells the expected reward `ER` evaluating current policy after `S` steps. Note that `ER` is average evaluation performance over `num_eval_episodes` episodes (usually 10). + +## Running on the cluster +You can find the `run_cluster.sh` script file that allows you run training on the cluster. It is a simple bash script, that is super easy to modify. We usually run 10 different seeds for each configuration to get reliable results. For example to schedule 10 runs of `walker walk` simple do this: +``` +./run_cluster.sh walker walk +``` +This script will schedule 10 jobs and all the output will be stored under `./runs/walker_walk/{configuration_name}/seed_i`. The folder structure looks like this: +``` +runs/ + walker_walk/ + sac_states/ + seed_1/ + id # slurm job id + stdout # standard output of your job + stderr # standard error of your jobs + run.sh # starting script + run.slrm # slurm script + eval.log # log file for evaluation + train.log # log file for training + tb/ # folder that stores tensorboard output + video/ # folder stores evaluation videos + 10000.mp4 # video of one episode after 10000 steps + seed_2/ + ... +``` +Again, you can attach tensorboard to a particular configuration, for example: +``` +tensorboard --logdir runs/walker_walk/sac_states +``` + +For convinience, you can also use an iPython notebook to get aggregated over 10 seeds results. An example of such notebook is `runs.ipynb` + + +## Run entire testbed +Another scirpt that allow to run all 10 dm_control task on the cluster is here: +``` +./run_all.sh +``` +It will call `run_cluster.sh` for each task, so you only need to modify `run_cluster.sh` to change the hyper parameters. diff --git a/ddpg.py b/ddpg.py new file mode 100644 index 0000000..2b03972 --- /dev/null +++ b/ddpg.py @@ -0,0 +1,209 @@ +# Code is taken from https://github.com/sfujim/TD3 with slight modifications + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import utils +from encoder import make_encoder + +LOG_FREQ = 10000 + + +class Actor(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + self.l1 = nn.Linear(self.encoder.feature_dim, 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, action_shape[0]) + + self.outputs = dict() + + def forward(self, obs, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + h = F.relu(self.l1(obs)) + h = F.relu(self.l2(h)) + action = torch.tanh(self.l3(h)) + self.outputs['mu'] = action + return action + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_actor/%s_hist' % k, v, step) + + L.log_param('train_actor/fc1', self.l1, step) + L.log_param('train_actor/fc2', self.l2, step) + L.log_param('train_actor/fc3', self.l3, step) + + +class Critic(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, 1) + + self.outputs = dict() + + def forward(self, obs, action, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + obs_action = torch.cat([obs, action], dim=1) + h = F.relu(self.l1(obs_action)) + h = F.relu(self.l2(h)) + q = self.l3(h) + self.outputs['q'] = q + return q + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + self.encoder.log(L, step, log_freq) + + for k, v in self.outputs.items(): + L.log_histogram('train_critic/%s_hist' % k, v, step) + + L.log_param('train_critic/fc1', self.l1, step) + L.log_param('train_critic/fc2', self.l2, step) + L.log_param('train_critic/fc3', self.l3, step) + + +class DDPGAgent(object): + def __init__( + self, + obs_shape, + action_shape, + device, + discount=0.99, + tau=0.005, + actor_lr=1e-3, + critic_lr=1e-3, + encoder_type='identity', + encoder_feature_dim=50 + ): + self.device = device + self.discount = discount + self.tau = tau + + # models + self.actor = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.critic = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.actor.encoder.copy_conv_weights_from(self.critic.encoder) + + self.actor_target = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.actor_target.load_state_dict(self.actor.state_dict()) + + self.critic_target = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.critic_target.load_state_dict(self.critic.state_dict()) + + # optimizers + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=actor_lr + ) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=critic_lr + ) + + self.train() + self.critic_target.train() + self.actor_target.train() + + def train(self, training=True): + self.training = training + self.actor.train(training) + self.critic.train(training) + + def select_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + action = self.actor(obs) + return action.cpu().data.numpy().flatten() + + def sample_action(self, obs): + return self.select_action(obs) + + def update_critic(self, obs, action, reward, next_obs, not_done, L, step): + with torch.no_grad(): + target_Q = self.critic_target( + next_obs, self.actor_target(next_obs) + ) + target_Q = reward + (not_done * self.discount * target_Q) + + current_Q = self.critic(obs, action) + + critic_loss = F.mse_loss(current_Q, target_Q) + L.log('train_critic/loss', critic_loss, step) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + self.critic.log(L, step) + + def update_actor(self, obs, L, step): + action = self.actor(obs, detach_encoder=True) + actor_Q = self.critic(obs, action, detach_encoder=True) + actor_loss = -actor_Q.mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.actor.log(L, step) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample() + + L.log('train/batch_reward', reward.mean(), step) + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + self.update_actor(obs, L, step) + + utils.soft_update_params(self.critic, self.critic_target, self.tau) + utils.soft_update_params(self.actor, self.actor_target, self.tau) + + def save(self, model_dir, step): + torch.save( + self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) + ) + torch.save( + self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) + ) + + def load(self, model_dir, step): + self.actor.load_state_dict( + torch.load('%s/actor_%s.pt' % (model_dir, step)) + ) + self.critic.load_state_dict( + torch.load('%s/critic_%s.pt' % (model_dir, step)) + ) diff --git a/decoder.py b/decoder.py new file mode 100644 index 0000000..3fbb4a7 --- /dev/null +++ b/decoder.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + +from encoder import OUT_DIM + + +class PixelDecoder(nn.Module): + def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32): + super().__init__() + + self.num_layers = num_layers + self.num_filters = num_filters + self.out_dim = OUT_DIM[num_layers] + + self.fc = nn.Linear( + feature_dim, num_filters * self.out_dim * self.out_dim + ) + + self.deconvs = nn.ModuleList() + + for i in range(self.num_layers - 1): + self.deconvs.append( + nn.ConvTranspose2d(num_filters, num_filters, 3, stride=1) + ) + self.deconvs.append( + nn.ConvTranspose2d( + num_filters, obs_shape[0], 3, stride=2, output_padding=1 + ) + ) + + self.outputs = dict() + + def forward(self, h): + h = torch.relu(self.fc(h)) + self.outputs['fc'] = h + + deconv = h.view(-1, self.num_filters, self.out_dim, self.out_dim) + self.outputs['deconv1'] = deconv + + for i in range(0, self.num_layers - 1): + deconv = torch.relu(self.deconvs[i](deconv)) + self.outputs['deconv%s' % (i + 1)] = deconv + + obs = self.deconvs[-1](deconv) + self.outputs['obs'] = obs + + return obs + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_decoder/%s_hist' % k, v, step) + if len(v.shape) > 2: + L.log_image('train_decoder/%s_i' % k, v[0], step) + + for i in range(self.num_layers): + L.log_param( + 'train_decoder/deconv%s' % (i + 1), self.deconvs[i], step + ) + L.log_param('train_decoder/fc', self.fc, step) + + +class StateDecoder(nn.Module): + def __init__(self, obs_shape, feature_dim): + super().__init__() + + assert len(obs_shape) == 1 + + self.trunk = nn.Sequential( + nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024), + nn.ReLU(), nn.Linear(1024, obs_shape[0]), nn.ReLU() + ) + + self.outputs = dict() + + def forward(self, obs, detach=False): + h = self.trunk(obs) + if detach: + h = h.detach() + self.outputs['h'] = h + return h + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + L.log_param('train_encoder/fc1', self.trunk[0], step) + L.log_param('train_encoder/fc2', self.trunk[2], step) + for k, v in self.outputs.items(): + L.log_histogram('train_encoder/%s_hist' % k, v, step) + + +_AVAILABLE_DECODERS = {'pixel': PixelDecoder, 'state': StateDecoder} + + +def make_decoder( + decoder_type, obs_shape, feature_dim, num_layers, num_filters +): + assert decoder_type in _AVAILABLE_DECODERS + if decoder_type == 'pixel': + return _AVAILABLE_DECODERS[decoder_type]( + obs_shape, feature_dim, num_layers, num_filters + ) + return _AVAILABLE_DECODERS[decoder_type](obs_shape, feature_dim) diff --git a/encoder.py b/encoder.py new file mode 100644 index 0000000..0f2e581 --- /dev/null +++ b/encoder.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn + + +def tie_weights(src, trg): + assert type(src) == type(trg) + trg.weight = src.weight + trg.bias = src.bias + + +OUT_DIM = {2: 39, 4: 35, 6: 31} + + +class PixelEncoder(nn.Module): + """Convolutional encoder of pixels observations.""" + def __init__( + self, + obs_shape, + feature_dim, + num_layers=2, + num_filters=32, + stochastic=False + ): + super().__init__() + + assert len(obs_shape) == 3 + + self.feature_dim = feature_dim + self.num_layers = num_layers + self.stochastic = stochastic + + self.convs = nn.ModuleList( + [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] + ) + for i in range(num_layers - 1): + self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) + + out_dim = OUT_DIM[num_layers] + self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) + self.ln = nn.LayerNorm(self.feature_dim) + + if self.stochastic: + self.log_std_min = -10 + self.log_std_max = 2 + self.fc_log_std = nn.Linear( + num_filters * out_dim * out_dim, self.feature_dim + ) + + self.outputs = dict() + + def reparameterize(self, mu, logstd): + std = torch.exp(logstd) + eps = torch.randn_like(std) + return mu + eps * std + + def forward_conv(self, obs): + obs = obs / 255. + self.outputs['obs'] = obs + + conv = torch.relu(self.convs[0](obs)) + self.outputs['conv1'] = conv + + for i in range(1, self.num_layers): + conv = torch.relu(self.convs[i](conv)) + self.outputs['conv%s' % (i + 1)] = conv + + h = conv.view(conv.size(0), -1) + return h + + def forward(self, obs, detach=False): + h = self.forward_conv(obs) + + if detach: + h = h.detach() + + h_fc = self.fc(h) + self.outputs['fc'] = h_fc + + h_norm = self.ln(h_fc) + self.outputs['ln'] = h_norm + + out = torch.tanh(h_norm) + + if self.stochastic: + self.outputs['mu'] = out + log_std = torch.tanh(self.fc_log_std(h)) + # normalize + log_std = self.log_std_min + 0.5 * ( + self.log_std_max - self.log_std_min + ) * (log_std + 1) + out = self.reparameterize(out, log_std) + self.outputs['log_std'] = log_std + + self.outputs['tanh'] = out + + return out + + def copy_conv_weights_from(self, source): + """Tie convolutional layers""" + # only tie conv layers + for i in range(self.num_layers): + tie_weights(src=source.convs[i], trg=self.convs[i]) + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_encoder/%s_hist' % k, v, step) + if len(v.shape) > 2: + L.log_image('train_encoder/%s_img' % k, v[0], step) + + for i in range(self.num_layers): + L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) + L.log_param('train_encoder/fc', self.fc, step) + L.log_param('train_encoder/ln', self.ln, step) + + +class StateEncoder(nn.Module): + def __init__(self, obs_shape, feature_dim): + super().__init__() + + assert len(obs_shape) == 1 + self.feature_dim = feature_dim + + self.trunk = nn.Sequential( + nn.Linear(obs_shape[0], 256), nn.ReLU(), + nn.Linear(256, feature_dim), nn.ReLU() + ) + + self.outputs = dict() + + def forward(self, obs, detach=False): + h = self.trunk(obs) + if detach: + h = h.detach() + self.outputs['h'] = h + return h + + def copy_conv_weights_from(self, source): + pass + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + L.log_param('train_encoder/fc1', self.trunk[0], step) + L.log_param('train_encoder/fc2', self.trunk[2], step) + for k, v in self.outputs.items(): + L.log_histogram('train_encoder/%s_hist' % k, v, step) + + +class IdentityEncoder(nn.Module): + def __init__(self, obs_shape, feature_dim): + super().__init__() + + assert len(obs_shape) == 1 + self.feature_dim = obs_shape[0] + + def forward(self, obs, detach=False): + return obs + + def copy_conv_weights_from(self, source): + pass + + def log(self, L, step, log_freq): + pass + + +_AVAILABLE_ENCODERS = { + 'pixel': PixelEncoder, + 'state': StateEncoder, + 'identity': IdentityEncoder +} + + +def make_encoder( + encoder_type, obs_shape, feature_dim, num_layers, num_filters, stochastic +): + assert encoder_type in _AVAILABLE_ENCODERS + if encoder_type == 'pixel': + return _AVAILABLE_ENCODERS[encoder_type]( + obs_shape, feature_dim, num_layers, num_filters, stochastic + ) + return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim) diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..93ff12e --- /dev/null +++ b/logger.py @@ -0,0 +1,165 @@ +from torch.utils.tensorboard import SummaryWriter +from collections import defaultdict +import json +import os +import shutil +import torch +import torchvision +import numpy as np +from termcolor import colored + + +FORMAT_CONFIG = { + 'rl': { + 'train': [('episode', 'E', 'int'), + ('step', 'S', 'int'), + ('duration', 'D', 'time'), + ('episode_reward', 'R', 'float'), + ('batch_reward', 'BR', 'float'), + ('actor_loss', 'ALOSS', 'float'), + ('critic_loss', 'CLOSS', 'float'), + ('ae_loss', 'RLOSS', 'float')], + 'eval': [('step', 'S', 'int'), + ('episode_reward', 'ER', 'float')] + } +} + + +class AverageMeter(object): + def __init__(self): + self._sum = 0 + self._count = 0 + + def update(self, value, n=1): + self._sum += value + self._count += n + + def value(self): + return self._sum / max(1, self._count) + + +class MetersGroup(object): + def __init__(self, file_name, formating): + self._file_name = file_name + if os.path.exists(file_name): + os.remove(file_name) + self._formating = formating + self._meters = defaultdict(AverageMeter) + + def log(self, key, value, n=1): + self._meters[key].update(value, n) + + def _prime_meters(self): + data = dict() + for key, meter in self._meters.items(): + if key.startswith('train'): + key = key[len('train') + 1:] + else: + key = key[len('eval') + 1:] + key = key.replace('/', '_') + data[key] = meter.value() + return data + + def _dump_to_file(self, data): + with open(self._file_name, 'a') as f: + f.write(json.dumps(data) + '\n') + + def _format(self, key, value, ty): + template = '%s: ' + if ty == 'int': + template += '%d' + elif ty == 'float': + template += '%.04f' + elif ty == 'time': + template += '%.01f s' + else: + raise 'invalid format type: %s' % ty + return template % (key, value) + + def _dump_to_console(self, data, prefix): + prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') + pieces = ['{:5}'.format(prefix)] + for key, disp_key, ty in self._formating: + value = data.get(key, 0) + pieces.append(self._format(disp_key, value, ty)) + print('| %s' % (' | '.join(pieces))) + + def dump(self, step, prefix): + if len(self._meters) == 0: + return + data = self._prime_meters() + data['step'] = step + self._dump_to_file(data) + self._dump_to_console(data, prefix) + self._meters.clear() + + +class Logger(object): + def __init__(self, log_dir, use_tb=True, config='rl'): + self._log_dir = log_dir + if use_tb: + tb_dir = os.path.join(log_dir, 'tb') + if os.path.exists(tb_dir): + shutil.rmtree(tb_dir) + self._sw = SummaryWriter(tb_dir) + else: + self._sw = None + self._train_mg = MetersGroup( + os.path.join(log_dir, 'train.log'), + formating=FORMAT_CONFIG[config]['train']) + self._eval_mg = MetersGroup( + os.path.join(log_dir, 'eval.log'), + formating=FORMAT_CONFIG[config]['eval']) + + def _try_sw_log(self, key, value, step): + if self._sw is not None: + self._sw.add_scalar(key, value, step) + + def _try_sw_log_image(self, key, image, step): + if self._sw is not None: + assert image.dim() == 3 + grid = torchvision.utils.make_grid(image.unsqueeze(1)) + self._sw.add_image(key, grid, step) + + def _try_sw_log_video(self, key, frames, step): + if self._sw is not None: + frames = torch.from_numpy(np.array(frames)) + frames = frames.unsqueeze(0) + self._sw.add_video(key, frames, step, fps=30) + + def _try_sw_log_histogram(self, key, histogram, step): + if self._sw is not None: + self._sw.add_histogram(key, histogram, step) + + def log(self, key, value, step, n=1): + assert key.startswith('train') or key.startswith('eval') + if type(value) == torch.Tensor: + value = value.item() + self._try_sw_log(key, value / n, step) + mg = self._train_mg if key.startswith('train') else self._eval_mg + mg.log(key, value, n) + + def log_param(self, key, param, step): + self.log_histogram(key + '_w', param.weight.data, step) + if hasattr(param.weight, 'grad') and param.weight.grad is not None: + self.log_histogram(key + '_w_g', param.weight.grad.data, step) + if hasattr(param, 'bias'): + self.log_histogram(key + '_b', param.bias.data, step) + if hasattr(param.bias, 'grad') and param.bias.grad is not None: + self.log_histogram(key + '_b_g', param.bias.grad.data, step) + + def log_image(self, key, image, step): + assert key.startswith('train') or key.startswith('eval') + self._try_sw_log_image(key, image, step) + + def log_video(self, key, frames, step): + assert key.startswith('train') or key.startswith('eval') + self._try_sw_log_video(key, frames, step) + + def log_histogram(self, key, histogram, step): + assert key.startswith('train') or key.startswith('eval') + self._try_sw_log_histogram(key, histogram, step) + + def dump(self, step): + self._train_mg.dump(step, 'train') + self._eval_mg.dump(step, 'eval') diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..78d2e41 --- /dev/null +++ b/run.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +DOMAIN=cheetah +TASK=run +ACTION_REPEAT=4 +ENCODER_TYPE=pixel +ENCODER_TYPE=pixel + + +WORK_DIR=./runs + +python train.py \ + --domain_name ${DOMAIN} \ + --task_name ${TASK} \ + --encoder_type ${ENCODER_TYPE} \ + --decoder_type ${DECODER_TYPE} \ + --action_repeat ${ACTION_REPEAT} \ + --save_video \ + --save_tb \ + --work_dir ${WORK_DIR}/${DOMAIN}_{TASK}/_ae_encoder_${ENCODER_TYPE}_decoder_{ENCODER_TYPE} \ + --seed 1 diff --git a/sac.py b/sac.py new file mode 100644 index 0000000..d1a18f6 --- /dev/null +++ b/sac.py @@ -0,0 +1,507 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +import math + +import utils +from encoder import make_encoder +from decoder import make_decoder + +LOG_FREQ = 10000 + + +def gaussian_logprob(noise, log_std): + """Compute Gaussian log probability.""" + residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) + return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) + + +def squash(mu, pi, log_pi): + """Apply squashing function. + See appendix C from https://arxiv.org/pdf/1812.05905.pdf. + """ + mu = torch.tanh(mu) + if pi is not None: + pi = torch.tanh(pi) + if log_pi is not None: + log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) + return mu, pi, log_pi + + +def weight_init(m): + """Custom weight init for Conv2D and Linear layers.""" + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf + assert m.weight.size(2) == m.weight.size(3) + m.weight.data.fill_(0.0) + m.bias.data.fill_(0.0) + mid = m.weight.size(2) // 2 + gain = nn.init.calculate_gain('relu') + nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) + + +class Actor(nn.Module): + """MLP actor network.""" + def __init__( + self, obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters, + freeze_encoder, stochastic + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim, num_layers, + num_filters, stochastic + ) + + self.log_std_min = log_std_min + self.log_std_max = log_std_max + self.freeze_encoder = freeze_encoder + + self.trunk = nn.Sequential( + nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, 2 * action_shape[0]) + ) + + self.outputs = dict() + self.apply(weight_init) + + def forward( + self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False + ): + obs = self.encoder(obs, detach=detach_encoder) + + if self.freeze_encoder: + obs = obs.detach() + + mu, log_std = self.trunk(obs).chunk(2, dim=-1) + + # constrain log_std inside [log_std_min, log_std_max] + log_std = F.tanh(log_std) + log_std = self.log_std_min + 0.5 * ( + self.log_std_max - self.log_std_min + ) * (log_std + 1) + + self.outputs['mu'] = mu + self.outputs['std'] = log_std.exp() + + if compute_pi: + std = log_std.exp() + noise = torch.randn_like(mu) + pi = mu + noise * std + else: + pi = None + entropy = None + + if compute_log_pi: + log_pi = gaussian_logprob(noise, log_std) + else: + log_pi = None + + mu, pi, log_pi = squash(mu, pi, log_pi) + + return mu, pi, log_pi, log_std + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_actor/%s_hist' % k, v, step) + + L.log_param('train_actor/fc1', self.trunk[0], step) + L.log_param('train_actor/fc2', self.trunk[2], step) + L.log_param('train_actor/fc3', self.trunk[4], step) + + +class QFunction(nn.Module): + """MLP for q-function.""" + def __init__(self, obs_dim, action_dim, hidden_dim): + super().__init__() + + self.trunk = nn.Sequential( + nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, obs, action): + assert obs.size(0) == action.size(0) + + obs_action = torch.cat([obs, action], dim=1) + return self.trunk(obs_action) + + +class DynamicsModel(nn.Module): + def __init__(self, state_dim, action_dim, hidden_dim): + super().__init__() + + self.trunk = nn.Sequential( + nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, state_dim) + ) + + def forward(self, state, action): + assert state.size(0) == action.size(0) + + state_action = torch.cat([state, action], dim=1) + return self.trunk(state_action) + + +class Critic(nn.Module): + """Critic network, employes two q-functions.""" + def __init__( + self, obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, num_layers, num_filters, freeze_encoder, + use_dynamics, stochastic + ): + super().__init__() + + self.freeze_encoder = freeze_encoder + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim, num_layers, + num_filters, stochastic + ) + + if use_dynamics: + self.forward_model = DynamicsModel( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + + self.Q1 = QFunction( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + self.Q2 = QFunction( + self.encoder.feature_dim, action_shape[0], hidden_dim + ) + + self.outputs = dict() + self.apply(weight_init) + + def forward(self, obs, action, detach_encoder=False): + # detach_encoder allows to stop gradient propogation to encoder + obs = self.encoder(obs, detach=detach_encoder) + + if self.freeze_encoder: + obs = obs.detach() + + q1 = self.Q1(obs, action) + q2 = self.Q2(obs, action) + + self.outputs['q1'] = q1 + self.outputs['q2'] = q2 + + return q1, q2 + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + self.encoder.log(L, step, log_freq) + + for k, v in self.outputs.items(): + L.log_histogram('train_critic/%s_hist' % k, v, step) + + for i in range(3): + L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) + L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) + + +class SACAgent(object): + """Soft Actor-Critic algorithm.""" + def __init__( + self, + obs_shape, + state_shape, + action_shape, + device, + hidden_dim=256, + discount=0.99, + init_temperature=0.01, + alpha_lr=1e-3, + alpha_beta=0.9, + actor_lr=1e-3, + actor_beta=0.9, + actor_log_std_min=-10, + actor_log_std_max=2, + actor_update_freq=2, + critic_lr=1e-3, + critic_beta=0.9, + critic_tau=0.005, + critic_target_update_freq=2, + encoder_type='identity', + encoder_feature_dim=50, + encoder_lr=1e-3, + encoder_tau=0.005, + decoder_type='identity', + decoder_lr=1e-3, + decoder_update_freq=1, + decoder_latent_lambda=0.0, + decoder_weight_lambda=0.0, + decoder_kl_lambda=0.0, + num_layers=4, + num_filters=32, + freeze_encoder=False, + use_dynamics=False + ): + self.device = device + self.discount = discount + self.critic_tau = critic_tau + self.encoder_tau = encoder_tau + self.actor_update_freq = actor_update_freq + self.critic_target_update_freq = critic_target_update_freq + self.decoder_update_freq = decoder_update_freq + self.decoder_latent_lambda = decoder_latent_lambda + self.decoder_kl_lambda = decoder_kl_lambda + self.decoder_type = decoder_type + self.use_dynamics = use_dynamics + + stochastic = decoder_kl_lambda > 0.0 + + self.actor = Actor( + obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, actor_log_std_min, actor_log_std_max, + num_layers, num_filters, freeze_encoder, stochastic + ).to(device) + + self.critic = Critic( + obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, num_layers, num_filters, freeze_encoder, + use_dynamics, stochastic + ).to(device) + + self.critic_target = Critic( + obs_shape, action_shape, hidden_dim, encoder_type, + encoder_feature_dim, num_layers, num_filters, freeze_encoder, + use_dynamics, stochastic + ).to(device) + + self.critic_target.load_state_dict(self.critic.state_dict()) + + # tie encoders between actor and critic + self.actor.encoder.copy_conv_weights_from(self.critic.encoder) + + self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) + self.log_alpha.requires_grad = True + # set target entropy to -|A| + self.target_entropy = -np.prod(action_shape) + + self.decoder = None + if decoder_type != 'identity': + # create decoder + shape = obs_shape if decoder_type == 'pixel' else state_shape + self.decoder = make_decoder( + decoder_type, shape, encoder_feature_dim, num_layers, + num_filters + ).to(device) + self.decoder.apply(weight_init) + + # optimizer for critic encoder for reconstruction loss + self.encoder_optimizer = torch.optim.Adam( + self.critic.encoder.parameters(), lr=encoder_lr + ) + + # optimizer for decoder + self.decoder_optimizer = torch.optim.Adam( + self.decoder.parameters(), + lr=decoder_lr, + weight_decay=decoder_weight_lambda + ) + + # optimizers + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999) + ) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999) + ) + + self.log_alpha_optimizer = torch.optim.Adam( + [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999) + ) + + self.train() + self.critic_target.train() + + def train(self, training=True): + self.training = training + self.actor.train(training) + self.critic.train(training) + if self.decoder is not None: + self.decoder.train(training) + + @property + def alpha(self): + return self.log_alpha.exp() + + def select_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + mu, _, _, _ = self.actor( + obs, compute_pi=False, compute_log_pi=False + ) + return mu.cpu().data.numpy().flatten() + + def sample_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + mu, pi, _, _ = self.actor(obs, compute_log_pi=False) + return pi.cpu().data.numpy().flatten() + + def update_critic(self, obs, action, reward, next_obs, not_done, L, step): + with torch.no_grad(): + _, policy_action, log_pi, _ = self.actor(next_obs) + target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) + target_V = torch.min(target_Q1, + target_Q2) - self.alpha.detach() * log_pi + target_Q = reward + (not_done * self.discount * target_V) + + # get current Q estimates + current_Q1, current_Q2 = self.critic(obs, action) + critic_loss = F.mse_loss(current_Q1, + target_Q) + F.mse_loss(current_Q2, target_Q) + L.log('train_critic/loss', critic_loss, step) + + # update dynamics (optional) + if self.use_dynamics: + h_obs = self.critic.encoder.outputs['mu'] + with torch.no_grad(): + next_latent = self.critic.encoder(next_obs) + pred_next_latent = self.critic.forward_model(h_obs, action) + dynamics_loss = F.mse_loss(pred_next_latent, next_latent) + L.log('train_critic/dynamics_loss', dynamics_loss, step) + critic_loss += dynamics_loss + + # Optimize the critic + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + self.critic.log(L, step) + + def update_actor_and_alpha(self, obs, L, step): + # detach encoder, so we don't update it with the actor loss + _, pi, log_pi, log_std = self.actor(obs, detach_encoder=True) + actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True) + + actor_Q = torch.min(actor_Q1, actor_Q2) + actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() + + L.log('train_actor/loss', actor_loss, step) + L.log('train_actor/target_entropy', self.target_entropy, step) + entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi) + ) + log_std.sum(dim=-1) + L.log('train_actor/entropy', entropy.mean(), step) + + # optimize the actor + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.actor.log(L, step) + + self.log_alpha_optimizer.zero_grad() + alpha_loss = (self.alpha * + (-log_pi - self.target_entropy).detach()).mean() + L.log('train_alpha/loss', alpha_loss, step) + L.log('train_alpha/value', self.alpha, step) + alpha_loss.backward() + self.log_alpha_optimizer.step() + + def update_decoder(self, obs, target_obs, L, step): + if self.decoder is None: + return + + h = self.critic.encoder(obs) + + if target_obs.dim() == 4: + # preprocess images to be in [-0.5, 0.5] range + target_obs = utils.preprocess_obs(target_obs) + rec_obs = self.decoder(h) + rec_loss = F.mse_loss(target_obs, rec_obs) + + # add L2 penalty on latent representation + # see https://arxiv.org/pdf/1903.12436.pdf + latent_loss = (0.5 * h.pow(2).sum(1)).mean() + + # add KL penalty for VAE + if self.decoder_kl_lambda > 0.0: + log_std = self.critic.encoder.outputs['log_std'] + mu = self.critic.encoder.outputs['mu'] + kl_div = -0.5 * (1 + 2 * log_std - mu.pow(2) - (2 * log_std).exp()) + kl_div = kl_div.sum(1).mean(0, True) + else: + kl_div = 0.0 + loss = rec_loss + self.decoder_latent_lambda * latent_loss + self.decoder_kl_lambda * kl_div + + self.encoder_optimizer.zero_grad() + self.decoder_optimizer.zero_grad() + loss.backward() + + self.encoder_optimizer.step() + self.decoder_optimizer.step() + L.log('train_ae/ae_loss', loss, step) + + self.decoder.log(L, step, log_freq=LOG_FREQ) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done, state = replay_buffer.sample() + + L.log('train/batch_reward', reward.mean(), step) + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + utils.soft_update_params( + self.critic.Q1, self.critic_target.Q1, self.critic_tau + ) + utils.soft_update_params( + self.critic.Q2, self.critic_target.Q2, self.critic_tau + ) + utils.soft_update_params( + self.critic.encoder, self.critic_target.encoder, + self.encoder_tau + ) + + if step % self.decoder_update_freq == 0: + target = obs if self.decoder_type == 'pixel' else state + self.update_decoder(obs, target, L, step) + + def save(self, model_dir, step): + torch.save( + self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) + ) + torch.save( + self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) + ) + if self.decoder is not None: + torch.save( + self.decoder.state_dict(), + '%s/decoder_%s.pt' % (model_dir, step) + ) + + def load(self, model_dir, step): + self.actor.load_state_dict( + torch.load('%s/actor_%s.pt' % (model_dir, step)) + ) + self.critic.load_state_dict( + torch.load('%s/critic_%s.pt' % (model_dir, step)) + ) + if self.decoder is not None: + self.decoder.load_state_dict( + torch.load('%s/decoder_%s.pt' % (model_dir, step)) + ) diff --git a/td3.py b/td3.py new file mode 100644 index 0000000..db6aaf5 --- /dev/null +++ b/td3.py @@ -0,0 +1,259 @@ +# Code is taken from https://github.com/sfujim/TD3 with slight modifications + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import utils +from encoder import make_encoder + +LOG_FREQ = 10000 + + +class Actor(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + self.l1 = nn.Linear(self.encoder.feature_dim, 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, action_shape[0]) + + self.outputs = dict() + + def forward(self, obs, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + h = F.relu(self.l1(obs)) + h = F.relu(self.l2(h)) + action = torch.tanh(self.l3(h)) + self.outputs['mu'] = action + return action + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_actor/%s_hist' % k, v, step) + + L.log_param('train_actor/fc1', self.l1, step) + L.log_param('train_actor/fc2', self.l2, step) + L.log_param('train_actor/fc3', self.l3, step) + + +class Critic(nn.Module): + def __init__( + self, obs_shape, action_shape, encoder_type, encoder_feature_dim + ): + super().__init__() + + self.encoder = make_encoder( + encoder_type, obs_shape, encoder_feature_dim + ) + + # Q1 architecture + self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) + self.l2 = nn.Linear(400, 300) + self.l3 = nn.Linear(300, 1) + + # Q2 architecture + self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400) + self.l5 = nn.Linear(400, 300) + self.l6 = nn.Linear(300, 1) + + self.outputs = dict() + + def forward(self, obs, action, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + + obs_action = torch.cat([obs, action], 1) + + h1 = F.relu(self.l1(obs_action)) + h1 = F.relu(self.l2(h1)) + q1 = self.l3(h1) + + h2 = F.relu(self.l4(obs_action)) + h2 = F.relu(self.l5(h2)) + q2 = self.l6(h2) + + self.outputs['q1'] = q1 + self.outputs['q2'] = q2 + + return q1, q2 + + def Q1(self, obs, action, detach_encoder=False): + obs = self.encoder(obs, detach=detach_encoder) + + obs_action = torch.cat([obs, action], 1) + + h1 = F.relu(self.l1(obs_action)) + h1 = F.relu(self.l2(h1)) + q1 = self.l3(h1) + return q1 + + def log(self, L, step, log_freq=LOG_FREQ): + if step % log_freq != 0: return + + self.encoder.log(L, step, log_freq) + + for k, v in self.outputs.items(): + L.log_histogram('train_critic/%s_hist' % k, v, step) + + L.log_param('train_critic/q1_fc1', self.l1, step) + L.log_param('train_critic/q1_fc2', self.l2, step) + L.log_param('train_critic/q1_fc3', self.l3, step) + L.log_param('train_critic/q1_fc4', self.l4, step) + L.log_param('train_critic/q1_fc5', self.l5, step) + L.log_param('train_critic/q1_fc6', self.l6, step) + + +class TD3Agent(object): + def __init__( + self, + obs_shape, + action_shape, + device, + discount=0.99, + tau=0.005, + policy_noise=0.2, + noise_clip=0.5, + expl_noise=0.1, + actor_lr=1e-3, + critic_lr=1e-3, + encoder_type='identity', + encoder_feature_dim=50, + actor_update_freq=2, + target_update_freq=2, + ): + self.device = device + self.discount = discount + self.tau = tau + self.policy_noise = policy_noise + self.noise_clip = noise_clip + self.expl_noise = expl_noise + self.actor_update_freq = actor_update_freq + self.target_update_freq = target_update_freq + + # models + self.actor = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.critic = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + + self.actor.encoder.copy_conv_weights_from(self.critic.encoder) + + self.actor_target = Actor( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.actor_target.load_state_dict(self.actor.state_dict()) + + self.critic_target = Critic( + obs_shape, action_shape, encoder_type, encoder_feature_dim + ).to(device) + self.critic_target.load_state_dict(self.critic.state_dict()) + + # optimizers + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=actor_lr + ) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=critic_lr + ) + + self.train() + self.critic_target.train() + self.actor_target.train() + + def train(self, training=True): + self.training = training + self.actor.train(training) + self.critic.train(training) + + def select_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + action = self.actor(obs) + return action.cpu().data.numpy().flatten() + + def sample_action(self, obs): + with torch.no_grad(): + obs = torch.FloatTensor(obs).to(self.device) + obs = obs.unsqueeze(0) + action = self.actor(obs) + noise = torch.randn_like(action) * self.expl_noise + action = (action + noise).clamp(-1.0, 1.0) + return action.cpu().data.numpy().flatten() + + def update_critic(self, obs, action, reward, next_obs, not_done, L, step): + with torch.no_grad(): + noise = torch.randn_like(action).to(self.device) * self.policy_noise + noise = noise.clamp(-self.noise_clip, self.noise_clip) + next_action = self.actor_target(next_obs) + noise + next_action = next_action.clamp(-1.0, 1.0) + target_Q1, target_Q2 = self.critic_target(next_obs, next_action) + target_Q = torch.min(target_Q1, target_Q2) + target_Q = reward + (not_done * self.discount * target_Q) + + current_Q1, current_Q2 = self.critic(obs, action) + + critic_loss = F.mse_loss(current_Q1, + target_Q) + F.mse_loss(current_Q2, target_Q) + L.log('train_critic/loss', critic_loss, step) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + self.critic.log(L, step) + + def update_actor(self, obs, L, step): + action = self.actor(obs, detach_encoder=True) + actor_Q = self.critic.Q1(obs, action, detach_encoder=True) + actor_loss = -actor_Q.mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.actor.log(L, step) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample() + + L.log('train/batch_reward', reward.mean(), step) + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor(obs, L, step) + + if step % self.target_update_freq == 0: + utils.soft_update_params(self.critic, self.critic_target, self.tau) + utils.soft_update_params(self.actor, self.actor_target, self.tau) + + def save(self, model_dir, step): + torch.save( + self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) + ) + torch.save( + self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) + ) + + def load(self, model_dir, step): + self.actor.load_state_dict( + torch.load('%s/actor_%s.pt' % (model_dir, step)) + ) + self.critic.load_state_dict( + torch.load('%s/critic_%s.pt' % (model_dir, step)) + ) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..2af856d --- /dev/null +++ b/train.py @@ -0,0 +1,315 @@ +import numpy as np +import torch +import argparse +import os +import math +import gym +import sys +import random +import time +import json +import dmc2gym +import copy + +import utils +from logger import Logger +from video import VideoRecorder + +from sac import SACAgent +from td3 import TD3Agent +from ddpg import DDPGAgent + + +def parse_args(): + parser = argparse.ArgumentParser() + # environment + parser.add_argument('--domain_name', default='cheetah') + parser.add_argument('--task_name', default='run') + parser.add_argument('--image_size', default=84, type=int) + parser.add_argument('--action_repeat', default=1, type=int) + parser.add_argument('--frame_stack', default=3, type=int) + # replay buffer + parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) + # train + parser.add_argument('--agent', default='sac', type=str) + parser.add_argument('--init_steps', default=1000, type=int) + parser.add_argument('--num_train_steps', default=1000000, type=int) + parser.add_argument('--batch_size', default=512, type=int) + parser.add_argument('--hidden_dim', default=256, type=int) + # eval + parser.add_argument('--eval_freq', default=10000, type=int) + parser.add_argument('--num_eval_episodes', default=10, type=int) + # critic + parser.add_argument('--critic_lr', default=1e-3, type=float) + parser.add_argument('--critic_beta', default=0.9, type=float) + parser.add_argument('--critic_tau', default=0.005, type=float) + parser.add_argument('--critic_target_update_freq', default=2, type=int) + # actor + parser.add_argument('--actor_lr', default=1e-3, type=float) + parser.add_argument('--actor_beta', default=0.9, type=float) + parser.add_argument('--actor_log_std_min', default=-10, type=float) + parser.add_argument('--actor_log_std_max', default=2, type=float) + parser.add_argument('--actor_update_freq', default=2, type=int) + # encoder/decoder + parser.add_argument('--encoder_type', default='identity', type=str) + parser.add_argument('--encoder_feature_dim', default=50, type=int) + parser.add_argument('--encoder_lr', default=1e-3, type=float) + parser.add_argument('--encoder_tau', default=0.005, type=float) + parser.add_argument('--decoder_type', default='identity', type=str) + parser.add_argument('--decoder_lr', default=1e-3, type=float) + parser.add_argument('--decoder_update_freq', default=1, type=int) + parser.add_argument('--decoder_latent_lambda', default=0.0, type=float) + parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) + parser.add_argument('--decoder_kl_lambda', default=0.0, type=float) + parser.add_argument('--num_layers', default=4, type=int) + parser.add_argument('--num_filters', default=32, type=int) + parser.add_argument('--freeze_encoder', default=False, action='store_true') + parser.add_argument('--use_dynamics', default=False, action='store_true') + # sac + parser.add_argument('--discount', default=0.99, type=float) + parser.add_argument('--init_temperature', default=0.01, type=float) + parser.add_argument('--alpha_lr', default=1e-3, type=float) + parser.add_argument('--alpha_beta', default=0.9, type=float) + # td3 + parser.add_argument('--policy_noise', default=0.2, type=float) + parser.add_argument('--expl_noise', default=0.1, type=float) + parser.add_argument('--noise_clip', default=0.5, type=float) + parser.add_argument('--tau', default=0.005, type=float) + # misc + parser.add_argument('--seed', default=1, type=int) + parser.add_argument('--work_dir', default='.', type=str) + parser.add_argument('--save_tb', default=False, action='store_true') + parser.add_argument('--save_model', default=False, action='store_true') + parser.add_argument('--save_buffer', default=False, action='store_true') + parser.add_argument('--save_video', default=False, action='store_true') + parser.add_argument('--pretrained_info', default=None, type=str) + parser.add_argument('--pretrained_decoder', default=False, action='store_true') + + args = parser.parse_args() + return args + + +def evaluate(env, agent, video, num_episodes, L, step): + for i in range(num_episodes): + obs = env.reset() + video.init(enabled=(i == 0)) + done = False + episode_reward = 0 + while not done: + with utils.eval_mode(agent): + action = agent.select_action(obs) + obs, reward, done, _ = env.step(action) + video.record(env) + episode_reward += reward + + video.save('%d.mp4' % step) + L.log('eval/episode_reward', episode_reward, step) + L.dump(step) + + +def make_agent(obs_shape, state_shape, action_shape, args, device): + if args.agent == 'sac': + return SACAgent( + obs_shape=obs_shape, + state_shape=state_shape, + action_shape=action_shape, + device=device, + hidden_dim=args.hidden_dim, + discount=args.discount, + init_temperature=args.init_temperature, + alpha_lr=args.alpha_lr, + alpha_beta=args.alpha_beta, + actor_lr=args.actor_lr, + actor_beta=args.actor_beta, + actor_log_std_min=args.actor_log_std_min, + actor_log_std_max=args.actor_log_std_max, + actor_update_freq=args.actor_update_freq, + critic_lr=args.critic_lr, + critic_beta=args.critic_beta, + critic_tau=args.critic_tau, + critic_target_update_freq=args.critic_target_update_freq, + encoder_type=args.encoder_type, + encoder_feature_dim=args.encoder_feature_dim, + encoder_lr=args.encoder_lr, + encoder_tau=args.encoder_tau, + decoder_type=args.decoder_type, + decoder_lr=args.decoder_lr, + decoder_update_freq=args.decoder_update_freq, + decoder_latent_lambda=args.decoder_latent_lambda, + decoder_weight_lambda=args.decoder_weight_lambda, + decoder_kl_lambda=args.decoder_kl_lambda, + num_layers=args.num_layers, + num_filters=args.num_filters, + freeze_encoder=args.freeze_encoder, + use_dynamics=args.use_dynamics + ) + elif args.agent == 'td3': + return TD3Agent( + obs_shape=obs_shape, + action_shape=action_shape, + device=device, + discount=args.discount, + tau=args.tau, + policy_noise=args.policy_noise, + noise_clip=args.noise_clip, + expl_noise=args.expl_noise, + actor_lr=args.actor_lr, + critic_lr=args.critic_lr, + encoder_type=args.encoder_type, + encoder_feature_dim=args.encoder_feature_dim, + actor_update_freq=args.actor_update_freq, + target_update_freq=args.critic_target_update_freq + ) + elif args.agent == 'ddpg': + return DDPGAgent( + obs_shape=obs_shape, + action_shape=action_shape, + device=device, + discount=args.discount, + tau=args.tau, + actor_lr=args.actor_lr, + critic_lr=args.critic_lr, + encoder_type=args.encoder_type, + encoder_feature_dim=args.encoder_feature_dim + ) + else: + assert 'agent is not supported: %s' % args.agent + + +def load_pretrained_encoder(agent, pretrained_info, pretrained_decoder): + path, version = pretrained_info.split(':') + + pretrained_agent = copy.deepcopy(agent) + pretrained_agent.load(path, int(version)) + agent.critic.encoder.load_state_dict( + pretrained_agent.critic.encoder.state_dict() + ) + agent.actor.encoder.load_state_dict( + pretrained_agent.actor.encoder.state_dict() + ) + + if pretrained_decoder: + agent.decoder.load_state_dict(pretrained_agent.decoder.state_dict()) + + return agent + + +def main(): + args = parse_args() + utils.set_seed_everywhere(args.seed) + + env = dmc2gym.make( + domain_name=args.domain_name, + task_name=args.task_name, + seed=args.seed, + visualize_reward=False, + from_pixels=(args.encoder_type == 'pixel'), + height=args.image_size, + width=args.image_size, + frame_skip=args.action_repeat + ) + env.seed(args.seed) + + # stack several consecutive frames together + if args.encoder_type == 'pixel': + env = utils.FrameStack(env, k=args.frame_stack) + + utils.make_dir(args.work_dir) + video_dir = utils.make_dir(os.path.join(args.work_dir, 'video')) + model_dir = utils.make_dir(os.path.join(args.work_dir, 'model')) + buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer')) + + video = VideoRecorder(video_dir if args.save_video else None) + + with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: + json.dump(vars(args), f, sort_keys=True, indent=4) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # the dmc2gym wrapper standardizes actions + assert env.action_space.low.min() >= -1 + assert env.action_space.high.max() <= 1 + + replay_buffer = utils.ReplayBuffer( + obs_shape=env.observation_space.shape, + state_shape=env.state_space.shape, + action_shape=env.action_space.shape, + capacity=args.replay_buffer_capacity, + batch_size=args.batch_size, + device=device + ) + + agent = make_agent( + obs_shape=env.observation_space.shape, + state_shape=env.state_space.shape, + action_shape=env.action_space.shape, + args=args, + device=device + ) + + if args.pretrained_info is not None: + agent = load_pretrained_encoder( + agent, args.pretrained_info, args.pretrained_decoder + ) + + L = Logger(args.work_dir, use_tb=args.save_tb) + + episode, episode_reward, done = 0, 0, True + start_time = time.time() + for step in range(args.num_train_steps): + if done: + if step > 0: + L.log('train/duration', time.time() - start_time, step) + start_time = time.time() + L.dump(step) + + # evaluate agent periodically + if step % args.eval_freq == 0: + L.log('eval/episode', episode, step) + evaluate(env, agent, video, args.num_eval_episodes, L, step) + if args.save_model: + agent.save(model_dir, step) + if args.save_buffer: + replay_buffer.save(buffer_dir) + + L.log('train/episode_reward', episode_reward, step) + + obs = env.reset() + done = False + episode_reward = 0 + episode_step = 0 + episode += 1 + + L.log('train/episode', episode, step) + + # sample action for data collection + if step < args.init_steps: + action = env.action_space.sample() + else: + with utils.eval_mode(agent): + action = agent.sample_action(obs) + + # run training update + if step >= args.init_steps: + num_updates = args.init_steps if step == args.init_steps else 1 + for _ in range(num_updates): + agent.update(replay_buffer, L, step) + + state = env.env.env._current_state + next_obs, reward, done, _ = env.step(action) + next_state = env.env.env._current_state.shape + + # allow infinit bootstrap + done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( + done + ) + episode_reward += reward + + replay_buffer.add(obs, action, reward, next_obs, done_bool, state, next_state) + + obs = next_obs + episode_step += 1 + + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..4f0c6c1 --- /dev/null +++ b/utils.py @@ -0,0 +1,182 @@ +import torch +import numpy as np +import torch.nn as nn +import gym +import os +from collections import deque +import random + + +class eval_mode(object): + def __init__(self, *models): + self.models = models + + def __enter__(self): + self.prev_states = [] + for model in self.models: + self.prev_states.append(model.training) + model.train(False) + + def __exit__(self, *args): + for model, state in zip(self.models, self.prev_states): + model.train(state) + return False + + +def soft_update_params(net, target_net, tau): + for param, target_param in zip(net.parameters(), target_net.parameters()): + target_param.data.copy_( + tau * param.data + (1 - tau) * target_param.data + ) + + +def set_seed_everywhere(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def module_hash(module): + result = 0 + for tensor in module.state_dict().values(): + result += tensor.sum().item() + return result + + +def make_dir(dir_path): + try: + os.mkdir(dir_path) + except OSError: + pass + return dir_path + + +def preprocess_obs(obs, bits=5): + """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" + bins = 2**bits + assert obs.dtype == torch.float32 + if bits < 8: + obs = torch.floor(obs / 2**(8 - bits)) + obs = obs / bins + obs = obs + torch.rand_like(obs) / bins + obs = obs - 0.5 + return obs + + +class ReplayBuffer(object): + """Buffer to store environment transitions.""" + def __init__( + self, obs_shape, state_shape, action_shape, capacity, batch_size, + device + ): + self.capacity = capacity + self.batch_size = batch_size + self.device = device + + # the proprioceptive obs is stored as float32, pixels obs as uint8 + obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 + + self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) + self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) + self.actions = np.empty((capacity, *action_shape), dtype=np.float32) + self.rewards = np.empty((capacity, 1), dtype=np.float32) + self.not_dones = np.empty((capacity, 1), dtype=np.float32) + self.states = np.empty((capacity, *state_shape), dtype=np.float32) + self.next_states = np.empty((capacity, *state_shape), dtype=np.float32) + + self.idx = 0 + self.last_save = 0 + self.full = False + + def add(self, obs, action, reward, next_obs, done, state, next_state): + np.copyto(self.obses[self.idx], obs) + np.copyto(self.actions[self.idx], action) + np.copyto(self.rewards[self.idx], reward) + np.copyto(self.next_obses[self.idx], next_obs) + np.copyto(self.not_dones[self.idx], not done) + np.copyto(self.states[self.idx], state) + np.copyto(self.next_states[self.idx], next_state) + + self.idx = (self.idx + 1) % self.capacity + self.full = self.full or self.idx == 0 + + def sample(self): + idxs = np.random.randint( + 0, self.capacity if self.full else self.idx, size=self.batch_size + ) + + obses = torch.as_tensor(self.obses[idxs], device=self.device).float() + actions = torch.as_tensor(self.actions[idxs], device=self.device) + rewards = torch.as_tensor(self.rewards[idxs], device=self.device) + next_obses = torch.as_tensor( + self.next_obses[idxs], device=self.device + ).float() + not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) + states = torch.as_tensor(self.states[idxs], device=self.device) + + return obses, actions, rewards, next_obses, not_dones, states + + def save(self, save_dir): + if self.idx == self.last_save: + return + path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx)) + payload = [ + self.obses[self.last_save:self.idx], + self.next_obses[self.last_save:self.idx], + self.actions[self.last_save:self.idx], + self.rewards[self.last_save:self.idx], + self.not_dones[self.last_save:self.idx], + self.states[self.last_save:self.idx], + self.next_states[self.last_save:self.idx] + ] + self.last_save = self.idx + torch.save(payload, path) + + def load(self, save_dir): + chunks = os.listdir(save_dir) + chucks = sorted(chunks, key=lambda x: int(x.split('_')[0])) + for chunk in chucks: + start, end = [int(x) for x in chunk.split('.')[0].split('_')] + path = os.path.join(save_dir, chunk) + payload = torch.load(path) + assert self.idx == start + self.obses[start:end] = payload[0] + self.next_obses[start:end] = payload[1] + self.actions[start:end] = payload[2] + self.rewards[start:end] = payload[3] + self.not_dones[start:end] = payload[4] + self.states[start:end] = payload[5] + self.next_states[start:end] = payload[6] + self.idx = end + + +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + gym.Wrapper.__init__(self, env) + self._k = k + self._frames = deque([], maxlen=k) + shp = env.observation_space.shape + self.observation_space = gym.spaces.Box( + low=0, + high=1, + shape=((shp[0] * k,) + shp[1:]), + dtype=env.observation_space.dtype + ) + self._max_episode_steps = env._max_episode_steps + + def reset(self): + obs = self.env.reset() + for _ in range(self._k): + self._frames.append(obs) + return self._get_obs() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self._frames.append(obs) + return self._get_obs(), reward, done, info + + def _get_obs(self): + assert len(self._frames) == self._k + return np.concatenate(list(self._frames), axis=0) diff --git a/video.py b/video.py new file mode 100644 index 0000000..36588e6 --- /dev/null +++ b/video.py @@ -0,0 +1,32 @@ +import imageio +import os +import numpy as np + + +class VideoRecorder(object): + def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): + self.dir_name = dir_name + self.height = height + self.width = width + self.camera_id = camera_id + self.fps = fps + self.frames = [] + + def init(self, enabled=True): + self.frames = [] + self.enabled = self.dir_name is not None and enabled + + def record(self, env): + if self.enabled: + frame = env.render( + mode='rgb_array', + height=self.height, + width=self.width, + camera_id=self.camera_id + ) + self.frames.append(frame) + + def save(self, file_name): + if self.enabled: + path = os.path.join(self.dir_name, file_name) + imageio.mimsave(path, self.frames, fps=self.fps)