Curiosity/ppo/ppo_trainer.py

49 lines
1.8 KiB
Python
Raw Normal View History

2023-02-01 18:36:09 +00:00
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
class PPO(nn.Module):
def __init__(self, actor_critic, clip_param=0.2, ppo_epoch=40, policy_lr=3e-4, value_lr=1e-3):
super(PPO, self).__init__()
self.ac = actor_critic
self.clip_param = clip_param
self.ppo_epoch = ppo_epoch
policy_params = list(self.ac.shared_layers.parameters()) + list(self.ac.policy_layers.parameters())
self.policy_optim = optim.Adam(policy_params, lr=policy_lr)
value_params = list(self.ac.shared_layers.parameters()) + list(self.ac.value_layers.parameters())
self.value_optim = optim.Adam(value_params, lr=value_lr)
def update_policy(self, obs, actions, old_log_probs, gaes, returns):
for _ in range(self.ppo_epoch):
self.policy_optim.zero_grad()
new_probs = Categorical(self.ac.policy(obs))
new_log_probs = new_probs.log_prob(actions)
ratio = torch.exp(new_log_probs - old_log_probs)
surr_1 = ratio * gaes
surr_2 = torch.clamp(ratio,min=1-self.clip_param, max=1+self.clip_param) * gaes
loss = - torch.min(surr_1, surr_2).mean()
loss.backward()
self.policy_optim.step()
kl_div = (old_log_probs - new_log_probs).mean()
if kl_div >= 0.02:
break
def update_value(self, obs, returns):
for _ in range(self.ppo_epoch):
self.value_optim.zero_grad()
value_loss = (returns - self.ac.value(obs)) ** 2 #F.mse_loss(self.ac.value(obs), returns.unsqueeze(1)).mean()
value_loss = value_loss.mean()
value_loss.backward()
self.value_optim.step()