import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.distributions import Categorical import gym import numpy as np from functions import rollouts, discount_rewards from models import ICM, ActorCritic, ActorCriticNetwork from ppo_trainer import PPO from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") env = gym.make('CartPole-v1') ac = ActorCriticNetwork(env.observation_space.shape[0], env.action_space.n).to(device) state = env.reset() done = False total_episodes = 1000 max_steps = 1000 ppo = PPO(ac) for episode in range(total_episodes): rollout = rollouts(env, ac, max_steps=max_steps) # Shuffle permute_idx = np.random.permutation(len(rollout[0])) # Policy data obs = torch.tensor(np.asarray(rollout[0])[permute_idx], dtype=torch.float32).to(device) actions = torch.tensor(np.asarray(rollout[1])[permute_idx], dtype=torch.float32).to(device) old_log_probs = torch.tensor(np.asarray(rollout[4])[permute_idx], dtype=torch.float32).to(device) gaes = torch.tensor(np.asarray(rollout[3])[permute_idx], dtype=torch.float32).to(device) # Value data returns = discount_rewards(np.asarray(rollout[2]))[permute_idx] returns = torch.tensor(returns, dtype=torch.float32).to(device) ppo.update_policy(obs, actions, old_log_probs, gaes, returns) ppo.update_value(obs, returns) writer.add_scalar('Reward', sum(rollout[2]), episode) print('Episode {} | Avg Reward {:.1f}'.format(episode, sum(rollout[2])))