# Copyright 2017 The dm_control Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for the pixel wrapper.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections # Internal dependencies. from absl.testing import absltest from absl.testing import parameterized from local_dm_control_suite import cartpole from dm_control.suite.wrappers import pixels import dm_env from dm_env import specs import numpy as np class FakePhysics(object): def render(self, *args, **kwargs): del args del kwargs return np.zeros((4, 5, 3), dtype=np.uint8) class FakeArrayObservationEnvironment(dm_env.Environment): def __init__(self): self.physics = FakePhysics() def reset(self): return dm_env.restart(np.zeros((2,))) def step(self, action): del action return dm_env.transition(0.0, np.zeros((2,))) def action_spec(self): pass def observation_spec(self): return specs.Array(shape=(2,), dtype=np.float) class PixelsTest(parameterized.TestCase): @parameterized.parameters(True, False) def test_dict_observation(self, pixels_only): pixel_key = 'rgb' env = cartpole.swingup() # Make sure we are testing the right environment for the test. observation_spec = env.observation_spec() self.assertIsInstance(observation_spec, collections.OrderedDict) width = 320 height = 240 # The wrapper should only add one observation. wrapped = pixels.Wrapper(env, observation_key=pixel_key, pixels_only=pixels_only, render_kwargs={'width': width, 'height': height}) wrapped_observation_spec = wrapped.observation_spec() self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) if pixels_only: self.assertLen(wrapped_observation_spec, 1) self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) else: expected_length = len(observation_spec) + 1 self.assertLen(wrapped_observation_spec, expected_length) expected_keys = list(observation_spec.keys()) + [pixel_key] self.assertEqual(expected_keys, list(wrapped_observation_spec.keys())) # Check that the added spec item is consistent with the added observation. time_step = wrapped.reset() rgb_observation = time_step.observation[pixel_key] wrapped_observation_spec[pixel_key].validate(rgb_observation) self.assertEqual(rgb_observation.shape, (height, width, 3)) self.assertEqual(rgb_observation.dtype, np.uint8) @parameterized.parameters(True, False) def test_single_array_observation(self, pixels_only): pixel_key = 'depth' env = FakeArrayObservationEnvironment() observation_spec = env.observation_spec() self.assertIsInstance(observation_spec, specs.Array) wrapped = pixels.Wrapper(env, observation_key=pixel_key, pixels_only=pixels_only) wrapped_observation_spec = wrapped.observation_spec() self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) if pixels_only: self.assertLen(wrapped_observation_spec, 1) self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) else: self.assertLen(wrapped_observation_spec, 2) self.assertEqual([pixels.STATE_KEY, pixel_key], list(wrapped_observation_spec.keys())) time_step = wrapped.reset() depth_observation = time_step.observation[pixel_key] wrapped_observation_spec[pixel_key].validate(depth_observation) self.assertEqual(depth_observation.shape, (4, 5, 3)) self.assertEqual(depth_observation.dtype, np.uint8) if __name__ == '__main__': absltest.main()