541 lines
17 KiB
Python
541 lines
17 KiB
Python
import atexit
|
|
import functools
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
|
|
import gym
|
|
import numpy as np
|
|
from PIL import Image
|
|
from collections import deque
|
|
|
|
from numpy.core import overrides
|
|
|
|
|
|
class DMC2GYMWrapper:
|
|
|
|
def __init__(self, env):
|
|
self._env = env
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
@property
|
|
def observation_space(self):
|
|
spaces = {}
|
|
spaces['image'] = gym.spaces.Box(
|
|
0, 255, (self._env._height, self._env._width, 3,), dtype=np.uint8)
|
|
return gym.spaces.Dict(spaces)
|
|
|
|
def step(self, action):
|
|
image, reward, done, info = self._env.step(action)
|
|
obs = {'image': image}
|
|
return obs, reward, done, info
|
|
|
|
def reset(self):
|
|
image = self._env.reset()
|
|
obs = {'image': image}
|
|
return obs
|
|
|
|
|
|
class DeepMindControl:
|
|
|
|
def __init__(self, name, size=(64, 64), camera=None):
|
|
domain, task = name.split('_', 1)
|
|
if domain == 'cup': # Only domain with multiple words.
|
|
domain = 'ball_in_cup'
|
|
if isinstance(domain, str):
|
|
from dm_control import suite
|
|
self._env = suite.load(domain, task)
|
|
else:
|
|
assert task is None
|
|
self._env = domain()
|
|
self._size = size
|
|
if camera is None:
|
|
camera = dict(quadruped=2).get(domain, 0)
|
|
self._camera = camera
|
|
|
|
@property
|
|
def observation_space(self):
|
|
spaces = {}
|
|
for key, value in self._env.observation_spec().items():
|
|
spaces[key] = gym.spaces.Box(
|
|
-np.inf, np.inf, value.shape, dtype=np.float32)
|
|
spaces['image'] = gym.spaces.Box(
|
|
0, 255, self._size + (3,), dtype=np.uint8)
|
|
return gym.spaces.Dict(spaces)
|
|
|
|
@property
|
|
def action_space(self):
|
|
spec = self._env.action_spec()
|
|
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
|
|
|
|
def step(self, action):
|
|
time_step = self._env.step(action)
|
|
obs = dict(time_step.observation)
|
|
obs['image'] = self.render()
|
|
reward = time_step.reward or 0
|
|
done = time_step.last()
|
|
info = {'discount': np.array(time_step.discount, np.float32)}
|
|
return obs, reward, done, info
|
|
|
|
def reset(self):
|
|
time_step = self._env.reset()
|
|
obs = dict(time_step.observation)
|
|
obs['image'] = self.render()
|
|
return obs
|
|
|
|
def render(self, *args, **kwargs):
|
|
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
|
|
raise ValueError("Only render mode 'rgb_array' is supported.")
|
|
return self._env.physics.render(*self._size, camera_id=self._camera)
|
|
|
|
|
|
class Atari:
|
|
|
|
LOCK = threading.Lock()
|
|
|
|
def __init__(
|
|
self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30,
|
|
life_done=False, sticky_actions=True):
|
|
import gym
|
|
version = 0 if sticky_actions else 4
|
|
name = ''.join(word.title() for word in name.split('_'))
|
|
with self.LOCK:
|
|
self._env = gym.make('{}NoFrameskip-v{}'.format(name, version))
|
|
self._action_repeat = action_repeat
|
|
self._size = size
|
|
self._grayscale = grayscale
|
|
self._noops = noops
|
|
self._life_done = life_done
|
|
self._lives = None
|
|
shape = self._env.observation_space.shape[:2] + \
|
|
(() if grayscale else (3,))
|
|
self._buffers = [np.empty(shape, dtype=np.uint8) for _ in range(2)]
|
|
self._random = np.random.RandomState(seed=None)
|
|
|
|
@property
|
|
def observation_space(self):
|
|
shape = self._size + (1 if self._grayscale else 3,)
|
|
space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
|
|
return gym.spaces.Dict({'image': space})
|
|
|
|
@property
|
|
def action_space(self):
|
|
return self._env.action_space
|
|
|
|
def close(self):
|
|
return self._env.close()
|
|
|
|
def reset(self):
|
|
with self.LOCK:
|
|
self._env.reset()
|
|
noops = self._random.randint(1, self._noops + 1)
|
|
for _ in range(noops):
|
|
done = self._env.step(0)[2]
|
|
if done:
|
|
with self.LOCK:
|
|
self._env.reset()
|
|
self._lives = self._env.ale.lives()
|
|
if self._grayscale:
|
|
self._env.ale.getScreenGrayscale(self._buffers[0])
|
|
else:
|
|
self._env.ale.getScreenRGB2(self._buffers[0])
|
|
self._buffers[1].fill(0)
|
|
return self._get_obs()
|
|
|
|
def step(self, action):
|
|
total_reward = 0.0
|
|
for step in range(self._action_repeat):
|
|
_, reward, done, info = self._env.step(action)
|
|
total_reward += reward
|
|
if self._life_done:
|
|
lives = self._env.ale.lives()
|
|
done = done or lives < self._lives
|
|
self._lives = lives
|
|
if done:
|
|
break
|
|
elif step >= self._action_repeat - 2:
|
|
index = step - (self._action_repeat - 2)
|
|
if self._grayscale:
|
|
self._env.ale.getScreenGrayscale(self._buffers[index])
|
|
else:
|
|
self._env.ale.getScreenRGB2(self._buffers[index])
|
|
obs = self._get_obs()
|
|
return obs, total_reward, done, info
|
|
|
|
def render(self, mode):
|
|
return self._env.render(mode)
|
|
|
|
def _get_obs(self):
|
|
if self._action_repeat > 1:
|
|
np.maximum(self._buffers[0],
|
|
self._buffers[1], out=self._buffers[0])
|
|
image = np.array(Image.fromarray(self._buffers[0]).resize(
|
|
self._size, Image.BILINEAR))
|
|
image = np.clip(image, 0, 255).astype(np.uint8)
|
|
image = image[:, :, None] if self._grayscale else image
|
|
return {'image': image}
|
|
|
|
|
|
class Collect:
|
|
|
|
def __init__(self, env, callbacks=None, precision=32):
|
|
self._env = env
|
|
self._callbacks = callbacks or ()
|
|
self._precision = precision
|
|
self._episode = None
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
def step(self, action):
|
|
obs, reward, done, info = self._env.step(action)
|
|
obs = {k: self._convert(v) for k, v in obs.items()}
|
|
transition = obs.copy()
|
|
transition['action'] = action
|
|
transition['reward'] = reward
|
|
transition['discount'] = info.get(
|
|
'discount', np.array(1 - float(done)))
|
|
self._episode.append(transition)
|
|
if done:
|
|
episode = {k: [t[k] for t in self._episode]
|
|
for k in self._episode[0]}
|
|
episode = {k: self._convert(v) for k, v in episode.items()}
|
|
info['episode'] = episode
|
|
for callback in self._callbacks:
|
|
callback(episode)
|
|
return obs, reward, done, info
|
|
|
|
def reset(self):
|
|
obs = self._env.reset()
|
|
transition = obs.copy()
|
|
transition['action'] = np.zeros(self._env.action_space.shape)
|
|
transition['reward'] = 0.0
|
|
transition['discount'] = 1.0
|
|
self._episode = [transition]
|
|
return obs
|
|
|
|
def _convert(self, value):
|
|
value = np.array(value)
|
|
if np.issubdtype(value.dtype, np.floating):
|
|
dtype = {16: np.float16, 32: np.float32,
|
|
64: np.float64}[self._precision]
|
|
elif np.issubdtype(value.dtype, np.signedinteger):
|
|
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
|
elif np.issubdtype(value.dtype, np.uint8):
|
|
dtype = np.uint8
|
|
else:
|
|
raise NotImplementedError(value.dtype)
|
|
return value.astype(dtype)
|
|
|
|
|
|
class TimeLimit:
|
|
|
|
def __init__(self, env, duration):
|
|
self._env = env
|
|
self._duration = duration
|
|
self._step = None
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
def step(self, action):
|
|
assert self._step is not None, 'Must reset environment.'
|
|
obs, reward, done, info = self._env.step(action)
|
|
self._step += 1
|
|
if self._step >= self._duration:
|
|
done = True
|
|
if 'discount' not in info:
|
|
info['discount'] = np.array(1.0).astype(np.float32)
|
|
self._step = None
|
|
return obs, reward, done, info
|
|
|
|
def reset(self):
|
|
self._step = 0
|
|
return self._env.reset()
|
|
|
|
|
|
class ActionRepeat:
|
|
|
|
def __init__(self, env, amount):
|
|
self._env = env
|
|
self._amount = amount
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
def step(self, action):
|
|
done = False
|
|
total_reward = 0
|
|
current_step = 0
|
|
while current_step < self._amount and not done:
|
|
obs, reward, done, info = self._env.step(action)
|
|
total_reward += reward
|
|
current_step += 1
|
|
return obs, total_reward, done, info
|
|
|
|
|
|
class NormalizeActions:
|
|
|
|
def __init__(self, env):
|
|
self._env = env
|
|
self._mask = np.logical_and(
|
|
np.isfinite(env.action_space.low),
|
|
np.isfinite(env.action_space.high))
|
|
self._low = np.where(self._mask, env.action_space.low, -1)
|
|
self._high = np.where(self._mask, env.action_space.high, 1)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
@property
|
|
def action_space(self):
|
|
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
|
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
|
return gym.spaces.Box(low, high, dtype=np.float32)
|
|
|
|
def step(self, action):
|
|
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
|
original = np.where(self._mask, original, action)
|
|
return self._env.step(original)
|
|
|
|
|
|
class ObsDict:
|
|
|
|
def __init__(self, env, key='obs'):
|
|
self._env = env
|
|
self._key = key
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
@property
|
|
def observation_space(self):
|
|
spaces = {self._key: self._env.observation_space}
|
|
return gym.spaces.Dict(spaces)
|
|
|
|
@property
|
|
def action_space(self):
|
|
return self._env.action_space
|
|
|
|
def step(self, action):
|
|
obs, reward, done, info = self._env.step(action)
|
|
obs = {self._key: np.array(obs)}
|
|
return obs, reward, done, info
|
|
|
|
def reset(self):
|
|
obs = self._env.reset()
|
|
obs = {self._key: np.array(obs)}
|
|
return obs
|
|
|
|
|
|
class OneHotAction:
|
|
|
|
def __init__(self, env):
|
|
assert isinstance(env.action_space, gym.spaces.Discrete)
|
|
self._env = env
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
@property
|
|
def action_space(self):
|
|
shape = (self._env.action_space.n,)
|
|
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
|
space.sample = self._sample_action
|
|
return space
|
|
|
|
def step(self, action):
|
|
index = np.argmax(action).astype(int)
|
|
reference = np.zeros_like(action)
|
|
reference[index] = 1
|
|
if not np.allclose(reference, action):
|
|
raise ValueError(f'Invalid one-hot action:\n{action}')
|
|
return self._env.step(index)
|
|
|
|
def reset(self):
|
|
return self._env.reset()
|
|
|
|
def _sample_action(self):
|
|
actions = self._env.action_space.n
|
|
index = self._random.randint(0, actions)
|
|
reference = np.zeros(actions, dtype=np.float32)
|
|
reference[index] = 1.0
|
|
return reference
|
|
|
|
|
|
class RewardObs:
|
|
|
|
def __init__(self, env):
|
|
self._env = env
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
@property
|
|
def observation_space(self):
|
|
spaces = self._env.observation_space.spaces
|
|
assert 'reward' not in spaces
|
|
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
|
|
return gym.spaces.Dict(spaces)
|
|
|
|
def step(self, action):
|
|
obs, reward, done, info = self._env.step(action)
|
|
obs['reward'] = reward
|
|
return obs, reward, done, info
|
|
|
|
def reset(self):
|
|
obs = self._env.reset()
|
|
obs['reward'] = 0.0
|
|
return obs
|
|
|
|
|
|
class Async:
|
|
|
|
_ACCESS = 1
|
|
_CALL = 2
|
|
_RESULT = 3
|
|
_EXCEPTION = 4
|
|
_CLOSE = 5
|
|
|
|
def __init__(self, ctor, strategy='process'):
|
|
self._strategy = strategy
|
|
if strategy == 'none':
|
|
self._env = ctor()
|
|
elif strategy == 'thread':
|
|
import multiprocessing.dummy as mp
|
|
elif strategy == 'process':
|
|
import multiprocessing as mp
|
|
else:
|
|
raise NotImplementedError(strategy)
|
|
if strategy != 'none':
|
|
self._conn, conn = mp.Pipe()
|
|
self._process = mp.Process(target=self._worker, args=(ctor, conn))
|
|
atexit.register(self.close)
|
|
self._process.start()
|
|
self._obs_space = None
|
|
self._action_space = None
|
|
|
|
@property
|
|
def observation_space(self):
|
|
if not self._obs_space:
|
|
self._obs_space = self.__getattr__('observation_space')
|
|
return self._obs_space
|
|
|
|
@property
|
|
def action_space(self):
|
|
if not self._action_space:
|
|
self._action_space = self.__getattr__('action_space')
|
|
return self._action_space
|
|
|
|
def __getattr__(self, name):
|
|
if self._strategy == 'none':
|
|
return getattr(self._env, name)
|
|
self._conn.send((self._ACCESS, name))
|
|
return self._receive()
|
|
|
|
def call(self, name, *args, **kwargs):
|
|
blocking = kwargs.pop('blocking', True)
|
|
if self._strategy == 'none':
|
|
return functools.partial(getattr(self._env, name), *args, **kwargs)
|
|
payload = name, args, kwargs
|
|
self._conn.send((self._CALL, payload))
|
|
promise = self._receive
|
|
return promise() if blocking else promise
|
|
|
|
def close(self):
|
|
if self._strategy == 'none':
|
|
try:
|
|
self._env.close()
|
|
except AttributeError:
|
|
pass
|
|
return
|
|
try:
|
|
self._conn.send((self._CLOSE, None))
|
|
self._conn.close()
|
|
except IOError:
|
|
# The connection was already closed.
|
|
pass
|
|
self._process.join()
|
|
|
|
def step(self, action, blocking=True):
|
|
return self.call('step', action, blocking=blocking)
|
|
|
|
def reset(self, blocking=True):
|
|
return self.call('reset', blocking=blocking)
|
|
|
|
def _receive(self):
|
|
try:
|
|
message, payload = self._conn.recv()
|
|
except ConnectionResetError:
|
|
raise RuntimeError('Environment worker crashed.')
|
|
# Re-raise exceptions in the main process.
|
|
if message == self._EXCEPTION:
|
|
stacktrace = payload
|
|
raise Exception(stacktrace)
|
|
if message == self._RESULT:
|
|
return payload
|
|
raise KeyError(f'Received message of unexpected type {message}')
|
|
|
|
def _worker(self, ctor, conn):
|
|
try:
|
|
env = ctor()
|
|
while True:
|
|
try:
|
|
# Only block for short times to have keyboard exceptions be raised.
|
|
if not conn.poll(0.1):
|
|
continue
|
|
message, payload = conn.recv()
|
|
except (EOFError, KeyboardInterrupt):
|
|
break
|
|
if message == self._ACCESS:
|
|
name = payload
|
|
result = getattr(env, name)
|
|
conn.send((self._RESULT, result))
|
|
continue
|
|
if message == self._CALL:
|
|
name, args, kwargs = payload
|
|
result = getattr(env, name)(*args, **kwargs)
|
|
conn.send((self._RESULT, result))
|
|
continue
|
|
if message == self._CLOSE:
|
|
assert payload is None
|
|
break
|
|
raise KeyError(f'Received message of unknown type {message}')
|
|
except Exception:
|
|
stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
|
|
print(f'Error in environment process: {stacktrace}')
|
|
conn.send((self._EXCEPTION, stacktrace))
|
|
conn.close()
|
|
|
|
|
|
class FrameStack(gym.Wrapper):
|
|
def __init__(self, env, k):
|
|
gym.Wrapper.__init__(self, env)
|
|
self._k = k
|
|
self._frames = deque([], maxlen=k)
|
|
shp = env.observation_space.shape
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0,
|
|
high=1,
|
|
shape=((shp[0] * k,) + shp[1:]),
|
|
dtype=env.observation_space.dtype
|
|
)
|
|
self._max_episode_steps = env._max_episode_steps
|
|
|
|
def reset(self):
|
|
obs = self.env.reset()
|
|
for _ in range(self._k):
|
|
self._frames.append(obs)
|
|
return self._get_obs()
|
|
|
|
def step(self, action):
|
|
obs, reward, done, info = self.env.step(action)
|
|
self._frames.append(obs)
|
|
return self._get_obs(), reward, done, info
|
|
|
|
def _get_obs(self):
|
|
assert len(self._frames) == self._k
|
|
return np.concatenate(list(self._frames), axis=0)
|