import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.distributions import Categorical class PPO(nn.Module): def __init__(self, actor_critic, clip_param=0.2, ppo_epoch=40, policy_lr=3e-4, value_lr=1e-3): super(PPO, self).__init__() self.ac = actor_critic self.clip_param = clip_param self.ppo_epoch = ppo_epoch policy_params = list(self.ac.shared_layers.parameters()) + list(self.ac.policy_layers.parameters()) self.policy_optim = optim.Adam(policy_params, lr=policy_lr) value_params = list(self.ac.shared_layers.parameters()) + list(self.ac.value_layers.parameters()) self.value_optim = optim.Adam(value_params, lr=value_lr) def update_policy(self, obs, actions, old_log_probs, gaes, returns): for _ in range(self.ppo_epoch): self.policy_optim.zero_grad() new_probs = Categorical(self.ac.policy(obs)) new_log_probs = new_probs.log_prob(actions) ratio = torch.exp(new_log_probs - old_log_probs) surr_1 = ratio * gaes surr_2 = torch.clamp(ratio,min=1-self.clip_param, max=1+self.clip_param) * gaes loss = - torch.min(surr_1, surr_2).mean() loss.backward() self.policy_optim.step() kl_div = (old_log_probs - new_log_probs).mean() if kl_div >= 0.02: break def update_value(self, obs, returns): for _ in range(self.ppo_epoch): self.value_optim.zero_grad() value_loss = (returns - self.ac.value(obs)) ** 2 #F.mse_loss(self.ac.value(obs), returns.unsqueeze(1)).mean() value_loss = value_loss.mean() value_loss.backward() self.value_optim.step()