Curiosity/icm mario.py
2023-01-27 19:32:44 +01:00

406 lines
14 KiB
Python

import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
from torch.distributions import Categorical
import collections
import cv2
import torch.nn.functional as f
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
# Gym is an OpenAI toolkit for RL
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack
# NES Emulator for OpenAI Gym
from nes_py.wrappers import JoypadSpace
# Super Mario environment for OpenAI Gym
import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class SkipFrame(gym.Wrapper):
def __init__(self, env, skip):
"""Return only every `skip`-th frame"""
super().__init__(env)
self._skip = skip
def step(self, action):
"""Repeat action, and sum reward"""
total_reward = 0.0
for i in range(self._skip):
# Accumulate reward and repeat the same action
obs, reward, done, trunk, info = self.env.step(action)
total_reward += reward
if done:
break
return obs, total_reward, done, trunk, info
class GrayScaleObservation(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
obs_shape = self.observation_space.shape[:2]
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
def permute_orientation(self, observation):
# permute [H, W, C] array to [C, H, W] tensor
observation = np.transpose(observation, (2, 0, 1))
observation = torch.tensor(observation.copy(), dtype=torch.float)
return observation
def observation(self, observation):
observation = self.permute_orientation(observation)
transform = T.Grayscale()
observation = transform(observation)
return observation
class ResizeObservation(gym.ObservationWrapper):
def __init__(self, env, shape):
super().__init__(env)
if isinstance(shape, int):
self.shape = (shape, shape)
else:
self.shape = tuple(shape)
obs_shape = self.shape + self.observation_space.shape[2:]
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
def observation(self, observation):
transforms = T.Compose(
[T.Resize(self.shape), T.Normalize(0, 255)]
)
observation = transforms(observation).squeeze(0)
return observation
class MaxAndSkipEnv(gym.Wrapper):
"""
Each action of the agent is repeated over skip frames
return only every `skip`-th frame
"""
def __init__(self, env=None, skip=4):
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = collections.deque(maxlen=2)
self._skip = skip
def step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def reset(self):
"""Clear past frame buffer and init to first obs"""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class MarioRescale84x84(gym.ObservationWrapper):
"""
Downsamples/Rescales each frame to size 84x84 with greyscale
"""
def __init__(self, env=None):
super(MarioRescale84x84, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def observation(self, obs):
return MarioRescale84x84.process(obs)
@staticmethod
def process(frame):
if frame.size == 240 * 256 * 3:
img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
# image normalization on RBG
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
class ImageToPyTorch(gym.ObservationWrapper):
"""
Each frame is converted to PyTorch tensors
"""
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class BufferWrapper(gym.ObservationWrapper):
"""
Only every k-th frame is collected by the buffer
"""
def __init__(self, env, n_steps, dtype=np.float32):
super(BufferWrapper, self).__init__(env)
self.dtype = dtype
old_space = env.observation_space
self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
old_space.high.repeat(n_steps, axis=0), dtype=dtype)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
class PixelNormalization(gym.ObservationWrapper):
"""
Normalize pixel values in frame --> 0 to 1
"""
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
def create_mario_env(env):
env = MaxAndSkipEnv(env)
env = MarioRescale84x84(env)
env = ImageToPyTorch(env)
env = BufferWrapper(env, 4)
env = PixelNormalization(env)
return JoypadSpace(env, COMPLEX_MOVEMENT)
class ActorCritic(nn.Module):
def __init__(self, input_size, action_size=2):
super(ActorCritic, self).__init__()
self.input_size = input_size
self.action_size = action_size
self.feature = nn.Sequential(
nn.Conv2d(in_channels=self.input_size[0], out_channels=32, kernel_size=8, stride=4),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=7*7*64, out_features=512),
nn.LeakyReLU(),
)
def actor(self,state):
policy = nn.Sequential(
nn.Linear(in_features=state.shape[1], out_features=state.shape[1]),
nn.LeakyReLU(),
nn.Linear(in_features=state.shape[1], out_features=self.action_size),
nn.Softmax(dim=-1)
).to(device)
return policy(state)
def critic(self,state):
value = nn.Sequential(
nn.Linear(in_features=state.shape[1], out_features=state.shape[1]),
nn.LeakyReLU(),
nn.Linear(in_features=state.shape[1], out_features=1)
).to(device)
return value(state)
def forward(self, state):
if state.dim() == 3:
state = state.unsqueeze(0)
state = self.feature(state)
value = self.critic(state)
policy = self.actor(state)
action_probs = Categorical(policy)
log_action_probs = torch.log(action_probs.probs)
return value, action_probs, log_action_probs
class Encoder(nn.Module):
def __init__(self, input_size, action_size=2):
super(Encoder, self).__init__()
self.input_size = input_size[0]
self.action_size = action_size
self.feature_encoder = nn.Sequential(
nn.Conv2d(in_channels=self.input_size, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=32*4*4, out_features=256),
).to(device)
def forward(self, state):
if state.dim() == 3:
state = state.unsqueeze(0)
state = self.feature_encoder(state)
return state
class InverseModel(nn.Module):
def __init__(self, input_size, action_size=2):
super(InverseModel, self).__init__()
self.input_size = input_size[0]
self.action_size = action_size
self.feature_encoder = Encoder(input_size, action_size)
self.model = nn.Sequential(
nn.Linear(in_features=256*2, out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=self.action_size),
nn.Softmax(dim=-1)
).to(device)
def forward(self, state, next_state):
encoded_state, next_encoded_state = torch.unsqueeze(state, dim=0), torch.unsqueeze(next_state, dim=0)
encoded_state, next_encoded_state = self.feature_encoder(encoded_state), self.feature_encoder(next_encoded_state)
encoded_states = torch.cat((encoded_state, next_encoded_state), dim=-1)
actions = Categorical(self.model(encoded_states))
a = float(np.array(actions.sample().cpu())[0])
action = torch.FloatTensor([a])
one_hot_action = f.one_hot(action.to(torch.int64), self.action_size)
return one_hot_action, encoded_state, next_encoded_state
class ForwardModel(nn.Module):
def __init__(self, encoded_state_size, action_size):
super(ForwardModel, self).__init__()
self.state_size = encoded_state_size
self.action_size = action_size
self.model = nn.Sequential(
nn.Linear(self.state_size + 1, 256),
nn.LeakyReLU(),
nn.Linear(256, encoded_state_size)
).to(device)
def forward(self, state, action):
if state.dim() == 3:
state = state.unsqueeze(0)
if action.dim() == 1:
action = action.unsqueeze(0)
state = torch.cat((state, action), dim=-1)
return self.model(state)
class ICM(nn.Module):
def __init__(self, state_size, action_size, encoded_state_size=256):
super(ICM, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.inverse_model = InverseModel(state_size, action_size)
self.forward_model = ForwardModel(encoded_state_size, action_size)
self.loss = nn.MSELoss().to(device)
self.feature_encoder = nn.Sequential(
nn.Conv2d(in_channels=self.state_size[0], out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=32*4*4, out_features=256),
).to(device)
self.model = nn.Sequential(
nn.Linear(self.state_size[1] + 1, 256),
nn.LeakyReLU(),
nn.Linear(256, encoded_state_size)
).to(device)
def forward(self, state, next_state, action):
if state.dim() == 3:
state = state.unsqueeze(0)
next_state = next_state.unsqueeze(0)
encoded_state, next_encoded_state = self.feature_encoder(state), self.feature_encoder(next_state)
#_, encoded_state, next_encoded_state = self.inverse_model(state, next_state)
action = torch.tensor([action]).to(device)
predicted_next_state = self.forward_model(encoded_state, action)
intrinsic_reward = 0.5 * self.loss(predicted_next_state, next_encoded_state.detach())
return intrinsic_reward
#env = gym.make('SuperMarioBros-1-1-v0')
#env = GrayScaleObservation(env)
#env = ResizeObservation(env, shape=84)
#env = FrameStack(env, num_stack=4)
env = gym.make('SuperMarioBros-1-1-v0')
env = create_mario_env(env)
ce = nn.CrossEntropyLoss().to(device)
mse = nn.MSELoss().to(device)
icm = ICM(env.observation_space.shape, env.action_space.n).to(device)
ac = ActorCritic(env.observation_space.shape, env.action_space.n).to(device)
optimizer = torch.optim.Adam(list(icm.parameters()) + list(ac.parameters()), lr=0.001)
done = False
t = 0
gamma = 0.99
for episode in range(1000):
observation = env.reset()
total_reward = 0
t_init = t
while not done:
#env.render()
value, actions, log_action_probs = ac(torch.FloatTensor(np.array(observation)).to(device))
action = actions.sample().item()
next_observation, reward, done, info = env.step(action) # feedback from environment
observation_array, next_observation_array = torch.FloatTensor(np.array(observation)).to(device), torch.FloatTensor(np.array(next_observation)).to(device)
int_reward = icm(observation_array, next_observation_array, action)
delta = torch.squeeze(int_reward + gamma * (ac(next_observation_array)[0]*(1-int(done))) - ac(observation_array)[0])
actor_loss = -log_action_probs[0,action] * int_reward
critic_loss = delta**2
reward = torch.FloatTensor([reward]).to(device)
reward = int_reward
one_hot_action = icm.inverse_model(observation_array, next_observation_array)[0].to(device)
inverse_loss = ce(one_hot_action.float(), actions.probs)
loss = actor_loss + critic_loss + inverse_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.add_scalar("loss", loss, t)
observation = next_observation
total_reward += reward
t += 1
#print("timestep: ", t, "reward: ", reward, "loss: ", loss)
if done:
done = False
break
writer.add_scalar("reward", total_reward/(t-t_init), t)
env.close()