From 80760bb6869e65e206633161dadd74976ddad0ba Mon Sep 17 00:00:00 2001 From: ved1 Date: Wed, 1 Feb 2023 19:36:09 +0100 Subject: [PATCH] Adding PPO --- ppo/functions.py | 47 +++++++++++++++ ppo/main.py | 49 +++++++++++++++ ppo/models.py | 147 +++++++++++++++++++++++++++++++++++++++++++++ ppo/ppo_trainer.py | 49 +++++++++++++++ 4 files changed, 292 insertions(+) create mode 100644 ppo/functions.py create mode 100644 ppo/main.py create mode 100644 ppo/models.py create mode 100644 ppo/ppo_trainer.py diff --git a/ppo/functions.py b/ppo/functions.py new file mode 100644 index 0000000..8554578 --- /dev/null +++ b/ppo/functions.py @@ -0,0 +1,47 @@ +import torch +import numpy as np +import collections + +def discount_rewards(rewards, gamma=0.99): + new_rewards = [float(rewards[-1])] + for i in reversed(range(len(rewards)-1)): + new_rewards.append(float(rewards[i]) + gamma * new_rewards[-1]) + return np.array(new_rewards[::-1]) + +def calculate_gaes(rewards, values, gamma=0.99, decay=0.97): + next_values = np.concatenate([values[1:], [0]]) + deltas = [rew + gamma * next_val - val for rew, val, next_val in zip(rewards, values, next_values)] + + gaes = [deltas[-1]] + for i in reversed(range(len(deltas)-1)): + gaes.append(deltas[i] + decay * gamma * gaes[-1]) + + return np.array(gaes[::-1]) + +def rollouts(env, actor_critic, max_steps): + obs = env.reset() + done = False + + obs_arr, action_arr, rewards, values, old_log_probs = [], [], [], [], [] + rollout = [obs_arr, action_arr, rewards, values, old_log_probs] + + for _ in range(max_steps): + actions, value = actor_critic(torch.FloatTensor(obs).to("cuda")) + action = actions.sample() + next_obs, reward, done, info = env.step(action.item()) + + obs_arr.append(obs) + action_arr.append(action.item()) + rewards.append(reward) + values.append(value.item()) + old_log_probs.append(actions.log_prob(action).item()) + + rollout = [obs_arr, action_arr, rewards, values, old_log_probs] + + if done: + break + obs = next_obs + + gaes = calculate_gaes(rewards, values) + rollout[3] = gaes + return rollout \ No newline at end of file diff --git a/ppo/main.py b/ppo/main.py new file mode 100644 index 0000000..76da755 --- /dev/null +++ b/ppo/main.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.distributions import Categorical + +import gym +import numpy as np + +from functions import rollouts, discount_rewards +from models import ICM, ActorCritic, ActorCriticNetwork +from ppo_trainer import PPO + +from torch.utils.tensorboard import SummaryWriter +writer = SummaryWriter() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +env = gym.make('CartPole-v1') +ac = ActorCriticNetwork(env.observation_space.shape[0], env.action_space.n).to(device) + +state = env.reset() +done = False + +total_episodes = 1000 +max_steps = 1000 + +ppo = PPO(ac) +for episode in range(total_episodes): + rollout = rollouts(env, ac, max_steps=max_steps) + + # Shuffle + permute_idx = np.random.permutation(len(rollout[0])) + + # Policy data + obs = torch.tensor(np.asarray(rollout[0])[permute_idx], dtype=torch.float32).to(device) + actions = torch.tensor(np.asarray(rollout[1])[permute_idx], dtype=torch.float32).to(device) + old_log_probs = torch.tensor(np.asarray(rollout[4])[permute_idx], dtype=torch.float32).to(device) + gaes = torch.tensor(np.asarray(rollout[3])[permute_idx], dtype=torch.float32).to(device) + + # Value data + returns = discount_rewards(np.asarray(rollout[2]))[permute_idx] + returns = torch.tensor(returns, dtype=torch.float32).to(device) + + ppo.update_policy(obs, actions, old_log_probs, gaes, returns) + ppo.update_value(obs, returns) + + writer.add_scalar('Reward', sum(rollout[2]), episode) + print('Episode {} | Avg Reward {:.1f}'.format(episode, sum(rollout[2]))) diff --git a/ppo/models.py b/ppo/models.py new file mode 100644 index 0000000..de33b45 --- /dev/null +++ b/ppo/models.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as f +from torch.distributions import Categorical + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class ICM(nn.Module): + def __init__(self, channels, encoded_state_size, action_size): + super(ICM, self).__init__() + self.channels = channels + self.encoded_state_size = encoded_state_size + self.action_size = action_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), + ).to(device) + + self.inverse_model = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size*2, out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + + self.forward_model = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.encoded_state_size), + ).to(device) + + def forward(self, state, next_state, action): + if state.dim() == 3: + state = state.unsqueeze(0) + next_state = next_state.unsqueeze(0) + + encoded_state = self.feature_encoder(state) + next_encoded_state = self.feature_encoder(next_state) + action_pred = self.inverse_model(torch.cat((encoded_state, next_encoded_state), dim=-1)) + next_encoded_state_pred = self.forward_model(torch.cat((encoded_state, action), dim=-1)) + return encoded_state, next_encoded_state, action_pred, next_encoded_state_pred + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + +class ActorCritic(nn.Module): + def __init__(self,encoded_state_size, action_size, state_size=4): + super(ActorCritic, self).__init__() + self.channels = state_size + self.encoded_state_size = encoded_state_size + self.action_size = action_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), + ).to(device) + + def actor(self,state): + policy = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size , out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + return policy(state) + + def critic(self,state): + value = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size , out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=1), + ).to(device) + return value(state) + + def forward(self, state): + if state.dim() == 3: + state = state.unsqueeze(0) + state = self.feature_encoder(state) + value = self.critic(state) + policy = self.actor(state) + actions = Categorical(policy) + return actions, value + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + +# Policy and value model +class ActorCriticNetwork(nn.Module): + def __init__(self, obs_space_size, action_space_size): + super().__init__() + + self.shared_layers = nn.Sequential( + nn.Linear(obs_space_size, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU()).to(device) + + self.policy_layers = nn.Sequential( + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, action_space_size), + nn.Softmax(dim=-1)).to(device) + + self.value_layers = nn.Sequential( + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, 1)).to(device) + + def value(self, obs): + z = self.shared_layers(obs) + value = self.value_layers(z) + return value + + def policy(self, obs): + z = self.shared_layers(obs) + policy_logits = self.policy_layers(z) + return policy_logits + + def forward(self, obs): + z = self.shared_layers(obs) + policy_logits = self.policy_layers(z) + value = self.value_layers(z) + return Categorical(policy_logits), value \ No newline at end of file diff --git a/ppo/ppo_trainer.py b/ppo/ppo_trainer.py new file mode 100644 index 0000000..4e478df --- /dev/null +++ b/ppo/ppo_trainer.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.distributions import Categorical + + +class PPO(nn.Module): + def __init__(self, actor_critic, clip_param=0.2, ppo_epoch=40, policy_lr=3e-4, value_lr=1e-3): + super(PPO, self).__init__() + self.ac = actor_critic + self.clip_param = clip_param + self.ppo_epoch = ppo_epoch + + policy_params = list(self.ac.shared_layers.parameters()) + list(self.ac.policy_layers.parameters()) + self.policy_optim = optim.Adam(policy_params, lr=policy_lr) + + value_params = list(self.ac.shared_layers.parameters()) + list(self.ac.value_layers.parameters()) + self.value_optim = optim.Adam(value_params, lr=value_lr) + + def update_policy(self, obs, actions, old_log_probs, gaes, returns): + for _ in range(self.ppo_epoch): + self.policy_optim.zero_grad() + + new_probs = Categorical(self.ac.policy(obs)) + new_log_probs = new_probs.log_prob(actions) + ratio = torch.exp(new_log_probs - old_log_probs) + + surr_1 = ratio * gaes + surr_2 = torch.clamp(ratio,min=1-self.clip_param, max=1+self.clip_param) * gaes + loss = - torch.min(surr_1, surr_2).mean() + + loss.backward() + self.policy_optim.step() + + kl_div = (old_log_probs - new_log_probs).mean() + if kl_div >= 0.02: + break + + + def update_value(self, obs, returns): + for _ in range(self.ppo_epoch): + self.value_optim.zero_grad() + + value_loss = (returns - self.ac.value(obs)) ** 2 #F.mse_loss(self.ac.value(obs), returns.unsqueeze(1)).mean() + value_loss = value_loss.mean() + + value_loss.backward() + self.value_optim.step() \ No newline at end of file