diff --git a/a3c/icm_mario.py b/a3c/icm_mario.py new file mode 100644 index 0000000..a3f7164 --- /dev/null +++ b/a3c/icm_mario.py @@ -0,0 +1,91 @@ +import gym +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import transforms as T + +from models_a3c import * +from mario_env import create_mario_env +from mario_env import * +from nes_py.wrappers import JoypadSpace + +from torch.utils.tensorboard import SummaryWriter +writer = SummaryWriter() + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Make environment +#env = gym.make('SuperMarioBros-1-1-v0') +#env = create_mario_env(env) + +# Loss functions +ce = nn.CrossEntropyLoss().to(device) +mse = nn.MSELoss().to(device) + + +# Training +def train(index, optimizer, global_ac, global_icm, beta=0.2, alpha=100, gamma=0.99, lamda=0.1, num_episodes=20000): + torch.manual_seed(123 + index) + + env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense') + + local_ac = ActorCritic(256, env.action_space.n).to(device) + local_icm = ICM(4, 256, env.action_space.n).to(device) + local_ac.load_state_dict(global_ac.state_dict()) + local_icm.load_state_dict(global_icm.state_dict()) + + #for episode in range(num_episodes): + loss = 0 + observation = env.reset() + reward = 0 + done = False + while not done: + #env.render() + state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device) + action_probs, value = local_ac(state) + action = action_probs.sample() + action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float() + + next_observation, reward, done, info = env.step(action.item()) + next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device) + + encoded_state, next_encoded_state, action_pred, next_encoded_state_pred = local_icm(state, next_state, action_one_hot) + + intrinsic_reward = alpha * mse(next_encoded_state_pred, next_encoded_state.detach()) + extrinsic_reward = torch.tensor(reward).to(device) + reward = intrinsic_reward + extrinsic_reward + + forward_loss = mse(next_encoded_state_pred, next_encoded_state.detach()) + inverse_loss = ce(action_probs.probs,action_pred) + icm_loss = beta * forward_loss + (1-beta) * inverse_loss + + delta = reward + gamma * (local_ac(next_state)[1]*(1-done)) - value + actor_loss = -(action_probs.log_prob(action) +1e-8) * delta + critic_loss = delta ** 2 + ac_loss = actor_loss + critic_loss + + loss += lamda * ac_loss + icm_loss + + observation = next_observation + + optimizer.zero_grad() + + # Update global model + for local_param, global_param in zip(global_ac.parameters(), local_ac.parameters()): + if global_param.grad is not None: + break + global_param._grad = local_param.grad + for local_param, global_param in zip(global_icm.parameters(), local_icm.parameters()): + if global_param.grad is not None: + break + global_param._grad = local_param.grad + + loss.backward() + optimizer.step() + + env.close() + +if __name__ == '__main__': + train() \ No newline at end of file diff --git a/a3c/main.py b/a3c/main.py new file mode 100644 index 0000000..2d2eb2d --- /dev/null +++ b/a3c/main.py @@ -0,0 +1,54 @@ +import gym +import numpy as np +import torch.multiprocessing as _mp + +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import transforms as T + +from models_a3c import * +from mario_env import * +from optimizer import GlobalAdam +from icm_mario import train +from mario_env import create_mario_env + + +from torch.utils.tensorboard import SummaryWriter +writer = SummaryWriter() + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + + + + +def train_a3c(): + torch.manual_seed(123) + #env = gym.make('SuperMarioBros-1-1-v0') + env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense') + mp = _mp.get_context("spawn") + + global_ac_model = ActorCritic(256, env.action_space.n).to(device) + global_ac_model.share_memory() + global_icm_model = ICM(4, 256, env.action_space.n).to(device) + global_icm_model.share_memory() + + optimizer = GlobalAdam(list(global_ac_model.parameters()) + list(global_icm_model.parameters()), lr=1e-4) + processes = [] + + processes = [] + + counter = mp.Value('i', 0) + lock = mp.Lock() + + for rank in range(0,1): + p = mp.Process(target=train, args=(rank, optimizer, global_ac_model, global_icm_model)) + p.start() + processes.append(p) + for p in processes: + p.join() + + +if __name__ == "__main__": + train_a3c() \ No newline at end of file diff --git a/a3c/mario_env.py b/a3c/mario_env.py new file mode 100644 index 0000000..db47c24 --- /dev/null +++ b/a3c/mario_env.py @@ -0,0 +1,191 @@ +import numpy as np +from collections import deque +import gym +import gym_super_mario_bros +from nes_py.wrappers import JoypadSpace +from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY +SIMPLE_MOVEMENT = SIMPLE_MOVEMENT[1:] +from gym import spaces +from PIL import Image +import cv2 + +PALETTE_ACTIONS = [ + ['up'], + ['down'], + ['left'], + ['left', 'A'], + ['left', 'B'], + ['left', 'A', 'B'], + ['right'], + ['right', 'A'], + ['right', 'B'], + ['right', 'A', 'B'], + ['A'], + ['B'], + ['A', 'B'] + ] +def _process_frame_mario(frame): + if frame is not None: # for future meta implementation + img = np.reshape(frame, [240, 256, 3]).astype(np.float32) + img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 + x_t = cv2.resize(img, (84, 84)) + x_t = np.reshape(x_t, [1, 84, 84])/255.0 + #x_t.astype(np.uint8) + + else: + x_t = np.zeros((1, 84, 84)) + return x_t + + + +class ProcessFrameMario(gym.Wrapper): + def __init__(self, env=None, reward_type=None): + super(ProcessFrameMario, self).__init__(env) + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(1, 84, 84), dtype=np.uint8) + self.prev_time = 400 + self.prev_stat = 0 + self.prev_score = 0 + self.prev_dist = 40 + self.reward_type = reward_type + self.milestones = [i for i in range(150,3150,150)] + self.counter = 0 + + def step(self, action): + ''' + Implementing custom rewards + Time = -0.1 + Distance = +1 or 0 + Player Status = +/- 5 + Score = 2.5 x [Increase in Score] + Done = +50 [Game Completed] or -50 [Game Incomplete] + ''' + obs, _, done, info = self.env.step(action) + + if self.reward_type == 'sparse': + reward = 0 + if (self.counter < len(self.milestones)) and (info['x_pos'] > self.milestones[self.counter]) : + reward = 10 + self.counter = self.counter + 1 + + if done : + if info['flag_get'] : + reward = 50 + else: + reward = -10 + + elif self.reward_type == 'dense': + + reward = max(min((info['x_pos'] - self.prev_dist - 0.05), 2), -2) + self.prev_dist = info['x_pos'] + + reward += (self.prev_time - info['time']) * -0.1 + self.prev_time = info['time'] + + reward += (int(info['status']!='small') - self.prev_stat) * 5 + self.prev_stat = int(info['status']!='small') + + reward += (info['score'] - self.prev_score) * 0.025 + self.prev_score = info['score'] + + if done: + if info['flag_get'] : + reward += 500 + else: + reward -= 50 + + else : return None + + return _process_frame_mario(obs), reward/10, done, info + + def reset(self): + self.prev_time = 400 + self.prev_stat = 0 + self.prev_score = 0 + self.prev_dist = 40 + self.counter = 0 + return _process_frame_mario(self.env.reset()) + + def change_level(self, level): + self.env.change_level(level) + + +class BufferSkipFrames(gym.Wrapper): + def __init__(self, env=None, skip=4, shape=(84, 84)): + super(BufferSkipFrames, self).__init__(env) + self.counter = 0 + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(4, 84, 84), dtype=np.uint8) + self.skip = skip + self.buffer = deque(maxlen=self.skip) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + counter = 1 + total_reward = reward + self.buffer.append(obs) + + for i in range(self.skip - 1): + if not done: + obs, reward, done, info = self.env.step(action) + total_reward += reward + counter +=1 + self.buffer.append(obs) + else: + self.buffer.append(obs) + + frame = np.stack(self.buffer, axis=0) + frame = np.reshape(frame, (4, 84, 84)) + return frame, total_reward, done, info + + def reset(self): + self.buffer.clear() + obs = self.env.reset() + for i in range(self.skip): + self.buffer.append(obs) + + frame = np.stack(self.buffer, axis=0) + frame = np.reshape(frame, (4, 84, 84)) + return frame + + def change_level(self, level): + self.env.change_level(level) + + +class NormalizedEnv(gym.ObservationWrapper): + def __init__(self, env=None): + super(NormalizedEnv, self).__init__(env) + self.state_mean = 0 + self.state_std = 0 + self.alpha = 0.9999 + self.num_steps = 0 + + def observation(self, observation): + if observation is not None: # for future meta implementation + self.num_steps += 1 + self.state_mean = self.state_mean * self.alpha + \ + observation.mean() * (1 - self.alpha) + self.state_std = self.state_std * self.alpha + \ + observation.std() * (1 - self.alpha) + + unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps)) + unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps)) + + return (observation - unbiased_mean) / (unbiased_std + 1e-8) + + else: + return observation + + def change_level(self, level): + self.env.change_level(level) + +def wrap_mario(env, reward_type): + # assert 'SuperMarioBros' in env.spec.id + env = ProcessFrameMario(env, reward_type) + env = NormalizedEnv(env) + env = BufferSkipFrames(env) + return env + +def create_mario_env(env_id, reward_type): + env = gym_super_mario_bros.make(env_id) + env = JoypadSpace(env, PALETTE_ACTIONS) + env = wrap_mario(env, reward_type) + return env \ No newline at end of file diff --git a/a3c/models_a3c.py b/a3c/models_a3c.py new file mode 100644 index 0000000..fe55de2 --- /dev/null +++ b/a3c/models_a3c.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn +import torch.nn.functional as f +from torch.distributions import Categorical + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class ICM(nn.Module): + def __init__(self, channels, encoded_state_size, action_size): + super(ICM, self).__init__() + self.channels = channels + self.encoded_state_size = encoded_state_size + self.action_size = action_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), + ).to(device) + + self.inverse_model = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size*2, out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + + self.forward_model = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.encoded_state_size), + ).to(device) + + def forward(self, state, next_state, action): + if state.dim() == 3: + state = state.unsqueeze(0) + next_state = next_state.unsqueeze(0) + + encoded_state = self.feature_encoder(state) + next_encoded_state = self.feature_encoder(next_state) + action_pred = self.inverse_model(torch.cat((encoded_state, next_encoded_state), dim=-1)) + next_encoded_state_pred = self.forward_model(torch.cat((encoded_state, action), dim=-1)) + return encoded_state, next_encoded_state, action_pred, next_encoded_state_pred + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + +class ActorCritic(nn.Module): + def __init__(self,encoded_state_size, action_size, state_size=4): + super(ActorCritic, self).__init__() + self.channels = state_size + self.encoded_state_size = encoded_state_size + self.action_size = action_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), + ).to(device) + + def actor(self,state): + policy = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size , out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + return policy(state) + + def critic(self,state): + value = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size , out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=1), + ).to(device) + return value(state) + + def forward(self, state): + if state.dim() == 3: + state = state.unsqueeze(0) + + state = self.feature_encoder(state) + value = self.critic(state) + policy = self.actor(state) + actions = Categorical(policy) + return actions, value + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + +#vec = torch.randn(4,84,84).to(device) +#ac = ActorCritic(256,12).to(device) +#a,v = ac(vec) +#print(a,v) \ No newline at end of file diff --git a/a3c/optimizer.py b/a3c/optimizer.py new file mode 100644 index 0000000..b6a045a --- /dev/null +++ b/a3c/optimizer.py @@ -0,0 +1,14 @@ +import torch + +class GlobalAdam(torch.optim.Adam): + def __init__(self, params, lr): + super(GlobalAdam, self).__init__(params, lr=lr) + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p.data) + + state['exp_avg'].share_memory_() + state['exp_avg_sq'].share_memory_() \ No newline at end of file