259 lines
8.3 KiB
Python
259 lines
8.3 KiB
Python
# Code is taken from https://github.com/sfujim/TD3 with slight modifications
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import utils
|
|
from encoder import make_encoder
|
|
|
|
LOG_FREQ = 10000
|
|
|
|
|
|
class Actor(nn.Module):
|
|
def __init__(
|
|
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder = make_encoder(
|
|
encoder_type, obs_shape, encoder_feature_dim
|
|
)
|
|
|
|
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
|
|
self.l2 = nn.Linear(400, 300)
|
|
self.l3 = nn.Linear(300, action_shape[0])
|
|
|
|
self.outputs = dict()
|
|
|
|
def forward(self, obs, detach_encoder=False):
|
|
obs = self.encoder(obs, detach=detach_encoder)
|
|
h = F.relu(self.l1(obs))
|
|
h = F.relu(self.l2(h))
|
|
action = torch.tanh(self.l3(h))
|
|
self.outputs['mu'] = action
|
|
return action
|
|
|
|
def log(self, L, step, log_freq=LOG_FREQ):
|
|
if step % log_freq != 0:
|
|
return
|
|
|
|
for k, v in self.outputs.items():
|
|
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
|
|
|
L.log_param('train_actor/fc1', self.l1, step)
|
|
L.log_param('train_actor/fc2', self.l2, step)
|
|
L.log_param('train_actor/fc3', self.l3, step)
|
|
|
|
|
|
class Critic(nn.Module):
|
|
def __init__(
|
|
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder = make_encoder(
|
|
encoder_type, obs_shape, encoder_feature_dim
|
|
)
|
|
|
|
# Q1 architecture
|
|
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
|
self.l2 = nn.Linear(400, 300)
|
|
self.l3 = nn.Linear(300, 1)
|
|
|
|
# Q2 architecture
|
|
self.l4 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
|
|
self.l5 = nn.Linear(400, 300)
|
|
self.l6 = nn.Linear(300, 1)
|
|
|
|
self.outputs = dict()
|
|
|
|
def forward(self, obs, action, detach_encoder=False):
|
|
obs = self.encoder(obs, detach=detach_encoder)
|
|
|
|
obs_action = torch.cat([obs, action], 1)
|
|
|
|
h1 = F.relu(self.l1(obs_action))
|
|
h1 = F.relu(self.l2(h1))
|
|
q1 = self.l3(h1)
|
|
|
|
h2 = F.relu(self.l4(obs_action))
|
|
h2 = F.relu(self.l5(h2))
|
|
q2 = self.l6(h2)
|
|
|
|
self.outputs['q1'] = q1
|
|
self.outputs['q2'] = q2
|
|
|
|
return q1, q2
|
|
|
|
def Q1(self, obs, action, detach_encoder=False):
|
|
obs = self.encoder(obs, detach=detach_encoder)
|
|
|
|
obs_action = torch.cat([obs, action], 1)
|
|
|
|
h1 = F.relu(self.l1(obs_action))
|
|
h1 = F.relu(self.l2(h1))
|
|
q1 = self.l3(h1)
|
|
return q1
|
|
|
|
def log(self, L, step, log_freq=LOG_FREQ):
|
|
if step % log_freq != 0: return
|
|
|
|
self.encoder.log(L, step, log_freq)
|
|
|
|
for k, v in self.outputs.items():
|
|
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
|
|
|
L.log_param('train_critic/q1_fc1', self.l1, step)
|
|
L.log_param('train_critic/q1_fc2', self.l2, step)
|
|
L.log_param('train_critic/q1_fc3', self.l3, step)
|
|
L.log_param('train_critic/q1_fc4', self.l4, step)
|
|
L.log_param('train_critic/q1_fc5', self.l5, step)
|
|
L.log_param('train_critic/q1_fc6', self.l6, step)
|
|
|
|
|
|
class TD3Agent(object):
|
|
def __init__(
|
|
self,
|
|
obs_shape,
|
|
action_shape,
|
|
device,
|
|
discount=0.99,
|
|
tau=0.005,
|
|
policy_noise=0.2,
|
|
noise_clip=0.5,
|
|
expl_noise=0.1,
|
|
actor_lr=1e-3,
|
|
critic_lr=1e-3,
|
|
encoder_type='identity',
|
|
encoder_feature_dim=50,
|
|
actor_update_freq=2,
|
|
target_update_freq=2,
|
|
):
|
|
self.device = device
|
|
self.discount = discount
|
|
self.tau = tau
|
|
self.policy_noise = policy_noise
|
|
self.noise_clip = noise_clip
|
|
self.expl_noise = expl_noise
|
|
self.actor_update_freq = actor_update_freq
|
|
self.target_update_freq = target_update_freq
|
|
|
|
# models
|
|
self.actor = Actor(
|
|
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
).to(device)
|
|
|
|
self.critic = Critic(
|
|
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
).to(device)
|
|
|
|
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
|
|
|
self.actor_target = Actor(
|
|
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
).to(device)
|
|
self.actor_target.load_state_dict(self.actor.state_dict())
|
|
|
|
self.critic_target = Critic(
|
|
obs_shape, action_shape, encoder_type, encoder_feature_dim
|
|
).to(device)
|
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
|
|
|
# optimizers
|
|
self.actor_optimizer = torch.optim.Adam(
|
|
self.actor.parameters(), lr=actor_lr
|
|
)
|
|
|
|
self.critic_optimizer = torch.optim.Adam(
|
|
self.critic.parameters(), lr=critic_lr
|
|
)
|
|
|
|
self.train()
|
|
self.critic_target.train()
|
|
self.actor_target.train()
|
|
|
|
def train(self, training=True):
|
|
self.training = training
|
|
self.actor.train(training)
|
|
self.critic.train(training)
|
|
|
|
def select_action(self, obs):
|
|
with torch.no_grad():
|
|
obs = torch.FloatTensor(obs).to(self.device)
|
|
obs = obs.unsqueeze(0)
|
|
action = self.actor(obs)
|
|
return action.cpu().data.numpy().flatten()
|
|
|
|
def sample_action(self, obs):
|
|
with torch.no_grad():
|
|
obs = torch.FloatTensor(obs).to(self.device)
|
|
obs = obs.unsqueeze(0)
|
|
action = self.actor(obs)
|
|
noise = torch.randn_like(action) * self.expl_noise
|
|
action = (action + noise).clamp(-1.0, 1.0)
|
|
return action.cpu().data.numpy().flatten()
|
|
|
|
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
|
with torch.no_grad():
|
|
noise = torch.randn_like(action).to(self.device) * self.policy_noise
|
|
noise = noise.clamp(-self.noise_clip, self.noise_clip)
|
|
next_action = self.actor_target(next_obs) + noise
|
|
next_action = next_action.clamp(-1.0, 1.0)
|
|
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
|
|
target_Q = torch.min(target_Q1, target_Q2)
|
|
target_Q = reward + (not_done * self.discount * target_Q)
|
|
|
|
current_Q1, current_Q2 = self.critic(obs, action)
|
|
|
|
critic_loss = F.mse_loss(current_Q1,
|
|
target_Q) + F.mse_loss(current_Q2, target_Q)
|
|
L.log('train_critic/loss', critic_loss, step)
|
|
|
|
self.critic_optimizer.zero_grad()
|
|
critic_loss.backward()
|
|
self.critic_optimizer.step()
|
|
|
|
self.critic.log(L, step)
|
|
|
|
def update_actor(self, obs, L, step):
|
|
action = self.actor(obs, detach_encoder=True)
|
|
actor_Q = self.critic.Q1(obs, action, detach_encoder=True)
|
|
actor_loss = -actor_Q.mean()
|
|
|
|
self.actor_optimizer.zero_grad()
|
|
actor_loss.backward()
|
|
self.actor_optimizer.step()
|
|
|
|
self.actor.log(L, step)
|
|
|
|
def update(self, replay_buffer, L, step):
|
|
obs, action, reward, next_obs, not_done = replay_buffer.sample()
|
|
|
|
L.log('train/batch_reward', reward.mean(), step)
|
|
|
|
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
|
|
|
if step % self.actor_update_freq == 0:
|
|
self.update_actor(obs, L, step)
|
|
|
|
if step % self.target_update_freq == 0:
|
|
utils.soft_update_params(self.critic, self.critic_target, self.tau)
|
|
utils.soft_update_params(self.actor, self.actor_target, self.tau)
|
|
|
|
def save(self, model_dir, step):
|
|
torch.save(
|
|
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
|
)
|
|
torch.save(
|
|
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
|
)
|
|
|
|
def load(self, model_dir, step):
|
|
self.actor.load_state_dict(
|
|
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
|
)
|
|
self.critic.load_state_dict(
|
|
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
|
) |