sac_ae_if/td3.py
Denis Yarats 681e13b12a init
2019-09-23 11:20:48 -07:00

259 lines
8.3 KiB
Python

# Code is taken from https://github.com/sfujim/TD3 with slight modifications
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from encoder import make_encoder
LOG_FREQ = 10000
class Actor(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_shape[0])
self.outputs = dict()
def forward(self, obs, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
h = F.relu(self.l1(obs))
h = F.relu(self.l2(h))
action = torch.tanh(self.l3(h))
self.outputs['mu'] = action
return action
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_actor/%s_hist' % k, v, step)
L.log_param('train_actor/fc1', self.l1, step)
L.log_param('train_actor/fc2', self.l2, step)
L.log_param('train_actor/fc3', self.l3, step)
class Critic(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
# Q1 architecture
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)
# Q2 architecture
self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l5 = nn.Linear(400, 300)
self.l6 = nn.Linear(300, 1)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], 1)
h1 = F.relu(self.l1(obs_action))
h1 = F.relu(self.l2(h1))
q1 = self.l3(h1)
h2 = F.relu(self.l4(obs_action))
h2 = F.relu(self.l5(h2))
q2 = self.l6(h2)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
def Q1(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], 1)
h1 = F.relu(self.l1(obs_action))
h1 = F.relu(self.l2(h1))
q1 = self.l3(h1)
return q1
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0: return
self.encoder.log(L, step, log_freq)
for k, v in self.outputs.items():
L.log_histogram('train_critic/%s_hist' % k, v, step)
L.log_param('train_critic/q1_fc1', self.l1, step)
L.log_param('train_critic/q1_fc2', self.l2, step)
L.log_param('train_critic/q1_fc3', self.l3, step)
L.log_param('train_critic/q1_fc4', self.l4, step)
L.log_param('train_critic/q1_fc5', self.l5, step)
L.log_param('train_critic/q1_fc6', self.l6, step)
class TD3Agent(object):
def __init__(
self,
obs_shape,
action_shape,
device,
discount=0.99,
tau=0.005,
policy_noise=0.2,
noise_clip=0.5,
expl_noise=0.1,
actor_lr=1e-3,
critic_lr=1e-3,
encoder_type='identity',
encoder_feature_dim=50,
actor_update_freq=2,
target_update_freq=2,
):
self.device = device
self.discount = discount
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.expl_noise = expl_noise
self.actor_update_freq = actor_update_freq
self.target_update_freq = target_update_freq
# models
self.actor = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.actor_target = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr
)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=critic_lr
)
self.train()
self.critic_target.train()
self.actor_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
def select_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
return action.cpu().data.numpy().flatten()
def sample_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
noise = torch.randn_like(action) * self.expl_noise
action = (action + noise).clamp(-1.0, 1.0)
return action.cpu().data.numpy().flatten()
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
noise = torch.randn_like(action).to(self.device) * self.policy_noise
noise = noise.clamp(-self.noise_clip, self.noise_clip)
next_action = self.actor_target(next_obs) + noise
next_action = next_action.clamp(-1.0, 1.0)
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + (not_done * self.discount * target_Q)
current_Q1, current_Q2 = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q1,
target_Q) + F.mse_loss(current_Q2, target_Q)
L.log('train_critic/loss', critic_loss, step)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(L, step)
def update_actor(self, obs, L, step):
action = self.actor(obs, detach_encoder=True)
actor_Q = self.critic.Q1(obs, action, detach_encoder=True)
actor_loss = -actor_Q.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(L, step)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
if step % self.actor_update_freq == 0:
self.update_actor(obs, L, step)
if step % self.target_update_freq == 0:
utils.soft_update_params(self.critic, self.critic_target, self.tau)
utils.soft_update_params(self.actor, self.actor_target, self.tau)
def save(self, model_dir, step):
torch.save(
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
)
torch.save(
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
)
def load(self, model_dir, step):
self.actor.load_state_dict(
torch.load('%s/actor_%s.pt' % (model_dir, step))
)
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)