This commit is contained in:
ved1 2023-01-31 15:58:50 +01:00
parent c3f6e9f281
commit 000b970a12
5 changed files with 464 additions and 0 deletions

91
a3c/icm_mario.py Normal file
View File

@ -0,0 +1,91 @@
import gym
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
from models_a3c import *
from mario_env import create_mario_env
from mario_env import *
from nes_py.wrappers import JoypadSpace
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Make environment
#env = gym.make('SuperMarioBros-1-1-v0')
#env = create_mario_env(env)
# Loss functions
ce = nn.CrossEntropyLoss().to(device)
mse = nn.MSELoss().to(device)
# Training
def train(index, optimizer, global_ac, global_icm, beta=0.2, alpha=100, gamma=0.99, lamda=0.1, num_episodes=20000):
torch.manual_seed(123 + index)
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
local_ac = ActorCritic(256, env.action_space.n).to(device)
local_icm = ICM(4, 256, env.action_space.n).to(device)
local_ac.load_state_dict(global_ac.state_dict())
local_icm.load_state_dict(global_icm.state_dict())
#for episode in range(num_episodes):
loss = 0
observation = env.reset()
reward = 0
done = False
while not done:
#env.render()
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
action_probs, value = local_ac(state)
action = action_probs.sample()
action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float()
next_observation, reward, done, info = env.step(action.item())
next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device)
encoded_state, next_encoded_state, action_pred, next_encoded_state_pred = local_icm(state, next_state, action_one_hot)
intrinsic_reward = alpha * mse(next_encoded_state_pred, next_encoded_state.detach())
extrinsic_reward = torch.tensor(reward).to(device)
reward = intrinsic_reward + extrinsic_reward
forward_loss = mse(next_encoded_state_pred, next_encoded_state.detach())
inverse_loss = ce(action_probs.probs,action_pred)
icm_loss = beta * forward_loss + (1-beta) * inverse_loss
delta = reward + gamma * (local_ac(next_state)[1]*(1-done)) - value
actor_loss = -(action_probs.log_prob(action) +1e-8) * delta
critic_loss = delta ** 2
ac_loss = actor_loss + critic_loss
loss += lamda * ac_loss + icm_loss
observation = next_observation
optimizer.zero_grad()
# Update global model
for local_param, global_param in zip(global_ac.parameters(), local_ac.parameters()):
if global_param.grad is not None:
break
global_param._grad = local_param.grad
for local_param, global_param in zip(global_icm.parameters(), local_icm.parameters()):
if global_param.grad is not None:
break
global_param._grad = local_param.grad
loss.backward()
optimizer.step()
env.close()
if __name__ == '__main__':
train()

54
a3c/main.py Normal file
View File

@ -0,0 +1,54 @@
import gym
import numpy as np
import torch.multiprocessing as _mp
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
from models_a3c import *
from mario_env import *
from optimizer import GlobalAdam
from icm_mario import train
from mario_env import create_mario_env
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_a3c():
torch.manual_seed(123)
#env = gym.make('SuperMarioBros-1-1-v0')
env = create_mario_env('SuperMarioBros-1-1-v0', reward_type = 'dense')
mp = _mp.get_context("spawn")
global_ac_model = ActorCritic(256, env.action_space.n).to(device)
global_ac_model.share_memory()
global_icm_model = ICM(4, 256, env.action_space.n).to(device)
global_icm_model.share_memory()
optimizer = GlobalAdam(list(global_ac_model.parameters()) + list(global_icm_model.parameters()), lr=1e-4)
processes = []
processes = []
counter = mp.Value('i', 0)
lock = mp.Lock()
for rank in range(0,1):
p = mp.Process(target=train, args=(rank, optimizer, global_ac_model, global_icm_model))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
train_a3c()

191
a3c/mario_env.py Normal file
View File

@ -0,0 +1,191 @@
import numpy as np
from collections import deque
import gym
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
SIMPLE_MOVEMENT = SIMPLE_MOVEMENT[1:]
from gym import spaces
from PIL import Image
import cv2
PALETTE_ACTIONS = [
['up'],
['down'],
['left'],
['left', 'A'],
['left', 'B'],
['left', 'A', 'B'],
['right'],
['right', 'A'],
['right', 'B'],
['right', 'A', 'B'],
['A'],
['B'],
['A', 'B']
]
def _process_frame_mario(frame):
if frame is not None: # for future meta implementation
img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
x_t = cv2.resize(img, (84, 84))
x_t = np.reshape(x_t, [1, 84, 84])/255.0
#x_t.astype(np.uint8)
else:
x_t = np.zeros((1, 84, 84))
return x_t
class ProcessFrameMario(gym.Wrapper):
def __init__(self, env=None, reward_type=None):
super(ProcessFrameMario, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(1, 84, 84), dtype=np.uint8)
self.prev_time = 400
self.prev_stat = 0
self.prev_score = 0
self.prev_dist = 40
self.reward_type = reward_type
self.milestones = [i for i in range(150,3150,150)]
self.counter = 0
def step(self, action):
'''
Implementing custom rewards
Time = -0.1
Distance = +1 or 0
Player Status = +/- 5
Score = 2.5 x [Increase in Score]
Done = +50 [Game Completed] or -50 [Game Incomplete]
'''
obs, _, done, info = self.env.step(action)
if self.reward_type == 'sparse':
reward = 0
if (self.counter < len(self.milestones)) and (info['x_pos'] > self.milestones[self.counter]) :
reward = 10
self.counter = self.counter + 1
if done :
if info['flag_get'] :
reward = 50
else:
reward = -10
elif self.reward_type == 'dense':
reward = max(min((info['x_pos'] - self.prev_dist - 0.05), 2), -2)
self.prev_dist = info['x_pos']
reward += (self.prev_time - info['time']) * -0.1
self.prev_time = info['time']
reward += (int(info['status']!='small') - self.prev_stat) * 5
self.prev_stat = int(info['status']!='small')
reward += (info['score'] - self.prev_score) * 0.025
self.prev_score = info['score']
if done:
if info['flag_get'] :
reward += 500
else:
reward -= 50
else : return None
return _process_frame_mario(obs), reward/10, done, info
def reset(self):
self.prev_time = 400
self.prev_stat = 0
self.prev_score = 0
self.prev_dist = 40
self.counter = 0
return _process_frame_mario(self.env.reset())
def change_level(self, level):
self.env.change_level(level)
class BufferSkipFrames(gym.Wrapper):
def __init__(self, env=None, skip=4, shape=(84, 84)):
super(BufferSkipFrames, self).__init__(env)
self.counter = 0
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(4, 84, 84), dtype=np.uint8)
self.skip = skip
self.buffer = deque(maxlen=self.skip)
def step(self, action):
obs, reward, done, info = self.env.step(action)
counter = 1
total_reward = reward
self.buffer.append(obs)
for i in range(self.skip - 1):
if not done:
obs, reward, done, info = self.env.step(action)
total_reward += reward
counter +=1
self.buffer.append(obs)
else:
self.buffer.append(obs)
frame = np.stack(self.buffer, axis=0)
frame = np.reshape(frame, (4, 84, 84))
return frame, total_reward, done, info
def reset(self):
self.buffer.clear()
obs = self.env.reset()
for i in range(self.skip):
self.buffer.append(obs)
frame = np.stack(self.buffer, axis=0)
frame = np.reshape(frame, (4, 84, 84))
return frame
def change_level(self, level):
self.env.change_level(level)
class NormalizedEnv(gym.ObservationWrapper):
def __init__(self, env=None):
super(NormalizedEnv, self).__init__(env)
self.state_mean = 0
self.state_std = 0
self.alpha = 0.9999
self.num_steps = 0
def observation(self, observation):
if observation is not None: # for future meta implementation
self.num_steps += 1
self.state_mean = self.state_mean * self.alpha + \
observation.mean() * (1 - self.alpha)
self.state_std = self.state_std * self.alpha + \
observation.std() * (1 - self.alpha)
unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))
return (observation - unbiased_mean) / (unbiased_std + 1e-8)
else:
return observation
def change_level(self, level):
self.env.change_level(level)
def wrap_mario(env, reward_type):
# assert 'SuperMarioBros' in env.spec.id
env = ProcessFrameMario(env, reward_type)
env = NormalizedEnv(env)
env = BufferSkipFrames(env)
return env
def create_mario_env(env_id, reward_type):
env = gym_super_mario_bros.make(env_id)
env = JoypadSpace(env, PALETTE_ACTIONS)
env = wrap_mario(env, reward_type)
return env

114
a3c/models_a3c.py Normal file
View File

@ -0,0 +1,114 @@
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.distributions import Categorical
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ICM(nn.Module):
def __init__(self, channels, encoded_state_size, action_size):
super(ICM, self).__init__()
self.channels = channels
self.encoded_state_size = encoded_state_size
self.action_size = action_size
self.feature_encoder = nn.Sequential(
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
).to(device)
self.inverse_model = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size*2, out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=self.action_size),
nn.Softmax(dim=-1)
).to(device)
self.forward_model = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=self.encoded_state_size),
).to(device)
def forward(self, state, next_state, action):
if state.dim() == 3:
state = state.unsqueeze(0)
next_state = next_state.unsqueeze(0)
encoded_state = self.feature_encoder(state)
next_encoded_state = self.feature_encoder(next_state)
action_pred = self.inverse_model(torch.cat((encoded_state, next_encoded_state), dim=-1))
next_encoded_state_pred = self.forward_model(torch.cat((encoded_state, action), dim=-1))
return encoded_state, next_encoded_state, action_pred, next_encoded_state_pred
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
class ActorCritic(nn.Module):
def __init__(self,encoded_state_size, action_size, state_size=4):
super(ActorCritic, self).__init__()
self.channels = state_size
self.encoded_state_size = encoded_state_size
self.action_size = action_size
self.feature_encoder = nn.Sequential(
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
).to(device)
def actor(self,state):
policy = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size , out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=self.action_size),
nn.Softmax(dim=-1)
).to(device)
return policy(state)
def critic(self,state):
value = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size , out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=1),
).to(device)
return value(state)
def forward(self, state):
if state.dim() == 3:
state = state.unsqueeze(0)
state = self.feature_encoder(state)
value = self.critic(state)
policy = self.actor(state)
actions = Categorical(policy)
return actions, value
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
#vec = torch.randn(4,84,84).to(device)
#ac = ActorCritic(256,12).to(device)
#a,v = ac(vec)
#print(a,v)

14
a3c/optimizer.py Normal file
View File

@ -0,0 +1,14 @@
import torch
class GlobalAdam(torch.optim.Adam):
def __init__(self, params, lr):
super(GlobalAdam, self).__init__(params, lr=lr)
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()