49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
import torch.nn.functional as F
|
||
|
from torch.distributions import Categorical
|
||
|
|
||
|
|
||
|
class PPO(nn.Module):
|
||
|
def __init__(self, actor_critic, clip_param=0.2, ppo_epoch=40, policy_lr=3e-4, value_lr=1e-3):
|
||
|
super(PPO, self).__init__()
|
||
|
self.ac = actor_critic
|
||
|
self.clip_param = clip_param
|
||
|
self.ppo_epoch = ppo_epoch
|
||
|
|
||
|
policy_params = list(self.ac.shared_layers.parameters()) + list(self.ac.policy_layers.parameters())
|
||
|
self.policy_optim = optim.Adam(policy_params, lr=policy_lr)
|
||
|
|
||
|
value_params = list(self.ac.shared_layers.parameters()) + list(self.ac.value_layers.parameters())
|
||
|
self.value_optim = optim.Adam(value_params, lr=value_lr)
|
||
|
|
||
|
def update_policy(self, obs, actions, old_log_probs, gaes, returns):
|
||
|
for _ in range(self.ppo_epoch):
|
||
|
self.policy_optim.zero_grad()
|
||
|
|
||
|
new_probs = Categorical(self.ac.policy(obs))
|
||
|
new_log_probs = new_probs.log_prob(actions)
|
||
|
ratio = torch.exp(new_log_probs - old_log_probs)
|
||
|
|
||
|
surr_1 = ratio * gaes
|
||
|
surr_2 = torch.clamp(ratio,min=1-self.clip_param, max=1+self.clip_param) * gaes
|
||
|
loss = - torch.min(surr_1, surr_2).mean()
|
||
|
|
||
|
loss.backward()
|
||
|
self.policy_optim.step()
|
||
|
|
||
|
kl_div = (old_log_probs - new_log_probs).mean()
|
||
|
if kl_div >= 0.02:
|
||
|
break
|
||
|
|
||
|
|
||
|
def update_value(self, obs, returns):
|
||
|
for _ in range(self.ppo_epoch):
|
||
|
self.value_optim.zero_grad()
|
||
|
|
||
|
value_loss = (returns - self.ac.value(obs)) ** 2 #F.mse_loss(self.ac.value(obs), returns.unsqueeze(1)).mean()
|
||
|
value_loss = value_loss.mean()
|
||
|
|
||
|
value_loss.backward()
|
||
|
self.value_optim.step()
|