Curiosity/a3c/main.py
2023-01-31 15:58:50 +01:00

54 lines
1.3 KiB
Python

import gym
import numpy as np
import torch.multiprocessing as _mp
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
from models_a3c import *
from mario_env import *
from optimizer import GlobalAdam
from icm_mario import train
from mario_env import create_mario_env
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_a3c():
torch.manual_seed(123)
#env = gym.make('SuperMarioBros-1-1-v0')
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
mp = _mp.get_context("spawn")
global_ac_model = ActorCritic(256, env.action_space.n).to(device)
global_ac_model.share_memory()
global_icm_model = ICM(4, 256, env.action_space.n).to(device)
global_icm_model.share_memory()
optimizer = GlobalAdam(list(global_ac_model.parameters()) + list(global_icm_model.parameters()), lr=1e-4)
processes = []
processes = []
counter = mp.Value('i', 0)
lock = mp.Lock()
for rank in range(0,1):
p = mp.Process(target=train, args=(rank, optimizer, global_ac_model, global_icm_model))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
train_a3c()