Curiosity/icm mario.py
2023-01-30 17:57:49 +01:00

173 lines
6.7 KiB
Python

import gym
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
from models import Actor, Critic, Encoder, InverseModel, ForwardModel
from mario_env import create_mario_env
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Make environment
env = gym.make('SuperMarioBros-1-1-v0')
env = create_mario_env(env)
# Models
encoder = Encoder(channels=4, encoded_state_size=256).to(device)
inverse_model = InverseModel(encoded_state_size=256, action_size=env.action_space.n).to(device)
forward_model = ForwardModel(encoded_state_size=256, action_size=env.action_space.n).to(device)
actor = Actor(encoded_state_size=256, action_size=env.action_space.n).to(device)
critic = Critic(encoded_state_size=256).to(device)
# Optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=0.0001)
critic_optim = torch.optim.Adam(critic.parameters(), lr=0.001)
icm_params = list(encoder.parameters()) + list(forward_model.parameters()) + list(inverse_model.parameters())
icm_optim = torch.optim.Adam(icm_params, lr=0.0001)
# Loss functions
ce = nn.CrossEntropyLoss().to(device)
mse = nn.MSELoss().to(device)
# Hyperparameters
beta = 0.2
alpha = 100
gamma = 0.99
lamda = 0.1
# Training Parameters
render = False
num_episodes = 1000
# Training
def train():
t = 0
for episode in range(num_episodes):
observation = env.reset()
total_reward = 0
done = False
while not done:
#env.render()
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
action_probs = actor(state)
action = action_probs.sample()
action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float()
next_observation, reward, done, info = env.step(action.item())
next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device)
encoded_state = encoder(state)
next_encoded_state = encoder(next_state)
predicted_next_state = forward_model(encoded_state, action_one_hot)
predicted_action = inverse_model(encoded_state, next_encoded_state)
intrinsic_reward = alpha * mse(predicted_next_state, next_encoded_state.detach())
extrinsic_reward = torch.tensor(reward).to(device)
reward = intrinsic_reward + extrinsic_reward
forward_loss = mse(predicted_next_state, next_encoded_state.detach())
inverse_loss = ce(action_probs.probs,predicted_action.probs)
icm_loss = beta * forward_loss + (1-beta) * inverse_loss
delta = reward + gamma * (critic(next_state)*(1-done)) - critic(state)
actor_loss = -(action_probs.log_prob(action) +1e-6) * delta
critic_loss = delta ** 2
ac_loss = actor_loss + critic_loss
loss = lamda * ac_loss + icm_loss
actor_optim.zero_grad()
critic_optim.zero_grad()
icm_optim.zero_grad()
loss.backward()
actor_optim.step()
critic_optim.step()
icm_optim.step()
observation = next_observation
total_reward += reward.item()
t +=1
writer.add_scalar('Loss/Actor Loss', actor_loss.item(), t)
writer.add_scalar('Loss/Critic Loss', critic_loss.item(), t)
writer.add_scalar('Loss/Forward Loss', forward_loss.item(), t)
writer.add_scalar('Loss/Inverse Loss', inverse_loss.item(), t)
writer.add_scalar('Reward/Episodic Reward', total_reward, episode)
if episode % 50 == 0:
torch.save(actor.state_dict(), 'saved_models/actor.pth')
torch.save(critic.state_dict(), 'saved_models/critic.pth')
torch.save(encoder.state_dict(), 'saved_models/encoder.pth')
torch.save(inverse_model.state_dict(), 'saved_models/inverse_model.pth')
torch.save(forward_model.state_dict(), 'saved_models/forward_model.pth')
env.close()
def test():
actor.load_state_dict(torch.load('saved_models/actor.pth'))
critic.load_state_dict(torch.load('saved_models/critic.pth'))
encoder.load_state_dict(torch.load('saved_models/encoder.pth'))
inverse_model.load_state_dict(torch.load('saved_models/inverse_model.pth'))
forward_model.load_state_dict(torch.load('saved_models/forward_model.pth'))
observation = env.reset()
while True:
env.render()
state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device)
action_probs = actor(state)
action = action_probs.sample()
observation, reward, done, info = env.step(action.item())
if done:
observation = env.reset()
if __name__ == '__main__':
train()
exit()
class ICM(nn.Module):
def __init__(self, state_size, action_size, inverse_model, forward_model, encoded_state_size=256):
super(ICM, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.loss = nn.MSELoss().to(device)
self.feature_encoder = nn.Sequential(
nn.Conv2d(in_channels=self.state_size[0], out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=32*4*4, out_features=256),
).to(device)
self.model = nn.Sequential(
nn.Linear(self.state_size[1] + 1, 256),
nn.LeakyReLU(),
nn.Linear(256, encoded_state_size)
).to(device)
def forward(self, state, next_state, action):
if state.dim() == 3:
state = state.unsqueeze(0)
next_state = next_state.unsqueeze(0)
encoded_state, next_encoded_state = self.feature_encoder(state), self.feature_encoder(next_state)
#_, encoded_state, next_encoded_state = self.inverse_model(state, next_state)
action = torch.tensor([action]).to(device)
predicted_next_state = self.forward_model(encoded_state, action)
intrinsic_reward = 0.5 * self.loss(predicted_next_state, next_encoded_state.detach())
return intrinsic_reward