50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
import torch.nn.functional as F
|
||
|
from torch.distributions import Categorical
|
||
|
|
||
|
import gym
|
||
|
import numpy as np
|
||
|
|
||
|
from functions import rollouts, discount_rewards
|
||
|
from models import ICM, ActorCritic, ActorCriticNetwork
|
||
|
from ppo_trainer import PPO
|
||
|
|
||
|
from torch.utils.tensorboard import SummaryWriter
|
||
|
writer = SummaryWriter()
|
||
|
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
|
||
|
env = gym.make('CartPole-v1')
|
||
|
ac = ActorCriticNetwork(env.observation_space.shape[0], env.action_space.n).to(device)
|
||
|
|
||
|
state = env.reset()
|
||
|
done = False
|
||
|
|
||
|
total_episodes = 1000
|
||
|
max_steps = 1000
|
||
|
|
||
|
ppo = PPO(ac)
|
||
|
for episode in range(total_episodes):
|
||
|
rollout = rollouts(env, ac, max_steps=max_steps)
|
||
|
|
||
|
# Shuffle
|
||
|
permute_idx = np.random.permutation(len(rollout[0]))
|
||
|
|
||
|
# Policy data
|
||
|
obs = torch.tensor(np.asarray(rollout[0])[permute_idx], dtype=torch.float32).to(device)
|
||
|
actions = torch.tensor(np.asarray(rollout[1])[permute_idx], dtype=torch.float32).to(device)
|
||
|
old_log_probs = torch.tensor(np.asarray(rollout[4])[permute_idx], dtype=torch.float32).to(device)
|
||
|
gaes = torch.tensor(np.asarray(rollout[3])[permute_idx], dtype=torch.float32).to(device)
|
||
|
|
||
|
# Value data
|
||
|
returns = discount_rewards(np.asarray(rollout[2]))[permute_idx]
|
||
|
returns = torch.tensor(returns, dtype=torch.float32).to(device)
|
||
|
|
||
|
ppo.update_policy(obs, actions, old_log_probs, gaes, returns)
|
||
|
ppo.update_value(obs, returns)
|
||
|
|
||
|
writer.add_scalar('Reward', sum(rollout[2]), episode)
|
||
|
print('Episode {} | Avg Reward {:.1f}'.format(episode, sum(rollout[2])))
|