A3C+ICM
This commit is contained in:
parent
c3f6e9f281
commit
000b970a12
91
a3c/icm_mario.py
Normal file
91
a3c/icm_mario.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision import transforms as T
|
||||||
|
|
||||||
|
from models_a3c import *
|
||||||
|
from mario_env import create_mario_env
|
||||||
|
from mario_env import *
|
||||||
|
from nes_py.wrappers import JoypadSpace
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Make environment
|
||||||
|
#env = gym.make('SuperMarioBros-1-1-v0')
|
||||||
|
#env = create_mario_env(env)
|
||||||
|
|
||||||
|
# Loss functions
|
||||||
|
ce = nn.CrossEntropyLoss().to(device)
|
||||||
|
mse = nn.MSELoss().to(device)
|
||||||
|
|
||||||
|
|
||||||
|
# Training
|
||||||
|
def train(index, optimizer, global_ac, global_icm, beta=0.2, alpha=100, gamma=0.99, lamda=0.1, num_episodes=20000):
|
||||||
|
torch.manual_seed(123 + index)
|
||||||
|
|
||||||
|
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
|
||||||
|
|
||||||
|
local_ac = ActorCritic(256, env.action_space.n).to(device)
|
||||||
|
local_icm = ICM(4, 256, env.action_space.n).to(device)
|
||||||
|
local_ac.load_state_dict(global_ac.state_dict())
|
||||||
|
local_icm.load_state_dict(global_icm.state_dict())
|
||||||
|
|
||||||
|
#for episode in range(num_episodes):
|
||||||
|
loss = 0
|
||||||
|
observation = env.reset()
|
||||||
|
reward = 0
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
#env.render()
|
||||||
|
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
|
||||||
|
action_probs, value = local_ac(state)
|
||||||
|
action = action_probs.sample()
|
||||||
|
action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float()
|
||||||
|
|
||||||
|
next_observation, reward, done, info = env.step(action.item())
|
||||||
|
next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device)
|
||||||
|
|
||||||
|
encoded_state, next_encoded_state, action_pred, next_encoded_state_pred = local_icm(state, next_state, action_one_hot)
|
||||||
|
|
||||||
|
intrinsic_reward = alpha * mse(next_encoded_state_pred, next_encoded_state.detach())
|
||||||
|
extrinsic_reward = torch.tensor(reward).to(device)
|
||||||
|
reward = intrinsic_reward + extrinsic_reward
|
||||||
|
|
||||||
|
forward_loss = mse(next_encoded_state_pred, next_encoded_state.detach())
|
||||||
|
inverse_loss = ce(action_probs.probs,action_pred)
|
||||||
|
icm_loss = beta * forward_loss + (1-beta) * inverse_loss
|
||||||
|
|
||||||
|
delta = reward + gamma * (local_ac(next_state)[1]*(1-done)) - value
|
||||||
|
actor_loss = -(action_probs.log_prob(action) +1e-8) * delta
|
||||||
|
critic_loss = delta ** 2
|
||||||
|
ac_loss = actor_loss + critic_loss
|
||||||
|
|
||||||
|
loss += lamda * ac_loss + icm_loss
|
||||||
|
|
||||||
|
observation = next_observation
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Update global model
|
||||||
|
for local_param, global_param in zip(global_ac.parameters(), local_ac.parameters()):
|
||||||
|
if global_param.grad is not None:
|
||||||
|
break
|
||||||
|
global_param._grad = local_param.grad
|
||||||
|
for local_param, global_param in zip(global_icm.parameters(), local_icm.parameters()):
|
||||||
|
if global_param.grad is not None:
|
||||||
|
break
|
||||||
|
global_param._grad = local_param.grad
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train()
|
54
a3c/main.py
Normal file
54
a3c/main.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch.multiprocessing as _mp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision import transforms as T
|
||||||
|
|
||||||
|
from models_a3c import *
|
||||||
|
from mario_env import *
|
||||||
|
from optimizer import GlobalAdam
|
||||||
|
from icm_mario import train
|
||||||
|
from mario_env import create_mario_env
|
||||||
|
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train_a3c():
|
||||||
|
torch.manual_seed(123)
|
||||||
|
#env = gym.make('SuperMarioBros-1-1-v0')
|
||||||
|
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
|
||||||
|
mp = _mp.get_context("spawn")
|
||||||
|
|
||||||
|
global_ac_model = ActorCritic(256, env.action_space.n).to(device)
|
||||||
|
global_ac_model.share_memory()
|
||||||
|
global_icm_model = ICM(4, 256, env.action_space.n).to(device)
|
||||||
|
global_icm_model.share_memory()
|
||||||
|
|
||||||
|
optimizer = GlobalAdam(list(global_ac_model.parameters()) + list(global_icm_model.parameters()), lr=1e-4)
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
counter = mp.Value('i', 0)
|
||||||
|
lock = mp.Lock()
|
||||||
|
|
||||||
|
for rank in range(0,1):
|
||||||
|
p = mp.Process(target=train, args=(rank, optimizer, global_ac_model, global_icm_model))
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_a3c()
|
191
a3c/mario_env.py
Normal file
191
a3c/mario_env.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
import gym
|
||||||
|
import gym_super_mario_bros
|
||||||
|
from nes_py.wrappers import JoypadSpace
|
||||||
|
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
|
||||||
|
SIMPLE_MOVEMENT = SIMPLE_MOVEMENT[1:]
|
||||||
|
from gym import spaces
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
PALETTE_ACTIONS = [
|
||||||
|
['up'],
|
||||||
|
['down'],
|
||||||
|
['left'],
|
||||||
|
['left', 'A'],
|
||||||
|
['left', 'B'],
|
||||||
|
['left', 'A', 'B'],
|
||||||
|
['right'],
|
||||||
|
['right', 'A'],
|
||||||
|
['right', 'B'],
|
||||||
|
['right', 'A', 'B'],
|
||||||
|
['A'],
|
||||||
|
['B'],
|
||||||
|
['A', 'B']
|
||||||
|
]
|
||||||
|
def _process_frame_mario(frame):
|
||||||
|
if frame is not None: # for future meta implementation
|
||||||
|
img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
|
||||||
|
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
|
||||||
|
x_t = cv2.resize(img, (84, 84))
|
||||||
|
x_t = np.reshape(x_t, [1, 84, 84])/255.0
|
||||||
|
#x_t.astype(np.uint8)
|
||||||
|
|
||||||
|
else:
|
||||||
|
x_t = np.zeros((1, 84, 84))
|
||||||
|
return x_t
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessFrameMario(gym.Wrapper):
|
||||||
|
def __init__(self, env=None, reward_type=None):
|
||||||
|
super(ProcessFrameMario, self).__init__(env)
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(1, 84, 84), dtype=np.uint8)
|
||||||
|
self.prev_time = 400
|
||||||
|
self.prev_stat = 0
|
||||||
|
self.prev_score = 0
|
||||||
|
self.prev_dist = 40
|
||||||
|
self.reward_type = reward_type
|
||||||
|
self.milestones = [i for i in range(150,3150,150)]
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
'''
|
||||||
|
Implementing custom rewards
|
||||||
|
Time = -0.1
|
||||||
|
Distance = +1 or 0
|
||||||
|
Player Status = +/- 5
|
||||||
|
Score = 2.5 x [Increase in Score]
|
||||||
|
Done = +50 [Game Completed] or -50 [Game Incomplete]
|
||||||
|
'''
|
||||||
|
obs, _, done, info = self.env.step(action)
|
||||||
|
|
||||||
|
if self.reward_type == 'sparse':
|
||||||
|
reward = 0
|
||||||
|
if (self.counter < len(self.milestones)) and (info['x_pos'] > self.milestones[self.counter]) :
|
||||||
|
reward = 10
|
||||||
|
self.counter = self.counter + 1
|
||||||
|
|
||||||
|
if done :
|
||||||
|
if info['flag_get'] :
|
||||||
|
reward = 50
|
||||||
|
else:
|
||||||
|
reward = -10
|
||||||
|
|
||||||
|
elif self.reward_type == 'dense':
|
||||||
|
|
||||||
|
reward = max(min((info['x_pos'] - self.prev_dist - 0.05), 2), -2)
|
||||||
|
self.prev_dist = info['x_pos']
|
||||||
|
|
||||||
|
reward += (self.prev_time - info['time']) * -0.1
|
||||||
|
self.prev_time = info['time']
|
||||||
|
|
||||||
|
reward += (int(info['status']!='small') - self.prev_stat) * 5
|
||||||
|
self.prev_stat = int(info['status']!='small')
|
||||||
|
|
||||||
|
reward += (info['score'] - self.prev_score) * 0.025
|
||||||
|
self.prev_score = info['score']
|
||||||
|
|
||||||
|
if done:
|
||||||
|
if info['flag_get'] :
|
||||||
|
reward += 500
|
||||||
|
else:
|
||||||
|
reward -= 50
|
||||||
|
|
||||||
|
else : return None
|
||||||
|
|
||||||
|
return _process_frame_mario(obs), reward/10, done, info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.prev_time = 400
|
||||||
|
self.prev_stat = 0
|
||||||
|
self.prev_score = 0
|
||||||
|
self.prev_dist = 40
|
||||||
|
self.counter = 0
|
||||||
|
return _process_frame_mario(self.env.reset())
|
||||||
|
|
||||||
|
def change_level(self, level):
|
||||||
|
self.env.change_level(level)
|
||||||
|
|
||||||
|
|
||||||
|
class BufferSkipFrames(gym.Wrapper):
|
||||||
|
def __init__(self, env=None, skip=4, shape=(84, 84)):
|
||||||
|
super(BufferSkipFrames, self).__init__(env)
|
||||||
|
self.counter = 0
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(4, 84, 84), dtype=np.uint8)
|
||||||
|
self.skip = skip
|
||||||
|
self.buffer = deque(maxlen=self.skip)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
counter = 1
|
||||||
|
total_reward = reward
|
||||||
|
self.buffer.append(obs)
|
||||||
|
|
||||||
|
for i in range(self.skip - 1):
|
||||||
|
if not done:
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
total_reward += reward
|
||||||
|
counter +=1
|
||||||
|
self.buffer.append(obs)
|
||||||
|
else:
|
||||||
|
self.buffer.append(obs)
|
||||||
|
|
||||||
|
frame = np.stack(self.buffer, axis=0)
|
||||||
|
frame = np.reshape(frame, (4, 84, 84))
|
||||||
|
return frame, total_reward, done, info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.buffer.clear()
|
||||||
|
obs = self.env.reset()
|
||||||
|
for i in range(self.skip):
|
||||||
|
self.buffer.append(obs)
|
||||||
|
|
||||||
|
frame = np.stack(self.buffer, axis=0)
|
||||||
|
frame = np.reshape(frame, (4, 84, 84))
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def change_level(self, level):
|
||||||
|
self.env.change_level(level)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedEnv(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env=None):
|
||||||
|
super(NormalizedEnv, self).__init__(env)
|
||||||
|
self.state_mean = 0
|
||||||
|
self.state_std = 0
|
||||||
|
self.alpha = 0.9999
|
||||||
|
self.num_steps = 0
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
if observation is not None: # for future meta implementation
|
||||||
|
self.num_steps += 1
|
||||||
|
self.state_mean = self.state_mean * self.alpha + \
|
||||||
|
observation.mean() * (1 - self.alpha)
|
||||||
|
self.state_std = self.state_std * self.alpha + \
|
||||||
|
observation.std() * (1 - self.alpha)
|
||||||
|
|
||||||
|
unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
|
||||||
|
unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))
|
||||||
|
|
||||||
|
return (observation - unbiased_mean) / (unbiased_std + 1e-8)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def change_level(self, level):
|
||||||
|
self.env.change_level(level)
|
||||||
|
|
||||||
|
def wrap_mario(env, reward_type):
|
||||||
|
# assert 'SuperMarioBros' in env.spec.id
|
||||||
|
env = ProcessFrameMario(env, reward_type)
|
||||||
|
env = NormalizedEnv(env)
|
||||||
|
env = BufferSkipFrames(env)
|
||||||
|
return env
|
||||||
|
|
||||||
|
def create_mario_env(env_id, reward_type):
|
||||||
|
env = gym_super_mario_bros.make(env_id)
|
||||||
|
env = JoypadSpace(env, PALETTE_ACTIONS)
|
||||||
|
env = wrap_mario(env, reward_type)
|
||||||
|
return env
|
114
a3c/models_a3c.py
Normal file
114
a3c/models_a3c.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as f
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class ICM(nn.Module):
|
||||||
|
def __init__(self, channels, encoded_state_size, action_size):
|
||||||
|
super(ICM, self).__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.encoded_state_size = encoded_state_size
|
||||||
|
self.action_size = action_size
|
||||||
|
|
||||||
|
self.feature_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.inverse_model = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size*2, out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=self.action_size),
|
||||||
|
nn.Softmax(dim=-1)
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.forward_model = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=self.encoded_state_size),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
def forward(self, state, next_state, action):
|
||||||
|
if state.dim() == 3:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
next_state = next_state.unsqueeze(0)
|
||||||
|
|
||||||
|
encoded_state = self.feature_encoder(state)
|
||||||
|
next_encoded_state = self.feature_encoder(next_state)
|
||||||
|
action_pred = self.inverse_model(torch.cat((encoded_state, next_encoded_state), dim=-1))
|
||||||
|
next_encoded_state_pred = self.forward_model(torch.cat((encoded_state, action), dim=-1))
|
||||||
|
return encoded_state, next_encoded_state, action_pred, next_encoded_state_pred
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
class ActorCritic(nn.Module):
|
||||||
|
def __init__(self,encoded_state_size, action_size, state_size=4):
|
||||||
|
super(ActorCritic, self).__init__()
|
||||||
|
self.channels = state_size
|
||||||
|
self.encoded_state_size = encoded_state_size
|
||||||
|
self.action_size = action_size
|
||||||
|
|
||||||
|
self.feature_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
def actor(self,state):
|
||||||
|
policy = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=self.action_size),
|
||||||
|
nn.Softmax(dim=-1)
|
||||||
|
).to(device)
|
||||||
|
return policy(state)
|
||||||
|
|
||||||
|
def critic(self,state):
|
||||||
|
value = nn.Sequential(
|
||||||
|
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=1),
|
||||||
|
).to(device)
|
||||||
|
return value(state)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
if state.dim() == 3:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
|
state = self.feature_encoder(state)
|
||||||
|
value = self.critic(state)
|
||||||
|
policy = self.actor(state)
|
||||||
|
actions = Categorical(policy)
|
||||||
|
return actions, value
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
#vec = torch.randn(4,84,84).to(device)
|
||||||
|
#ac = ActorCritic(256,12).to(device)
|
||||||
|
#a,v = ac(vec)
|
||||||
|
#print(a,v)
|
14
a3c/optimizer.py
Normal file
14
a3c/optimizer.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class GlobalAdam(torch.optim.Adam):
|
||||||
|
def __init__(self, params, lr):
|
||||||
|
super(GlobalAdam, self).__init__(params, lr=lr)
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
state = self.state[p]
|
||||||
|
state['step'] = 0
|
||||||
|
state['exp_avg'] = torch.zeros_like(p.data)
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||||
|
|
||||||
|
state['exp_avg'].share_memory_()
|
||||||
|
state['exp_avg_sq'].share_memory_()
|
Loading…
Reference in New Issue
Block a user