sac_ae_if/local_dm_control_suite/utils/randomizers_test.py

165 lines
5.9 KiB
Python
Raw Normal View History

2023-05-24 17:51:34 +00:00
# Copyright 2017 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for randomizers.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Internal dependencies.
from absl.testing import absltest
from absl.testing import parameterized
from dm_control import mujoco
from dm_control.mujoco.wrapper import mjbindings
from dm_control.suite.utils import randomizers
import numpy as np
from six.moves import range
mjlib = mjbindings.mjlib
class RandomizeUnlimitedJointsTest(parameterized.TestCase):
def setUp(self):
self.rand = np.random.RandomState(100)
def test_single_joint_of_each_type(self):
physics = mujoco.Physics.from_xml_string("""<mujoco>
<default>
<joint range="0 90" />
</default>
<worldbody>
<body>
<geom type="box" size="1 1 1"/>
<joint name="free" type="free"/>
</body>
<body>
<geom type="box" size="1 1 1"/>
<joint name="limited_hinge" type="hinge" limited="true"/>
<joint name="slide" type="slide"/>
<joint name="limited_slide" type="slide" limited="true"/>
<joint name="hinge" type="hinge"/>
</body>
<body>
<geom type="box" size="1 1 1"/>
<joint name="ball" type="ball"/>
</body>
<body>
<geom type="box" size="1 1 1"/>
<joint name="limited_ball" type="ball" limited="true"/>
</body>
</worldbody>
</mujoco>""")
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
self.assertNotEqual(0., physics.named.data.qpos['hinge'])
self.assertNotEqual(0., physics.named.data.qpos['limited_hinge'])
self.assertNotEqual(0., physics.named.data.qpos['limited_slide'])
self.assertNotEqual(0., np.sum(physics.named.data.qpos['ball']))
self.assertNotEqual(0., np.sum(physics.named.data.qpos['limited_ball']))
self.assertNotEqual(0., np.sum(physics.named.data.qpos['free'][3:]))
# Unlimited slide and the positional part of the free joint remains
# uninitialized.
self.assertEqual(0., physics.named.data.qpos['slide'])
self.assertEqual(0., np.sum(physics.named.data.qpos['free'][:3]))
def test_multiple_joints_of_same_type(self):
physics = mujoco.Physics.from_xml_string("""<mujoco>
<worldbody>
<body>
<geom type="box" size="1 1 1"/>
<joint name="hinge_1" type="hinge"/>
<joint name="hinge_2" type="hinge"/>
<joint name="hinge_3" type="hinge"/>
</body>
</worldbody>
</mujoco>""")
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
self.assertNotEqual(0., physics.named.data.qpos['hinge_1'])
self.assertNotEqual(0., physics.named.data.qpos['hinge_2'])
self.assertNotEqual(0., physics.named.data.qpos['hinge_3'])
self.assertNotEqual(physics.named.data.qpos['hinge_1'],
physics.named.data.qpos['hinge_2'])
self.assertNotEqual(physics.named.data.qpos['hinge_2'],
physics.named.data.qpos['hinge_3'])
self.assertNotEqual(physics.named.data.qpos['hinge_1'],
physics.named.data.qpos['hinge_3'])
def test_unlimited_hinge_randomization_range(self):
physics = mujoco.Physics.from_xml_string("""<mujoco>
<worldbody>
<body>
<geom type="box" size="1 1 1"/>
<joint name="hinge" type="hinge"/>
</body>
</worldbody>
</mujoco>""")
for _ in range(10):
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
self.assertBetween(physics.named.data.qpos['hinge'], -np.pi, np.pi)
def test_limited_1d_joint_limits_are_respected(self):
physics = mujoco.Physics.from_xml_string("""<mujoco>
<default>
<joint limited="true"/>
</default>
<worldbody>
<body>
<geom type="box" size="1 1 1"/>
<joint name="hinge" type="hinge" range="0 10"/>
<joint name="slide" type="slide" range="30 50"/>
</body>
</worldbody>
</mujoco>""")
for _ in range(10):
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
self.assertBetween(physics.named.data.qpos['hinge'],
np.deg2rad(0), np.deg2rad(10))
self.assertBetween(physics.named.data.qpos['slide'], 30, 50)
def test_limited_ball_joint_are_respected(self):
physics = mujoco.Physics.from_xml_string("""<mujoco>
<worldbody>
<body name="body" zaxis="1 0 0">
<geom type="box" size="1 1 1"/>
<joint name="ball" type="ball" limited="true" range="0 60"/>
</body>
</worldbody>
</mujoco>""")
body_axis = np.array([1., 0., 0.])
joint_axis = np.zeros(3)
for _ in range(10):
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
quat = physics.named.data.qpos['ball']
mjlib.mju_rotVecQuat(joint_axis, body_axis, quat)
angle_cos = np.dot(body_axis, joint_axis)
self.assertGreater(angle_cos, 0.5) # cos(60) = 0.5
if __name__ == '__main__':
absltest.main()