diff --git a/icm mario.py b/icm mario.py new file mode 100644 index 0000000..df5d085 --- /dev/null +++ b/icm mario.py @@ -0,0 +1,406 @@ +import torch +from torch import nn +from torchvision import transforms as T +from PIL import Image +import numpy as np +from pathlib import Path +from collections import deque +import random, datetime, os, copy +from torch.distributions import Categorical +import collections +import cv2 +import torch.nn.functional as f +from torch.utils.tensorboard import SummaryWriter +writer = SummaryWriter() + +# Gym is an OpenAI toolkit for RL +import gym +from gym.spaces import Box +from gym.wrappers import FrameStack + +# NES Emulator for OpenAI Gym +from nes_py.wrappers import JoypadSpace + +# Super Mario environment for OpenAI Gym +import gym_super_mario_bros +from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class SkipFrame(gym.Wrapper): + def __init__(self, env, skip): + """Return only every `skip`-th frame""" + super().__init__(env) + self._skip = skip + + def step(self, action): + """Repeat action, and sum reward""" + total_reward = 0.0 + for i in range(self._skip): + # Accumulate reward and repeat the same action + obs, reward, done, trunk, info = self.env.step(action) + total_reward += reward + if done: + break + return obs, total_reward, done, trunk, info + + +class GrayScaleObservation(gym.ObservationWrapper): + def __init__(self, env): + super().__init__(env) + obs_shape = self.observation_space.shape[:2] + self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) + + def permute_orientation(self, observation): + # permute [H, W, C] array to [C, H, W] tensor + observation = np.transpose(observation, (2, 0, 1)) + observation = torch.tensor(observation.copy(), dtype=torch.float) + return observation + + def observation(self, observation): + observation = self.permute_orientation(observation) + transform = T.Grayscale() + observation = transform(observation) + return observation + + +class ResizeObservation(gym.ObservationWrapper): + def __init__(self, env, shape): + super().__init__(env) + if isinstance(shape, int): + self.shape = (shape, shape) + else: + self.shape = tuple(shape) + + obs_shape = self.shape + self.observation_space.shape[2:] + self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) + + def observation(self, observation): + transforms = T.Compose( + [T.Resize(self.shape), T.Normalize(0, 255)] + ) + observation = transforms(observation).squeeze(0) + return observation +class MaxAndSkipEnv(gym.Wrapper): + """ + Each action of the agent is repeated over skip frames + return only every `skip`-th frame + """ + def __init__(self, env=None, skip=4): + super(MaxAndSkipEnv, self).__init__(env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = collections.deque(maxlen=2) + self._skip = skip + + def step(self, action): + total_reward = 0.0 + done = None + for _ in range(self._skip): + obs, reward, done, info = self.env.step(action) + self._obs_buffer.append(obs) + total_reward += reward + if done: + break + max_frame = np.max(np.stack(self._obs_buffer), axis=0) + return max_frame, total_reward, done, info + + def reset(self): + """Clear past frame buffer and init to first obs""" + self._obs_buffer.clear() + obs = self.env.reset() + self._obs_buffer.append(obs) + return obs + + +class MarioRescale84x84(gym.ObservationWrapper): + """ + Downsamples/Rescales each frame to size 84x84 with greyscale + """ + def __init__(self, env=None): + super(MarioRescale84x84, self).__init__(env) + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8) + + def observation(self, obs): + return MarioRescale84x84.process(obs) + + @staticmethod + def process(frame): + if frame.size == 240 * 256 * 3: + img = np.reshape(frame, [240, 256, 3]).astype(np.float32) + else: + assert False, "Unknown resolution." + # image normalization on RBG + img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 + resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA) + x_t = resized_screen[18:102, :] + x_t = np.reshape(x_t, [84, 84, 1]) + return x_t.astype(np.uint8) + + +class ImageToPyTorch(gym.ObservationWrapper): + """ + Each frame is converted to PyTorch tensors + """ + def __init__(self, env): + super(ImageToPyTorch, self).__init__(env) + old_shape = self.observation_space.shape + self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32) + + def observation(self, observation): + return np.moveaxis(observation, 2, 0) + + +class BufferWrapper(gym.ObservationWrapper): + """ + Only every k-th frame is collected by the buffer + """ + def __init__(self, env, n_steps, dtype=np.float32): + super(BufferWrapper, self).__init__(env) + self.dtype = dtype + old_space = env.observation_space + self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0), + old_space.high.repeat(n_steps, axis=0), dtype=dtype) + + def reset(self): + self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype) + return self.observation(self.env.reset()) + + def observation(self, observation): + self.buffer[:-1] = self.buffer[1:] + self.buffer[-1] = observation + return self.buffer + + +class PixelNormalization(gym.ObservationWrapper): + """ + Normalize pixel values in frame --> 0 to 1 + """ + def observation(self, obs): + return np.array(obs).astype(np.float32) / 255.0 + + +def create_mario_env(env): + env = MaxAndSkipEnv(env) + env = MarioRescale84x84(env) + env = ImageToPyTorch(env) + env = BufferWrapper(env, 4) + env = PixelNormalization(env) + return JoypadSpace(env, COMPLEX_MOVEMENT) + +class ActorCritic(nn.Module): + def __init__(self, input_size, action_size=2): + super(ActorCritic, self).__init__() + self.input_size = input_size + self.action_size = action_size + + self.feature = nn.Sequential( + nn.Conv2d(in_channels=self.input_size[0], out_channels=32, kernel_size=8, stride=4), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=7*7*64, out_features=512), + nn.LeakyReLU(), + ) + + def actor(self,state): + policy = nn.Sequential( + nn.Linear(in_features=state.shape[1], out_features=state.shape[1]), + nn.LeakyReLU(), + nn.Linear(in_features=state.shape[1], out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + return policy(state) + + def critic(self,state): + value = nn.Sequential( + nn.Linear(in_features=state.shape[1], out_features=state.shape[1]), + nn.LeakyReLU(), + nn.Linear(in_features=state.shape[1], out_features=1) + ).to(device) + return value(state) + + + def forward(self, state): + if state.dim() == 3: + state = state.unsqueeze(0) + state = self.feature(state) + value = self.critic(state) + policy = self.actor(state) + action_probs = Categorical(policy) + log_action_probs = torch.log(action_probs.probs) + return value, action_probs, log_action_probs + + +class Encoder(nn.Module): + def __init__(self, input_size, action_size=2): + super(Encoder, self).__init__() + self.input_size = input_size[0] + self.action_size = action_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.input_size, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=256), + ).to(device) + + def forward(self, state): + if state.dim() == 3: + state = state.unsqueeze(0) + state = self.feature_encoder(state) + return state + + +class InverseModel(nn.Module): + def __init__(self, input_size, action_size=2): + super(InverseModel, self).__init__() + self.input_size = input_size[0] + self.action_size = action_size + self.feature_encoder = Encoder(input_size, action_size) + + self.model = nn.Sequential( + nn.Linear(in_features=256*2, out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + + def forward(self, state, next_state): + + encoded_state, next_encoded_state = torch.unsqueeze(state, dim=0), torch.unsqueeze(next_state, dim=0) + encoded_state, next_encoded_state = self.feature_encoder(encoded_state), self.feature_encoder(next_encoded_state) + encoded_states = torch.cat((encoded_state, next_encoded_state), dim=-1) + actions = Categorical(self.model(encoded_states)) + a = float(np.array(actions.sample().cpu())[0]) + action = torch.FloatTensor([a]) + one_hot_action = f.one_hot(action.to(torch.int64), self.action_size) + return one_hot_action, encoded_state, next_encoded_state + + +class ForwardModel(nn.Module): + def __init__(self, encoded_state_size, action_size): + super(ForwardModel, self).__init__() + self.state_size = encoded_state_size + self.action_size = action_size + + self.model = nn.Sequential( + nn.Linear(self.state_size + 1, 256), + nn.LeakyReLU(), + nn.Linear(256, encoded_state_size) + ).to(device) + + def forward(self, state, action): + if state.dim() == 3: + state = state.unsqueeze(0) + if action.dim() == 1: + action = action.unsqueeze(0) + state = torch.cat((state, action), dim=-1) + return self.model(state) + +class ICM(nn.Module): + def __init__(self, state_size, action_size, encoded_state_size=256): + super(ICM, self).__init__() + self.state_size = state_size + self.action_size = action_size + self.inverse_model = InverseModel(state_size, action_size) + self.forward_model = ForwardModel(encoded_state_size, action_size) + self.loss = nn.MSELoss().to(device) + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.state_size[0], out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=256), + ).to(device) + + self.model = nn.Sequential( + nn.Linear(self.state_size[1] + 1, 256), + nn.LeakyReLU(), + nn.Linear(256, encoded_state_size) + ).to(device) + + def forward(self, state, next_state, action): + if state.dim() == 3: + state = state.unsqueeze(0) + next_state = next_state.unsqueeze(0) + encoded_state, next_encoded_state = self.feature_encoder(state), self.feature_encoder(next_state) + #_, encoded_state, next_encoded_state = self.inverse_model(state, next_state) + action = torch.tensor([action]).to(device) + predicted_next_state = self.forward_model(encoded_state, action) + + intrinsic_reward = 0.5 * self.loss(predicted_next_state, next_encoded_state.detach()) + return intrinsic_reward + +#env = gym.make('SuperMarioBros-1-1-v0') +#env = GrayScaleObservation(env) +#env = ResizeObservation(env, shape=84) +#env = FrameStack(env, num_stack=4) +env = gym.make('SuperMarioBros-1-1-v0') +env = create_mario_env(env) + +ce = nn.CrossEntropyLoss().to(device) +mse = nn.MSELoss().to(device) +icm = ICM(env.observation_space.shape, env.action_space.n).to(device) +ac = ActorCritic(env.observation_space.shape, env.action_space.n).to(device) +optimizer = torch.optim.Adam(list(icm.parameters()) + list(ac.parameters()), lr=0.001) +done = False +t = 0 +gamma = 0.99 +for episode in range(1000): + observation = env.reset() + total_reward = 0 + t_init = t + while not done: + #env.render() + value, actions, log_action_probs = ac(torch.FloatTensor(np.array(observation)).to(device)) + action = actions.sample().item() + + next_observation, reward, done, info = env.step(action) # feedback from environment + observation_array, next_observation_array = torch.FloatTensor(np.array(observation)).to(device), torch.FloatTensor(np.array(next_observation)).to(device) + + int_reward = icm(observation_array, next_observation_array, action) + + delta = torch.squeeze(int_reward + gamma * (ac(next_observation_array)[0]*(1-int(done))) - ac(observation_array)[0]) + actor_loss = -log_action_probs[0,action] * int_reward + critic_loss = delta**2 + + reward = torch.FloatTensor([reward]).to(device) + reward = int_reward + + one_hot_action = icm.inverse_model(observation_array, next_observation_array)[0].to(device) + inverse_loss = ce(one_hot_action.float(), actions.probs) + + + loss = actor_loss + critic_loss + inverse_loss + optimizer.zero_grad() + loss.backward() + optimizer.step() + writer.add_scalar("loss", loss, t) + + + observation = next_observation + + total_reward += reward + t += 1 + #print("timestep: ", t, "reward: ", reward, "loss: ", loss) + if done: + done = False + break + writer.add_scalar("reward", total_reward/(t-t_init), t) +env.close() \ No newline at end of file