121 lines
4.3 KiB
Python
121 lines
4.3 KiB
Python
|
# Copyright 2017 The dm_control Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ============================================================================
|
||
|
|
||
|
"""Wrapper that adds pixel observations to a control environment."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import collections
|
||
|
|
||
|
import dm_env
|
||
|
from dm_env import specs
|
||
|
|
||
|
STATE_KEY = 'state'
|
||
|
|
||
|
|
||
|
class Wrapper(dm_env.Environment):
|
||
|
"""Wraps a control environment and adds a rendered pixel observation."""
|
||
|
|
||
|
def __init__(self, env, pixels_only=True, render_kwargs=None,
|
||
|
observation_key='pixels'):
|
||
|
"""Initializes a new pixel Wrapper.
|
||
|
|
||
|
Args:
|
||
|
env: The environment to wrap.
|
||
|
pixels_only: If True (default), the original set of 'state' observations
|
||
|
returned by the wrapped environment will be discarded, and the
|
||
|
`OrderedDict` of observations will only contain pixels. If False, the
|
||
|
`OrderedDict` will contain the original observations as well as the
|
||
|
pixel observations.
|
||
|
render_kwargs: Optional `dict` containing keyword arguments passed to the
|
||
|
`mujoco.Physics.render` method.
|
||
|
observation_key: Optional custom string specifying the pixel observation's
|
||
|
key in the `OrderedDict` of observations. Defaults to 'pixels'.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If `env`'s observation spec is not compatible with the
|
||
|
wrapper. Supported formats are a single array, or a dict of arrays.
|
||
|
ValueError: If `env`'s observation already contains the specified
|
||
|
`observation_key`.
|
||
|
"""
|
||
|
if render_kwargs is None:
|
||
|
render_kwargs = {}
|
||
|
|
||
|
wrapped_observation_spec = env.observation_spec()
|
||
|
|
||
|
if isinstance(wrapped_observation_spec, specs.Array):
|
||
|
self._observation_is_dict = False
|
||
|
invalid_keys = set([STATE_KEY])
|
||
|
elif isinstance(wrapped_observation_spec, collections.MutableMapping):
|
||
|
self._observation_is_dict = True
|
||
|
invalid_keys = set(wrapped_observation_spec.keys())
|
||
|
else:
|
||
|
raise ValueError('Unsupported observation spec structure.')
|
||
|
|
||
|
if not pixels_only and observation_key in invalid_keys:
|
||
|
raise ValueError('Duplicate or reserved observation key {!r}.'
|
||
|
.format(observation_key))
|
||
|
|
||
|
if pixels_only:
|
||
|
self._observation_spec = collections.OrderedDict()
|
||
|
elif self._observation_is_dict:
|
||
|
self._observation_spec = wrapped_observation_spec.copy()
|
||
|
else:
|
||
|
self._observation_spec = collections.OrderedDict()
|
||
|
self._observation_spec[STATE_KEY] = wrapped_observation_spec
|
||
|
|
||
|
# Extend observation spec.
|
||
|
pixels = env.physics.render(**render_kwargs)
|
||
|
pixels_spec = specs.Array(
|
||
|
shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
|
||
|
self._observation_spec[observation_key] = pixels_spec
|
||
|
|
||
|
self._env = env
|
||
|
self._pixels_only = pixels_only
|
||
|
self._render_kwargs = render_kwargs
|
||
|
self._observation_key = observation_key
|
||
|
|
||
|
def reset(self):
|
||
|
time_step = self._env.reset()
|
||
|
return self._add_pixel_observation(time_step)
|
||
|
|
||
|
def step(self, action):
|
||
|
time_step = self._env.step(action)
|
||
|
return self._add_pixel_observation(time_step)
|
||
|
|
||
|
def observation_spec(self):
|
||
|
return self._observation_spec
|
||
|
|
||
|
def action_spec(self):
|
||
|
return self._env.action_spec()
|
||
|
|
||
|
def _add_pixel_observation(self, time_step):
|
||
|
if self._pixels_only:
|
||
|
observation = collections.OrderedDict()
|
||
|
elif self._observation_is_dict:
|
||
|
observation = type(time_step.observation)(time_step.observation)
|
||
|
else:
|
||
|
observation = collections.OrderedDict()
|
||
|
observation[STATE_KEY] = time_step.observation
|
||
|
|
||
|
pixels = self._env.physics.render(**self._render_kwargs)
|
||
|
observation[self._observation_key] = pixels
|
||
|
return time_step._replace(observation=observation)
|
||
|
|
||
|
def __getattr__(self, name):
|
||
|
return getattr(self._env, name)
|