sac_ae_if/local_dm_control_suite/utils/randomizers.py
2023-05-24 19:51:34 +02:00

92 lines
3.1 KiB
Python
Executable File

# Copyright 2017 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Randomization functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dm_control.mujoco.wrapper import mjbindings
import numpy as np
from six.moves import range
def random_limited_quaternion(random, limit):
"""Generates a random quaternion limited to the specified rotations."""
axis = random.randn(3)
axis /= np.linalg.norm(axis)
angle = random.rand() * limit
quaternion = np.zeros(4)
mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
return quaternion
def randomize_limited_and_rotational_joints(physics, random=None):
"""Randomizes the positions of joints defined in the physics body.
The following randomization rules apply:
- Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
- Unbounded hinges are samples uniformly in [-pi, pi]
- Quaternions for unlimited free joints and ball joints are sampled
uniformly on the unit 3-sphere.
- Quaternions for limited ball joints are sampled uniformly on a sector
of the unit 3-sphere.
- The linear degrees of freedom of free joints are not randomized.
Args:
physics: Instance of 'Physics' class that holds a loaded model.
random: Optional instance of 'np.random.RandomState'. Defaults to the global
NumPy random state.
"""
random = random or np.random
hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
ball = mjbindings.enums.mjtJoint.mjJNT_BALL
free = mjbindings.enums.mjtJoint.mjJNT_FREE
qpos = physics.named.data.qpos
for joint_id in range(physics.model.njnt):
joint_name = physics.model.id2name(joint_id, 'joint')
joint_type = physics.model.jnt_type[joint_id]
is_limited = physics.model.jnt_limited[joint_id]
range_min, range_max = physics.model.jnt_range[joint_id]
if is_limited:
if joint_type == hinge or joint_type == slide:
qpos[joint_name] = random.uniform(range_min, range_max)
elif joint_type == ball:
qpos[joint_name] = random_limited_quaternion(random, range_max)
else:
if joint_type == hinge:
qpos[joint_name] = random.uniform(-np.pi, np.pi)
elif joint_type == ball:
quat = random.randn(4)
quat /= np.linalg.norm(quat)
qpos[joint_name] = quat
elif joint_type == free:
quat = random.rand(4)
quat /= np.linalg.norm(quat)
qpos[joint_name][3:] = quat