sac_ae_if/ddpg.py
Denis Yarats 681e13b12a init
2019-09-23 11:20:48 -07:00

210 lines
6.3 KiB
Python

# Code is taken from https://github.com/sfujim/TD3 with slight modifications
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from encoder import make_encoder
LOG_FREQ = 10000
class Actor(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_shape[0])
self.outputs = dict()
def forward(self, obs, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
h = F.relu(self.l1(obs))
h = F.relu(self.l2(h))
action = torch.tanh(self.l3(h))
self.outputs['mu'] = action
return action
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_actor/%s_hist' % k, v, step)
L.log_param('train_actor/fc1', self.l1, step)
L.log_param('train_actor/fc2', self.l2, step)
L.log_param('train_actor/fc3', self.l3, step)
class Critic(nn.Module):
def __init__(
self, obs_shape, action_shape, encoder_type, encoder_feature_dim
):
super().__init__()
self.encoder = make_encoder(
encoder_type, obs_shape, encoder_feature_dim
)
self.l1 = nn.Linear(self.encoder.feature_dim + action_shape[0], 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)
self.outputs = dict()
def forward(self, obs, action, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], dim=1)
h = F.relu(self.l1(obs_action))
h = F.relu(self.l2(h))
q = self.l3(h)
self.outputs['q'] = q
return q
def log(self, L, step, log_freq=LOG_FREQ):
if step % log_freq != 0:
return
self.encoder.log(L, step, log_freq)
for k, v in self.outputs.items():
L.log_histogram('train_critic/%s_hist' % k, v, step)
L.log_param('train_critic/fc1', self.l1, step)
L.log_param('train_critic/fc2', self.l2, step)
L.log_param('train_critic/fc3', self.l3, step)
class DDPGAgent(object):
def __init__(
self,
obs_shape,
action_shape,
device,
discount=0.99,
tau=0.005,
actor_lr=1e-3,
critic_lr=1e-3,
encoder_type='identity',
encoder_feature_dim=50
):
self.device = device
self.discount = discount
self.tau = tau
# models
self.actor = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.actor_target = Actor(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target = Critic(
obs_shape, action_shape, encoder_type, encoder_feature_dim
).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=actor_lr
)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=critic_lr
)
self.train()
self.critic_target.train()
self.actor_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
def select_action(self, obs):
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
action = self.actor(obs)
return action.cpu().data.numpy().flatten()
def sample_action(self, obs):
return self.select_action(obs)
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
with torch.no_grad():
target_Q = self.critic_target(
next_obs, self.actor_target(next_obs)
)
target_Q = reward + (not_done * self.discount * target_Q)
current_Q = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q, target_Q)
L.log('train_critic/loss', critic_loss, step)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(L, step)
def update_actor(self, obs, L, step):
action = self.actor(obs, detach_encoder=True)
actor_Q = self.critic(obs, action, detach_encoder=True)
actor_loss = -actor_Q.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(L, step)
def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done = replay_buffer.sample()
L.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
self.update_actor(obs, L, step)
utils.soft_update_params(self.critic, self.critic_target, self.tau)
utils.soft_update_params(self.actor, self.actor_target, self.tau)
def save(self, model_dir, step):
torch.save(
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
)
torch.save(
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
)
def load(self, model_dir, step):
self.actor.load_state_dict(
torch.load('%s/actor_%s.pt' % (model_dir, step))
)
self.critic.load_state_dict(
torch.load('%s/critic_%s.pt' % (model_dir, step))
)