import gym import numpy as np import torch from torch import nn import torch.nn.functional as F from torchvision import transforms as T from models import Actor, Critic, Encoder, InverseModel, ForwardModel from mario_env import create_mario_env from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Make environment env = gym.make('SuperMarioBros-1-1-v0') env = create_mario_env(env) # Models encoder = Encoder(channels=4, encoded_state_size=256).to(device) inverse_model = InverseModel(encoded_state_size=256, action_size=env.action_space.n).to(device) forward_model = ForwardModel(encoded_state_size=256, action_size=env.action_space.n).to(device) actor = Actor(encoded_state_size=256, action_size=env.action_space.n).to(device) critic = Critic(encoded_state_size=256).to(device) # Optimizers actor_optim = torch.optim.Adam(actor.parameters(), lr=0.0001) critic_optim = torch.optim.Adam(critic.parameters(), lr=0.001) icm_params = list(encoder.parameters()) + list(forward_model.parameters()) + list(inverse_model.parameters()) icm_optim = torch.optim.Adam(icm_params, lr=0.0001) # Loss functions ce = nn.CrossEntropyLoss().to(device) mse = nn.MSELoss().to(device) # Hyperparameters beta = 0.2 alpha = 100 gamma = 0.99 lamda = 0.1 # Training Parameters render = False num_episodes = 1000 # Training def train(): t = 0 for episode in range(num_episodes): observation = env.reset() total_reward = 0 done = False while not done: #env.render() state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device) action_probs = actor(state) action = action_probs.sample() action_one_hot = F.one_hot(action, num_classes=env.action_space.n).float() next_observation, reward, done, info = env.step(action.item()) next_state = torch.tensor(next_observation).to(device).unsqueeze(0) if next_observation.ndim == 3 else torch.tensor(next_observation).to(device) encoded_state = encoder(state) next_encoded_state = encoder(next_state) predicted_next_state = forward_model(encoded_state, action_one_hot) predicted_action = inverse_model(encoded_state, next_encoded_state) intrinsic_reward = alpha * mse(predicted_next_state, next_encoded_state.detach()) extrinsic_reward = torch.tensor(reward).to(device) reward = intrinsic_reward + extrinsic_reward forward_loss = mse(predicted_next_state, next_encoded_state.detach()) inverse_loss = ce(action_probs.probs,predicted_action.probs) icm_loss = beta * forward_loss + (1-beta) * inverse_loss delta = reward + gamma * (critic(next_state)*(1-done)) - critic(state) actor_loss = -(action_probs.log_prob(action) +1e-6) * delta critic_loss = delta ** 2 ac_loss = actor_loss + critic_loss loss = lamda * ac_loss + icm_loss actor_optim.zero_grad() critic_optim.zero_grad() icm_optim.zero_grad() loss.backward() actor_optim.step() critic_optim.step() icm_optim.step() observation = next_observation total_reward += reward.item() t +=1 writer.add_scalar('Loss/Actor Loss', actor_loss.item(), t) writer.add_scalar('Loss/Critic Loss', critic_loss.item(), t) writer.add_scalar('Loss/Forward Loss', forward_loss.item(), t) writer.add_scalar('Loss/Inverse Loss', inverse_loss.item(), t) writer.add_scalar('Reward/Episodic Reward', total_reward, episode) if episode % 50 == 0: torch.save(actor.state_dict(), 'saved_models/actor.pth') torch.save(critic.state_dict(), 'saved_models/critic.pth') torch.save(encoder.state_dict(), 'saved_models/encoder.pth') torch.save(inverse_model.state_dict(), 'saved_models/inverse_model.pth') torch.save(forward_model.state_dict(), 'saved_models/forward_model.pth') env.close() def test(): actor.load_state_dict(torch.load('saved_models/actor.pth')) critic.load_state_dict(torch.load('saved_models/critic.pth')) encoder.load_state_dict(torch.load('saved_models/encoder.pth')) inverse_model.load_state_dict(torch.load('saved_models/inverse_model.pth')) forward_model.load_state_dict(torch.load('saved_models/forward_model.pth')) observation = env.reset() while True: env.render() state = torch.tensor(observation).to(device).unsqueeze(0) if observation.ndim == 3 else torch.tensor(observation).to(device) action_probs = actor(state) action = action_probs.sample() observation, reward, done, info = env.step(action.item()) if done: observation = env.reset() if __name__ == '__main__': train() exit() class ICM(nn.Module): def __init__(self, state_size, action_size, inverse_model, forward_model, encoded_state_size=256): super(ICM, self).__init__() self.state_size = state_size self.action_size = action_size self.loss = nn.MSELoss().to(device) self.feature_encoder = nn.Sequential( nn.Conv2d(in_channels=self.state_size[0], out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Flatten(), nn.Linear(in_features=32*4*4, out_features=256), ).to(device) self.model = nn.Sequential( nn.Linear(self.state_size[1] + 1, 256), nn.LeakyReLU(), nn.Linear(256, encoded_state_size) ).to(device) def forward(self, state, next_state, action): if state.dim() == 3: state = state.unsqueeze(0) next_state = next_state.unsqueeze(0) encoded_state, next_encoded_state = self.feature_encoder(state), self.feature_encoder(next_state) #_, encoded_state, next_encoded_state = self.inverse_model(state, next_state) action = torch.tensor([action]).to(device) predicted_next_state = self.forward_model(encoded_state, action) intrinsic_reward = 0.5 * self.loss(predicted_next_state, next_encoded_state.detach()) return intrinsic_reward