import cv2 import numpy as np import collections import gym from gym.spaces import Box import torch import torch.nn.functional as F from torchvision import transforms as T import gym_super_mario_bros from nes_py.wrappers import JoypadSpace from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT class SkipFrame(gym.Wrapper): def __init__(self, env, skip): """Return only every `skip`-th frame""" super().__init__(env) self._skip = skip def step(self, action): """Repeat action, and sum reward""" total_reward = 0.0 for i in range(self._skip): # Accumulate reward and repeat the same action obs, reward, done, trunk, info = self.env.step(action) total_reward += reward if done: break return obs, total_reward, done, trunk, info class GrayScaleObservation(gym.ObservationWrapper): def __init__(self, env): super().__init__(env) obs_shape = self.observation_space.shape[:2] self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) def permute_orientation(self, observation): # permute [H, W, C] array to [C, H, W] tensor observation = np.transpose(observation, (2, 0, 1)) observation = torch.tensor(observation.copy(), dtype=torch.float) return observation def observation(self, observation): observation = self.permute_orientation(observation) transform = T.Grayscale() observation = transform(observation) return observation class ResizeObservation(gym.ObservationWrapper): def __init__(self, env, shape): super().__init__(env) if isinstance(shape, int): self.shape = (shape, shape) else: self.shape = tuple(shape) obs_shape = self.shape + self.observation_space.shape[2:] self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) def observation(self, observation): transforms = T.Compose( [T.Resize(self.shape), T.Normalize(0, 255)] ) observation = transforms(observation).squeeze(0) return observation class MaxAndSkipEnv(gym.Wrapper): """ Each action of the agent is repeated over skip frames return only every `skip`-th frame """ def __init__(self, env=None, skip=4): super(MaxAndSkipEnv, self).__init__(env) # most recent raw observations (for max pooling across time steps) self._obs_buffer = collections.deque(maxlen=2) self._skip = skip def step(self, action): total_reward = 0.0 done = None for _ in range(self._skip): obs, reward, done, info = self.env.step(action) self._obs_buffer.append(obs) total_reward += reward if done: break max_frame = np.max(np.stack(self._obs_buffer), axis=0) return max_frame, total_reward, done, info def reset(self): """Clear past frame buffer and init to first obs""" self._obs_buffer.clear() obs = self.env.reset() self._obs_buffer.append(obs) return obs class MarioRescale84x84(gym.ObservationWrapper): """ Downsamples/Rescales each frame to size 84x84 with greyscale """ def __init__(self, env=None): super(MarioRescale84x84, self).__init__(env) self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8) def observation(self, obs): return MarioRescale84x84.process(obs) @staticmethod def process(frame): if frame.size == 240 * 256 * 3: img = np.reshape(frame, [240, 256, 3]).astype(np.float32) else: assert False, "Unknown resolution." # image normalization on RBG img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA) x_t = resized_screen[18:102, :] x_t = np.reshape(x_t, [84, 84, 1]) return x_t.astype(np.uint8) class ImageToPyTorch(gym.ObservationWrapper): """ Each frame is converted to PyTorch tensors """ def __init__(self, env): super(ImageToPyTorch, self).__init__(env) old_shape = self.observation_space.shape self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32) def observation(self, observation): return np.moveaxis(observation, 2, 0) class BufferWrapper(gym.ObservationWrapper): """ Only every k-th frame is collected by the buffer """ def __init__(self, env, n_steps, dtype=np.float32): super(BufferWrapper, self).__init__(env) self.dtype = dtype old_space = env.observation_space self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0), old_space.high.repeat(n_steps, axis=0), dtype=dtype) def reset(self): self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype) return self.observation(self.env.reset()) def observation(self, observation): self.buffer[:-1] = self.buffer[1:] self.buffer[-1] = observation return self.buffer class PixelNormalization(gym.ObservationWrapper): """ Normalize pixel values in frame --> 0 to 1 """ def observation(self, obs): return np.array(obs).astype(np.float32) / 255.0 def create_mario_env(env): env = MaxAndSkipEnv(env) env = MarioRescale84x84(env) env = ImageToPyTorch(env) env = BufferWrapper(env, 4) env = PixelNormalization(env) return JoypadSpace(env, COMPLEX_MOVEMENT)