Curiosity/ppo/main.py

50 lines
1.6 KiB
Python
Raw Permalink Normal View History

2023-02-01 18:36:09 +00:00
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import gym
import numpy as np
from functions import rollouts, discount_rewards
from models import ICM, ActorCritic, ActorCriticNetwork
from ppo_trainer import PPO
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make('CartPole-v1')
ac = ActorCriticNetwork(env.observation_space.shape[0], env.action_space.n).to(device)
state = env.reset()
done = False
total_episodes = 1000
max_steps = 1000
ppo = PPO(ac)
for episode in range(total_episodes):
rollout = rollouts(env, ac, max_steps=max_steps)
# Shuffle
permute_idx = np.random.permutation(len(rollout[0]))
# Policy data
obs = torch.tensor(np.asarray(rollout[0])[permute_idx], dtype=torch.float32).to(device)
actions = torch.tensor(np.asarray(rollout[1])[permute_idx], dtype=torch.float32).to(device)
old_log_probs = torch.tensor(np.asarray(rollout[4])[permute_idx], dtype=torch.float32).to(device)
gaes = torch.tensor(np.asarray(rollout[3])[permute_idx], dtype=torch.float32).to(device)
# Value data
returns = discount_rewards(np.asarray(rollout[2]))[permute_idx]
returns = torch.tensor(returns, dtype=torch.float32).to(device)
ppo.update_policy(obs, actions, old_log_probs, gaes, returns)
ppo.update_value(obs, returns)
writer.add_scalar('Reward', sum(rollout[2]), episode)
print('Episode {} | Avg Reward {:.1f}'.format(episode, sum(rollout[2])))