Curiosity/ppo/functions.py
2023-02-01 19:36:09 +01:00

47 lines
1.5 KiB
Python

import torch
import numpy as np
import collections
def discount_rewards(rewards, gamma=0.99):
new_rewards = [float(rewards[-1])]
for i in reversed(range(len(rewards)-1)):
new_rewards.append(float(rewards[i]) + gamma * new_rewards[-1])
return np.array(new_rewards[::-1])
def calculate_gaes(rewards, values, gamma=0.99, decay=0.97):
next_values = np.concatenate([values[1:], [0]])
deltas = [rew + gamma * next_val - val for rew, val, next_val in zip(rewards, values, next_values)]
gaes = [deltas[-1]]
for i in reversed(range(len(deltas)-1)):
gaes.append(deltas[i] + decay * gamma * gaes[-1])
return np.array(gaes[::-1])
def rollouts(env, actor_critic, max_steps):
obs = env.reset()
done = False
obs_arr, action_arr, rewards, values, old_log_probs = [], [], [], [], []
rollout = [obs_arr, action_arr, rewards, values, old_log_probs]
for _ in range(max_steps):
actions, value = actor_critic(torch.FloatTensor(obs).to("cuda"))
action = actions.sample()
next_obs, reward, done, info = env.step(action.item())
obs_arr.append(obs)
action_arr.append(action.item())
rewards.append(reward)
values.append(value.item())
old_log_probs.append(actions.log_prob(action).item())
rollout = [obs_arr, action_arr, rewards, values, old_log_probs]
if done:
break
obs = next_obs
gaes = calculate_gaes(rewards, values)
rollout[3] = gaes
return rollout