2023-01-30 16:57:49 +00:00
|
|
|
import gym
|
|
|
|
import numpy as np
|
|
|
|
|
2023-01-27 18:32:44 +00:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
2023-01-30 16:57:49 +00:00
|
|
|
import torch.nn.functional as F
|
2023-01-27 18:32:44 +00:00
|
|
|
from torchvision import transforms as T
|
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
from models import Actor, Critic, Encoder, InverseModel, ForwardModel
|
|
|
|
from mario_env import create_mario_env
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
writer = SummaryWriter()
|
2023-01-27 18:32:44 +00:00
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
# Make environment
|
|
|
|
env = gym.make('SuperMarioBros-1-1-v0')
|
|
|
|
env = create_mario_env(env)
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
# Models
|
|
|
|
encoder = Encoder(channels=4, encoded_state_size=256).to(device)
|
|
|
|
inverse_model = InverseModel(encoded_state_size=256, action_size=env.action_space.n).to(device)
|
|
|
|
forward_model = ForwardModel(encoded_state_size=256, action_size=env.action_space.n).to(device)
|
|
|
|
actor = Actor(encoded_state_size=256, action_size=env.action_space.n).to(device)
|
|
|
|
critic = Critic(encoded_state_size=256).to(device)
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
# Optimizers
|
|
|
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=0.0001)
|
|
|
|
critic_optim = torch.optim.Adam(critic.parameters(), lr=0.001)
|
|
|
|
icm_params = list(encoder.parameters()) + list(forward_model.parameters()) + list(inverse_model.parameters())
|
|
|
|
icm_optim = torch.optim.Adam(icm_params, lr=0.0001)
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
# Loss functions
|
|
|
|
ce = nn.CrossEntropyLoss().to(device)
|
|
|
|
mse = nn.MSELoss().to(device)
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
# Hyperparameters
|
|
|
|
beta = 0.2
|
|
|
|
alpha = 100
|
|
|
|
gamma = 0.99
|
|
|
|
lamda = 0.1
|
|
|
|
|
|
|
|
# Training Parameters
|
|
|
|
render = False
|
|
|
|
num_episodes = 1000
|
|
|
|
|
|
|
|
# Training
|
|
|
|
def train():
|
|
|
|
t = 0
|
|
|
|
for episode in range(num_episodes):
|
|
|
|
observation = env.reset()
|
|
|
|
total_reward = 0
|
|
|
|
done = False
|
|
|
|
while not done:
|
|
|
|
#env.render()
|
|
|
|
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
|
|
|
|
action_probs = actor(state)
|
|
|
|
action = action_probs.sample()
|
|
|
|
action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float()
|
|
|
|
|
|
|
|
next_observation, reward, done, info = env.step(action.item())
|
|
|
|
next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device)
|
|
|
|
|
|
|
|
encoded_state = encoder(state)
|
|
|
|
next_encoded_state = encoder(next_state)
|
|
|
|
predicted_next_state = forward_model(encoded_state, action_one_hot)
|
|
|
|
predicted_action = inverse_model(encoded_state, next_encoded_state)
|
|
|
|
|
|
|
|
intrinsic_reward = alpha * mse(predicted_next_state, next_encoded_state.detach())
|
|
|
|
extrinsic_reward = torch.tensor(reward).to(device)
|
|
|
|
reward = intrinsic_reward + extrinsic_reward
|
|
|
|
|
|
|
|
forward_loss = mse(predicted_next_state, next_encoded_state.detach())
|
|
|
|
inverse_loss = ce(action_probs.probs,predicted_action.probs)
|
|
|
|
icm_loss = beta * forward_loss + (1-beta) * inverse_loss
|
|
|
|
|
|
|
|
delta = reward + gamma * (critic(next_state)*(1-done)) - critic(state)
|
|
|
|
actor_loss = -(action_probs.log_prob(action) +1e-6) * delta
|
|
|
|
critic_loss = delta ** 2
|
|
|
|
ac_loss = actor_loss + critic_loss
|
|
|
|
|
|
|
|
loss = lamda * ac_loss + icm_loss
|
|
|
|
|
|
|
|
actor_optim.zero_grad()
|
|
|
|
critic_optim.zero_grad()
|
|
|
|
icm_optim.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
actor_optim.step()
|
|
|
|
critic_optim.step()
|
|
|
|
icm_optim.step()
|
|
|
|
|
|
|
|
observation = next_observation
|
|
|
|
|
|
|
|
total_reward += reward.item()
|
|
|
|
|
|
|
|
t +=1
|
|
|
|
writer.add_scalar('Loss/Actor Loss', actor_loss.item(), t)
|
|
|
|
writer.add_scalar('Loss/Critic Loss', critic_loss.item(), t)
|
|
|
|
writer.add_scalar('Loss/Forward Loss', forward_loss.item(), t)
|
|
|
|
writer.add_scalar('Loss/Inverse Loss', inverse_loss.item(), t)
|
|
|
|
|
|
|
|
writer.add_scalar('Reward/Episodic Reward', total_reward, episode)
|
|
|
|
|
|
|
|
if episode % 50 == 0:
|
|
|
|
torch.save(actor.state_dict(), 'saved_models/actor.pth')
|
|
|
|
torch.save(critic.state_dict(), 'saved_models/critic.pth')
|
|
|
|
torch.save(encoder.state_dict(), 'saved_models/encoder.pth')
|
|
|
|
torch.save(inverse_model.state_dict(), 'saved_models/inverse_model.pth')
|
|
|
|
torch.save(forward_model.state_dict(), 'saved_models/forward_model.pth')
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
def test():
|
|
|
|
actor.load_state_dict(torch.load('saved_models/actor.pth'))
|
|
|
|
critic.load_state_dict(torch.load('saved_models/critic.pth'))
|
|
|
|
encoder.load_state_dict(torch.load('saved_models/encoder.pth'))
|
|
|
|
inverse_model.load_state_dict(torch.load('saved_models/inverse_model.pth'))
|
|
|
|
forward_model.load_state_dict(torch.load('saved_models/forward_model.pth'))
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
observation = env.reset()
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
while True:
|
|
|
|
env.render()
|
|
|
|
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
|
|
|
|
action_probs = actor(state)
|
|
|
|
action = action_probs.sample()
|
|
|
|
observation, reward, done, info = env.step(action.item())
|
|
|
|
if done:
|
|
|
|
observation = env.reset()
|
2023-01-27 18:32:44 +00:00
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
if __name__ == '__main__':
|
|
|
|
train()
|
2023-01-27 18:32:44 +00:00
|
|
|
|
|
|
|
|
2023-01-30 16:57:49 +00:00
|
|
|
exit()
|
2023-01-27 18:32:44 +00:00
|
|
|
class ICM(nn.Module):
|
2023-01-30 16:57:49 +00:00
|
|
|
def __init__(self, state_size, action_size, inverse_model, forward_model, encoded_state_size=256):
|
2023-01-27 18:32:44 +00:00
|
|
|
super(ICM, self).__init__()
|
|
|
|
self.state_size = state_size
|
|
|
|
self.action_size = action_size
|
2023-01-30 16:57:49 +00:00
|
|
|
|
2023-01-27 18:32:44 +00:00
|
|
|
self.loss = nn.MSELoss().to(device)
|
|
|
|
|
|
|
|
self.feature_encoder = nn.Sequential(
|
|
|
|
nn.Conv2d(in_channels=self.state_size[0], out_channels=32, kernel_size=3, stride=2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.Flatten(),
|
|
|
|
nn.Linear(in_features=32*4*4, out_features=256),
|
|
|
|
).to(device)
|
|
|
|
|
|
|
|
self.model = nn.Sequential(
|
|
|
|
nn.Linear(self.state_size[1] + 1, 256),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.Linear(256, encoded_state_size)
|
|
|
|
).to(device)
|
|
|
|
|
|
|
|
def forward(self, state, next_state, action):
|
|
|
|
if state.dim() == 3:
|
|
|
|
state = state.unsqueeze(0)
|
|
|
|
next_state = next_state.unsqueeze(0)
|
|
|
|
encoded_state, next_encoded_state = self.feature_encoder(state), self.feature_encoder(next_state)
|
|
|
|
#_, encoded_state, next_encoded_state = self.inverse_model(state, next_state)
|
|
|
|
action = torch.tensor([action]).to(device)
|
|
|
|
predicted_next_state = self.forward_model(encoded_state, action)
|
|
|
|
|
|
|
|
intrinsic_reward = 0.5 * self.loss(predicted_next_state, next_encoded_state.detach())
|
2023-01-30 16:57:49 +00:00
|
|
|
return intrinsic_reward
|