import gym import numpy as np import torch from torch import nn import torch.nn.functional as F from torchvision import transforms as T from models_a3c import * from mario_env import create_mario_env from mario_env import * from nes_py.wrappers import JoypadSpace from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Make environment #env = gym.make('SuperMarioBros-1-1-v0') #env = create_mario_env(env) # Loss functions ce = nn.CrossEntropyLoss().to(device) mse = nn.MSELoss().to(device) # Training def train(index, optimizer, global_ac, global_icm, beta=0.2, alpha=100, gamma=0.99, lamda=0.1, num_episodes=20000): torch.manual_seed(123 + index) env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense') local_ac = ActorCritic(256, env.action_space.n).to(device) local_icm = ICM(4, 256, env.action_space.n).to(device) local_ac.load_state_dict(global_ac.state_dict()) local_icm.load_state_dict(global_icm.state_dict()) #for episode in range(num_episodes): loss = 0 observation = env.reset() reward = 0 done = False while not done: #env.render() state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device) action_probs, value = local_ac(state) action = action_probs.sample() action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float() next_observation, reward, done, info = env.step(action.item()) next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device) encoded_state, next_encoded_state, action_pred, next_encoded_state_pred = local_icm(state, next_state, action_one_hot) intrinsic_reward = alpha * mse(next_encoded_state_pred, next_encoded_state.detach()) extrinsic_reward = torch.tensor(reward).to(device) reward = intrinsic_reward + extrinsic_reward forward_loss = mse(next_encoded_state_pred, next_encoded_state.detach()) inverse_loss = ce(action_probs.probs,action_pred) icm_loss = beta * forward_loss + (1-beta) * inverse_loss delta = reward + gamma * (local_ac(next_state)[1]*(1-done)) - value actor_loss = -(action_probs.log_prob(action) +1e-8) * delta critic_loss = delta ** 2 ac_loss = actor_loss + critic_loss loss += lamda * ac_loss + icm_loss observation = next_observation optimizer.zero_grad() # Update global model for local_param, global_param in zip(global_ac.parameters(), local_ac.parameters()): if global_param.grad is not None: break global_param._grad = local_param.grad for local_param, global_param in zip(global_icm.parameters(), local_icm.parameters()): if global_param.grad is not None: break global_param._grad = local_param.grad loss.backward() optimizer.step() env.close() if __name__ == '__main__': train()