50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from torch.distributions import Categorical
|
|
|
|
import gym
|
|
import numpy as np
|
|
|
|
from functions import rollouts, discount_rewards
|
|
from models import ICM, ActorCritic, ActorCriticNetwork
|
|
from ppo_trainer import PPO
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
writer = SummaryWriter()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
env = gym.make('CartPole-v1')
|
|
ac = ActorCriticNetwork(env.observation_space.shape[0], env.action_space.n).to(device)
|
|
|
|
state = env.reset()
|
|
done = False
|
|
|
|
total_episodes = 1000
|
|
max_steps = 1000
|
|
|
|
ppo = PPO(ac)
|
|
for episode in range(total_episodes):
|
|
rollout = rollouts(env, ac, max_steps=max_steps)
|
|
|
|
# Shuffle
|
|
permute_idx = np.random.permutation(len(rollout[0]))
|
|
|
|
# Policy data
|
|
obs = torch.tensor(np.asarray(rollout[0])[permute_idx], dtype=torch.float32).to(device)
|
|
actions = torch.tensor(np.asarray(rollout[1])[permute_idx], dtype=torch.float32).to(device)
|
|
old_log_probs = torch.tensor(np.asarray(rollout[4])[permute_idx], dtype=torch.float32).to(device)
|
|
gaes = torch.tensor(np.asarray(rollout[3])[permute_idx], dtype=torch.float32).to(device)
|
|
|
|
# Value data
|
|
returns = discount_rewards(np.asarray(rollout[2]))[permute_idx]
|
|
returns = torch.tensor(returns, dtype=torch.float32).to(device)
|
|
|
|
ppo.update_policy(obs, actions, old_log_probs, gaes, returns)
|
|
ppo.update_value(obs, returns)
|
|
|
|
writer.add_scalar('Reward', sum(rollout[2]), episode)
|
|
print('Episode {} | Avg Reward {:.1f}'.format(episode, sum(rollout[2])))
|