Adding PPO
This commit is contained in:
parent
000b970a12
commit
80760bb686
47
ppo/functions.py
Normal file
47
ppo/functions.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
|
||||||
|
def discount_rewards(rewards, gamma=0.99):
|
||||||
|
new_rewards = [float(rewards[-1])]
|
||||||
|
for i in reversed(range(len(rewards)-1)):
|
||||||
|
new_rewards.append(float(rewards[i]) + gamma * new_rewards[-1])
|
||||||
|
return np.array(new_rewards[::-1])
|
||||||
|
|
||||||
|
def calculate_gaes(rewards, values, gamma=0.99, decay=0.97):
|
||||||
|
next_values = np.concatenate([values[1:], [0]])
|
||||||
|
deltas = [rew + gamma * next_val - val for rew, val, next_val in zip(rewards, values, next_values)]
|
||||||
|
|
||||||
|
gaes = [deltas[-1]]
|
||||||
|
for i in reversed(range(len(deltas)-1)):
|
||||||
|
gaes.append(deltas[i] + decay * gamma * gaes[-1])
|
||||||
|
|
||||||
|
return np.array(gaes[::-1])
|
||||||
|
|
||||||
|
def rollouts(env, actor_critic, max_steps):
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
|
||||||
|
obs_arr, action_arr, rewards, values, old_log_probs = [], [], [], [], []
|
||||||
|
rollout = [obs_arr, action_arr, rewards, values, old_log_probs]
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
actions, value = actor_critic(torch.FloatTensor(obs).to("cuda"))
|
||||||
|
action = actions.sample()
|
||||||
|
next_obs, reward, done, info = env.step(action.item())
|
||||||
|
|
||||||
|
obs_arr.append(obs)
|
||||||
|
action_arr.append(action.item())
|
||||||
|
rewards.append(reward)
|
||||||
|
values.append(value.item())
|
||||||
|
old_log_probs.append(actions.log_prob(action).item())
|
||||||
|
|
||||||
|
rollout = [obs_arr, action_arr, rewards, values, old_log_probs]
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
obs = next_obs
|
||||||
|
|
||||||
|
gaes = calculate_gaes(rewards, values)
|
||||||
|
rollout[3] = gaes
|
||||||
|
return rollout
|
49
ppo/main.py
Normal file
49
ppo/main.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from functions import rollouts, discount_rewards
|
||||||
|
from models import ICM, ActorCritic, ActorCriticNetwork
|
||||||
|
from ppo_trainer import PPO
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
env = gym.make('CartPole-v1')
|
||||||
|
ac = ActorCriticNetwork(env.observation_space.shape[0], env.action_space.n).to(device)
|
||||||
|
|
||||||
|
state = env.reset()
|
||||||
|
done = False
|
||||||
|
|
||||||
|
total_episodes = 1000
|
||||||
|
max_steps = 1000
|
||||||
|
|
||||||
|
ppo = PPO(ac)
|
||||||
|
for episode in range(total_episodes):
|
||||||
|
rollout = rollouts(env, ac, max_steps=max_steps)
|
||||||
|
|
||||||
|
# Shuffle
|
||||||
|
permute_idx = np.random.permutation(len(rollout[0]))
|
||||||
|
|
||||||
|
# Policy data
|
||||||
|
obs = torch.tensor(np.asarray(rollout[0])[permute_idx], dtype=torch.float32).to(device)
|
||||||
|
actions = torch.tensor(np.asarray(rollout[1])[permute_idx], dtype=torch.float32).to(device)
|
||||||
|
old_log_probs = torch.tensor(np.asarray(rollout[4])[permute_idx], dtype=torch.float32).to(device)
|
||||||
|
gaes = torch.tensor(np.asarray(rollout[3])[permute_idx], dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
# Value data
|
||||||
|
returns = discount_rewards(np.asarray(rollout[2]))[permute_idx]
|
||||||
|
returns = torch.tensor(returns, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
ppo.update_policy(obs, actions, old_log_probs, gaes, returns)
|
||||||
|
ppo.update_value(obs, returns)
|
||||||
|
|
||||||
|
writer.add_scalar('Reward', sum(rollout[2]), episode)
|
||||||
|
print('Episode {} | Avg Reward {:.1f}'.format(episode, sum(rollout[2])))
|
147
ppo/models.py
Normal file
147
ppo/models.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as f
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class ICM(nn.Module):
|
||||||
|
def __init__(self, channels, encoded_state_size, action_size):
|
||||||
|
super(ICM, self).__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.encoded_state_size = encoded_state_size
|
||||||
|
self.action_size = action_size
|
||||||
|
|
||||||
|
self.feature_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.inverse_model = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size*2, out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=self.action_size),
|
||||||
|
nn.Softmax(dim=-1)
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.forward_model = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=self.encoded_state_size),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
def forward(self, state, next_state, action):
|
||||||
|
if state.dim() == 3:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
next_state = next_state.unsqueeze(0)
|
||||||
|
|
||||||
|
encoded_state = self.feature_encoder(state)
|
||||||
|
next_encoded_state = self.feature_encoder(next_state)
|
||||||
|
action_pred = self.inverse_model(torch.cat((encoded_state, next_encoded_state), dim=-1))
|
||||||
|
next_encoded_state_pred = self.forward_model(torch.cat((encoded_state, action), dim=-1))
|
||||||
|
return encoded_state, next_encoded_state, action_pred, next_encoded_state_pred
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
class ActorCritic(nn.Module):
|
||||||
|
def __init__(self,encoded_state_size, action_size, state_size=4):
|
||||||
|
super(ActorCritic, self).__init__()
|
||||||
|
self.channels = state_size
|
||||||
|
self.encoded_state_size = encoded_state_size
|
||||||
|
self.action_size = action_size
|
||||||
|
|
||||||
|
self.feature_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
def actor(self,state):
|
||||||
|
policy = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=self.action_size),
|
||||||
|
nn.Softmax(dim=-1)
|
||||||
|
).to(device)
|
||||||
|
return policy(state)
|
||||||
|
|
||||||
|
def critic(self,state):
|
||||||
|
value = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=1),
|
||||||
|
).to(device)
|
||||||
|
return value(state)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
if state.dim() == 3:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
state = self.feature_encoder(state)
|
||||||
|
value = self.critic(state)
|
||||||
|
policy = self.actor(state)
|
||||||
|
actions = Categorical(policy)
|
||||||
|
return actions, value
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
# Policy and value model
|
||||||
|
class ActorCriticNetwork(nn.Module):
|
||||||
|
def __init__(self, obs_space_size, action_space_size):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shared_layers = nn.Sequential(
|
||||||
|
nn.Linear(obs_space_size, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(64, 64),
|
||||||
|
nn.ReLU()).to(device)
|
||||||
|
|
||||||
|
self.policy_layers = nn.Sequential(
|
||||||
|
nn.Linear(64, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(64, action_space_size),
|
||||||
|
nn.Softmax(dim=-1)).to(device)
|
||||||
|
|
||||||
|
self.value_layers = nn.Sequential(
|
||||||
|
nn.Linear(64, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(64, 1)).to(device)
|
||||||
|
|
||||||
|
def value(self, obs):
|
||||||
|
z = self.shared_layers(obs)
|
||||||
|
value = self.value_layers(z)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def policy(self, obs):
|
||||||
|
z = self.shared_layers(obs)
|
||||||
|
policy_logits = self.policy_layers(z)
|
||||||
|
return policy_logits
|
||||||
|
|
||||||
|
def forward(self, obs):
|
||||||
|
z = self.shared_layers(obs)
|
||||||
|
policy_logits = self.policy_layers(z)
|
||||||
|
value = self.value_layers(z)
|
||||||
|
return Categorical(policy_logits), value
|
49
ppo/ppo_trainer.py
Normal file
49
ppo/ppo_trainer.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
|
||||||
|
class PPO(nn.Module):
|
||||||
|
def __init__(self, actor_critic, clip_param=0.2, ppo_epoch=40, policy_lr=3e-4, value_lr=1e-3):
|
||||||
|
super(PPO, self).__init__()
|
||||||
|
self.ac = actor_critic
|
||||||
|
self.clip_param = clip_param
|
||||||
|
self.ppo_epoch = ppo_epoch
|
||||||
|
|
||||||
|
policy_params = list(self.ac.shared_layers.parameters()) + list(self.ac.policy_layers.parameters())
|
||||||
|
self.policy_optim = optim.Adam(policy_params, lr=policy_lr)
|
||||||
|
|
||||||
|
value_params = list(self.ac.shared_layers.parameters()) + list(self.ac.value_layers.parameters())
|
||||||
|
self.value_optim = optim.Adam(value_params, lr=value_lr)
|
||||||
|
|
||||||
|
def update_policy(self, obs, actions, old_log_probs, gaes, returns):
|
||||||
|
for _ in range(self.ppo_epoch):
|
||||||
|
self.policy_optim.zero_grad()
|
||||||
|
|
||||||
|
new_probs = Categorical(self.ac.policy(obs))
|
||||||
|
new_log_probs = new_probs.log_prob(actions)
|
||||||
|
ratio = torch.exp(new_log_probs - old_log_probs)
|
||||||
|
|
||||||
|
surr_1 = ratio * gaes
|
||||||
|
surr_2 = torch.clamp(ratio,min=1-self.clip_param, max=1+self.clip_param) * gaes
|
||||||
|
loss = - torch.min(surr_1, surr_2).mean()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
self.policy_optim.step()
|
||||||
|
|
||||||
|
kl_div = (old_log_probs - new_log_probs).mean()
|
||||||
|
if kl_div >= 0.02:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def update_value(self, obs, returns):
|
||||||
|
for _ in range(self.ppo_epoch):
|
||||||
|
self.value_optim.zero_grad()
|
||||||
|
|
||||||
|
value_loss = (returns - self.ac.value(obs)) ** 2 #F.mse_loss(self.ac.value(obs), returns.unsqueeze(1)).mean()
|
||||||
|
value_loss = value_loss.mean()
|
||||||
|
|
||||||
|
value_loss.backward()
|
||||||
|
self.value_optim.step()
|
Loading…
Reference in New Issue
Block a user