import gym import numpy as np import torch.multiprocessing as _mp import torch from torch import nn import torch.nn.functional as F from torchvision import transforms as T from models_a3c import * from mario_env import * from optimizer import GlobalAdam from icm_mario import train from mario_env import create_mario_env from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def train_a3c(): torch.manual_seed(123) #env = gym.make('SuperMarioBros-1-1-v0') env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense') mp = _mp.get_context("spawn") global_ac_model = ActorCritic(256, env.action_space.n).to(device) global_ac_model.share_memory() global_icm_model = ICM(4, 256, env.action_space.n).to(device) global_icm_model.share_memory() optimizer = GlobalAdam(list(global_ac_model.parameters()) + list(global_icm_model.parameters()), lr=1e-4) processes = [] processes = [] counter = mp.Value('i', 0) lock = mp.Lock() for rank in range(0,1): p = mp.Process(target=train, args=(rank, optimizer, global_ac_model, global_icm_model)) p.start() processes.append(p) for p in processes: p.join() if __name__ == "__main__": train_a3c()