91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import gym
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from torchvision import transforms as T
|
|
|
|
from models_a3c import *
|
|
from mario_env import create_mario_env
|
|
from mario_env import *
|
|
from nes_py.wrappers import JoypadSpace
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
writer = SummaryWriter()
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Make environment
|
|
#env = gym.make('SuperMarioBros-1-1-v0')
|
|
#env = create_mario_env(env)
|
|
|
|
# Loss functions
|
|
ce = nn.CrossEntropyLoss().to(device)
|
|
mse = nn.MSELoss().to(device)
|
|
|
|
|
|
# Training
|
|
def train(index, optimizer, global_ac, global_icm, beta=0.2, alpha=100, gamma=0.99, lamda=0.1, num_episodes=20000):
|
|
torch.manual_seed(123 + index)
|
|
|
|
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
|
|
|
|
local_ac = ActorCritic(256, env.action_space.n).to(device)
|
|
local_icm = ICM(4, 256, env.action_space.n).to(device)
|
|
local_ac.load_state_dict(global_ac.state_dict())
|
|
local_icm.load_state_dict(global_icm.state_dict())
|
|
|
|
#for episode in range(num_episodes):
|
|
loss = 0
|
|
observation = env.reset()
|
|
reward = 0
|
|
done = False
|
|
while not done:
|
|
#env.render()
|
|
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
|
|
action_probs, value = local_ac(state)
|
|
action = action_probs.sample()
|
|
action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float()
|
|
|
|
next_observation, reward, done, info = env.step(action.item())
|
|
next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device)
|
|
|
|
encoded_state, next_encoded_state, action_pred, next_encoded_state_pred = local_icm(state, next_state, action_one_hot)
|
|
|
|
intrinsic_reward = alpha * mse(next_encoded_state_pred, next_encoded_state.detach())
|
|
extrinsic_reward = torch.tensor(reward).to(device)
|
|
reward = intrinsic_reward + extrinsic_reward
|
|
|
|
forward_loss = mse(next_encoded_state_pred, next_encoded_state.detach())
|
|
inverse_loss = ce(action_probs.probs,action_pred)
|
|
icm_loss = beta * forward_loss + (1-beta) * inverse_loss
|
|
|
|
delta = reward + gamma * (local_ac(next_state)[1]*(1-done)) - value
|
|
actor_loss = -(action_probs.log_prob(action) +1e-8) * delta
|
|
critic_loss = delta ** 2
|
|
ac_loss = actor_loss + critic_loss
|
|
|
|
loss += lamda * ac_loss + icm_loss
|
|
|
|
observation = next_observation
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Update global model
|
|
for local_param, global_param in zip(global_ac.parameters(), local_ac.parameters()):
|
|
if global_param.grad is not None:
|
|
break
|
|
global_param._grad = local_param.grad
|
|
for local_param, global_param in zip(global_icm.parameters(), local_icm.parameters()):
|
|
if global_param.grad is not None:
|
|
break
|
|
global_param._grad = local_param.grad
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
env.close()
|
|
|
|
if __name__ == '__main__':
|
|
train() |