47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
import collections
|
||
|
|
||
|
def discount_rewards(rewards, gamma=0.99):
|
||
|
new_rewards = [float(rewards[-1])]
|
||
|
for i in reversed(range(len(rewards)-1)):
|
||
|
new_rewards.append(float(rewards[i]) + gamma * new_rewards[-1])
|
||
|
return np.array(new_rewards[::-1])
|
||
|
|
||
|
def calculate_gaes(rewards, values, gamma=0.99, decay=0.97):
|
||
|
next_values = np.concatenate([values[1:], [0]])
|
||
|
deltas = [rew + gamma * next_val - val for rew, val, next_val in zip(rewards, values, next_values)]
|
||
|
|
||
|
gaes = [deltas[-1]]
|
||
|
for i in reversed(range(len(deltas)-1)):
|
||
|
gaes.append(deltas[i] + decay * gamma * gaes[-1])
|
||
|
|
||
|
return np.array(gaes[::-1])
|
||
|
|
||
|
def rollouts(env, actor_critic, max_steps):
|
||
|
obs = env.reset()
|
||
|
done = False
|
||
|
|
||
|
obs_arr, action_arr, rewards, values, old_log_probs = [], [], [], [], []
|
||
|
rollout = [obs_arr, action_arr, rewards, values, old_log_probs]
|
||
|
|
||
|
for _ in range(max_steps):
|
||
|
actions, value = actor_critic(torch.FloatTensor(obs).to("cuda"))
|
||
|
action = actions.sample()
|
||
|
next_obs, reward, done, info = env.step(action.item())
|
||
|
|
||
|
obs_arr.append(obs)
|
||
|
action_arr.append(action.item())
|
||
|
rewards.append(reward)
|
||
|
values.append(value.item())
|
||
|
old_log_probs.append(actions.log_prob(action).item())
|
||
|
|
||
|
rollout = [obs_arr, action_arr, rewards, values, old_log_probs]
|
||
|
|
||
|
if done:
|
||
|
break
|
||
|
obs = next_obs
|
||
|
|
||
|
gaes = calculate_gaes(rewards, values)
|
||
|
rollout[3] = gaes
|
||
|
return rollout
|