# Copyright 2017 The dm_control Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Wrapper that adds pixel observations to a control environment.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import dm_env from dm_env import specs STATE_KEY = 'state' class Wrapper(dm_env.Environment): """Wraps a control environment and adds a rendered pixel observation.""" def __init__(self, env, pixels_only=True, render_kwargs=None, observation_key='pixels'): """Initializes a new pixel Wrapper. Args: env: The environment to wrap. pixels_only: If True (default), the original set of 'state' observations returned by the wrapped environment will be discarded, and the `OrderedDict` of observations will only contain pixels. If False, the `OrderedDict` will contain the original observations as well as the pixel observations. render_kwargs: Optional `dict` containing keyword arguments passed to the `mujoco.Physics.render` method. observation_key: Optional custom string specifying the pixel observation's key in the `OrderedDict` of observations. Defaults to 'pixels'. Raises: ValueError: If `env`'s observation spec is not compatible with the wrapper. Supported formats are a single array, or a dict of arrays. ValueError: If `env`'s observation already contains the specified `observation_key`. """ if render_kwargs is None: render_kwargs = {} wrapped_observation_spec = env.observation_spec() if isinstance(wrapped_observation_spec, specs.Array): self._observation_is_dict = False invalid_keys = set([STATE_KEY]) elif isinstance(wrapped_observation_spec, collections.MutableMapping): self._observation_is_dict = True invalid_keys = set(wrapped_observation_spec.keys()) else: raise ValueError('Unsupported observation spec structure.') if not pixels_only and observation_key in invalid_keys: raise ValueError('Duplicate or reserved observation key {!r}.' .format(observation_key)) if pixels_only: self._observation_spec = collections.OrderedDict() elif self._observation_is_dict: self._observation_spec = wrapped_observation_spec.copy() else: self._observation_spec = collections.OrderedDict() self._observation_spec[STATE_KEY] = wrapped_observation_spec # Extend observation spec. pixels = env.physics.render(**render_kwargs) pixels_spec = specs.Array( shape=pixels.shape, dtype=pixels.dtype, name=observation_key) self._observation_spec[observation_key] = pixels_spec self._env = env self._pixels_only = pixels_only self._render_kwargs = render_kwargs self._observation_key = observation_key def reset(self): time_step = self._env.reset() return self._add_pixel_observation(time_step) def step(self, action): time_step = self._env.step(action) return self._add_pixel_observation(time_step) def observation_spec(self): return self._observation_spec def action_spec(self): return self._env.action_spec() def _add_pixel_observation(self, time_step): if self._pixels_only: observation = collections.OrderedDict() elif self._observation_is_dict: observation = type(time_step.observation)(time_step.observation) else: observation = collections.OrderedDict() observation[STATE_KEY] = time_step.observation pixels = self._env.physics.render(**self._render_kwargs) observation[self._observation_key] = pixels return time_step._replace(observation=observation) def __getattr__(self, name): return getattr(self._env, name)