Curiosity/a3c/icm_mario.py

91 lines
3.1 KiB
Python
Raw Permalink Normal View History

2023-01-31 14:58:50 +00:00
import gym
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
from models_a3c import *
from mario_env import create_mario_env
from mario_env import *
from nes_py.wrappers import JoypadSpace
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Make environment
#env = gym.make('SuperMarioBros-1-1-v0')
#env = create_mario_env(env)
# Loss functions
ce = nn.CrossEntropyLoss().to(device)
mse = nn.MSELoss().to(device)
# Training
def train(index, optimizer, global_ac, global_icm, beta=0.2, alpha=100, gamma=0.99, lamda=0.1, num_episodes=20000):
torch.manual_seed(123 + index)
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
local_ac = ActorCritic(256, env.action_space.n).to(device)
local_icm = ICM(4, 256, env.action_space.n).to(device)
local_ac.load_state_dict(global_ac.state_dict())
local_icm.load_state_dict(global_icm.state_dict())
#for episode in range(num_episodes):
loss = 0
observation = env.reset()
reward = 0
done = False
while not done:
#env.render()
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
action_probs, value = local_ac(state)
action = action_probs.sample()
action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float()
next_observation, reward, done, info = env.step(action.item())
next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device)
encoded_state, next_encoded_state, action_pred, next_encoded_state_pred = local_icm(state, next_state, action_one_hot)
intrinsic_reward = alpha * mse(next_encoded_state_pred, next_encoded_state.detach())
extrinsic_reward = torch.tensor(reward).to(device)
reward = intrinsic_reward + extrinsic_reward
forward_loss = mse(next_encoded_state_pred, next_encoded_state.detach())
inverse_loss = ce(action_probs.probs,action_pred)
icm_loss = beta * forward_loss + (1-beta) * inverse_loss
delta = reward + gamma * (local_ac(next_state)[1]*(1-done)) - value
actor_loss = -(action_probs.log_prob(action) +1e-8) * delta
critic_loss = delta ** 2
ac_loss = actor_loss + critic_loss
loss += lamda * ac_loss + icm_loss
observation = next_observation
optimizer.zero_grad()
# Update global model
for local_param, global_param in zip(global_ac.parameters(), local_ac.parameters()):
if global_param.grad is not None:
break
global_param._grad = local_param.grad
for local_param, global_param in zip(global_icm.parameters(), local_icm.parameters()):
if global_param.grad is not None:
break
global_param._grad = local_param.grad
loss.backward()
optimizer.step()
env.close()
if __name__ == '__main__':
train()