175 lines
5.7 KiB
Python
175 lines
5.7 KiB
Python
|
import cv2
|
||
|
import numpy as np
|
||
|
import collections
|
||
|
|
||
|
import gym
|
||
|
from gym.spaces import Box
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torchvision import transforms as T
|
||
|
|
||
|
import gym_super_mario_bros
|
||
|
from nes_py.wrappers import JoypadSpace
|
||
|
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
|
||
|
|
||
|
|
||
|
class SkipFrame(gym.Wrapper):
|
||
|
def __init__(self, env, skip):
|
||
|
"""Return only every `skip`-th frame"""
|
||
|
super().__init__(env)
|
||
|
self._skip = skip
|
||
|
|
||
|
def step(self, action):
|
||
|
"""Repeat action, and sum reward"""
|
||
|
total_reward = 0.0
|
||
|
for i in range(self._skip):
|
||
|
# Accumulate reward and repeat the same action
|
||
|
obs, reward, done, trunk, info = self.env.step(action)
|
||
|
total_reward += reward
|
||
|
if done:
|
||
|
break
|
||
|
return obs, total_reward, done, trunk, info
|
||
|
|
||
|
|
||
|
class GrayScaleObservation(gym.ObservationWrapper):
|
||
|
def __init__(self, env):
|
||
|
super().__init__(env)
|
||
|
obs_shape = self.observation_space.shape[:2]
|
||
|
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
||
|
|
||
|
def permute_orientation(self, observation):
|
||
|
# permute [H, W, C] array to [C, H, W] tensor
|
||
|
observation = np.transpose(observation, (2, 0, 1))
|
||
|
observation = torch.tensor(observation.copy(), dtype=torch.float)
|
||
|
return observation
|
||
|
|
||
|
def observation(self, observation):
|
||
|
observation = self.permute_orientation(observation)
|
||
|
transform = T.Grayscale()
|
||
|
observation = transform(observation)
|
||
|
return observation
|
||
|
|
||
|
|
||
|
class ResizeObservation(gym.ObservationWrapper):
|
||
|
def __init__(self, env, shape):
|
||
|
super().__init__(env)
|
||
|
if isinstance(shape, int):
|
||
|
self.shape = (shape, shape)
|
||
|
else:
|
||
|
self.shape = tuple(shape)
|
||
|
|
||
|
obs_shape = self.shape + self.observation_space.shape[2:]
|
||
|
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
||
|
|
||
|
def observation(self, observation):
|
||
|
transforms = T.Compose(
|
||
|
[T.Resize(self.shape), T.Normalize(0, 255)]
|
||
|
)
|
||
|
observation = transforms(observation).squeeze(0)
|
||
|
return observation
|
||
|
class MaxAndSkipEnv(gym.Wrapper):
|
||
|
"""
|
||
|
Each action of the agent is repeated over skip frames
|
||
|
return only every `skip`-th frame
|
||
|
"""
|
||
|
def __init__(self, env=None, skip=4):
|
||
|
super(MaxAndSkipEnv, self).__init__(env)
|
||
|
# most recent raw observations (for max pooling across time steps)
|
||
|
self._obs_buffer = collections.deque(maxlen=2)
|
||
|
self._skip = skip
|
||
|
|
||
|
def step(self, action):
|
||
|
total_reward = 0.0
|
||
|
done = None
|
||
|
for _ in range(self._skip):
|
||
|
obs, reward, done, info = self.env.step(action)
|
||
|
self._obs_buffer.append(obs)
|
||
|
total_reward += reward
|
||
|
if done:
|
||
|
break
|
||
|
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||
|
return max_frame, total_reward, done, info
|
||
|
|
||
|
def reset(self):
|
||
|
"""Clear past frame buffer and init to first obs"""
|
||
|
self._obs_buffer.clear()
|
||
|
obs = self.env.reset()
|
||
|
self._obs_buffer.append(obs)
|
||
|
return obs
|
||
|
|
||
|
|
||
|
class MarioRescale84x84(gym.ObservationWrapper):
|
||
|
"""
|
||
|
Downsamples/Rescales each frame to size 84x84 with greyscale
|
||
|
"""
|
||
|
def __init__(self, env=None):
|
||
|
super(MarioRescale84x84, self).__init__(env)
|
||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||
|
|
||
|
def observation(self, obs):
|
||
|
return MarioRescale84x84.process(obs)
|
||
|
|
||
|
@staticmethod
|
||
|
def process(frame):
|
||
|
if frame.size == 240 * 256 * 3:
|
||
|
img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
|
||
|
else:
|
||
|
assert False, "Unknown resolution."
|
||
|
# image normalization on RBG
|
||
|
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
|
||
|
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
|
||
|
x_t = resized_screen[18:102, :]
|
||
|
x_t = np.reshape(x_t, [84, 84, 1])
|
||
|
return x_t.astype(np.uint8)
|
||
|
|
||
|
|
||
|
class ImageToPyTorch(gym.ObservationWrapper):
|
||
|
"""
|
||
|
Each frame is converted to PyTorch tensors
|
||
|
"""
|
||
|
def __init__(self, env):
|
||
|
super(ImageToPyTorch, self).__init__(env)
|
||
|
old_shape = self.observation_space.shape
|
||
|
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
||
|
|
||
|
def observation(self, observation):
|
||
|
return np.moveaxis(observation, 2, 0)
|
||
|
|
||
|
|
||
|
class BufferWrapper(gym.ObservationWrapper):
|
||
|
"""
|
||
|
Only every k-th frame is collected by the buffer
|
||
|
"""
|
||
|
def __init__(self, env, n_steps, dtype=np.float32):
|
||
|
super(BufferWrapper, self).__init__(env)
|
||
|
self.dtype = dtype
|
||
|
old_space = env.observation_space
|
||
|
self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
|
||
|
old_space.high.repeat(n_steps, axis=0), dtype=dtype)
|
||
|
|
||
|
def reset(self):
|
||
|
self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
|
||
|
return self.observation(self.env.reset())
|
||
|
|
||
|
def observation(self, observation):
|
||
|
self.buffer[:-1] = self.buffer[1:]
|
||
|
self.buffer[-1] = observation
|
||
|
return self.buffer
|
||
|
|
||
|
|
||
|
class PixelNormalization(gym.ObservationWrapper):
|
||
|
"""
|
||
|
Normalize pixel values in frame --> 0 to 1
|
||
|
"""
|
||
|
def observation(self, obs):
|
||
|
return np.array(obs).astype(np.float32) / 255.0
|
||
|
|
||
|
|
||
|
def create_mario_env(env):
|
||
|
env = MaxAndSkipEnv(env)
|
||
|
env = MarioRescale84x84(env)
|
||
|
env = ImageToPyTorch(env)
|
||
|
env = BufferWrapper(env, 4)
|
||
|
env = PixelNormalization(env)
|
||
|
return JoypadSpace(env, COMPLEX_MOVEMENT)
|