54 lines
1.3 KiB
Python
54 lines
1.3 KiB
Python
|
import gym
|
||
|
import numpy as np
|
||
|
import torch.multiprocessing as _mp
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
from torchvision import transforms as T
|
||
|
|
||
|
from models_a3c import *
|
||
|
from mario_env import *
|
||
|
from optimizer import GlobalAdam
|
||
|
from icm_mario import train
|
||
|
from mario_env import create_mario_env
|
||
|
|
||
|
|
||
|
from torch.utils.tensorboard import SummaryWriter
|
||
|
writer = SummaryWriter()
|
||
|
|
||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
def train_a3c():
|
||
|
torch.manual_seed(123)
|
||
|
#env = gym.make('SuperMarioBros-1-1-v0')
|
||
|
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
|
||
|
mp = _mp.get_context("spawn")
|
||
|
|
||
|
global_ac_model = ActorCritic(256, env.action_space.n).to(device)
|
||
|
global_ac_model.share_memory()
|
||
|
global_icm_model = ICM(4, 256, env.action_space.n).to(device)
|
||
|
global_icm_model.share_memory()
|
||
|
|
||
|
optimizer = GlobalAdam(list(global_ac_model.parameters()) + list(global_icm_model.parameters()), lr=1e-4)
|
||
|
processes = []
|
||
|
|
||
|
processes = []
|
||
|
|
||
|
counter = mp.Value('i', 0)
|
||
|
lock = mp.Lock()
|
||
|
|
||
|
for rank in range(0,1):
|
||
|
p = mp.Process(target=train, args=(rank, optimizer, global_ac_model, global_icm_model))
|
||
|
p.start()
|
||
|
processes.append(p)
|
||
|
for p in processes:
|
||
|
p.join()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
train_a3c()
|