92 lines
3.1 KiB
Python
Executable File
92 lines
3.1 KiB
Python
Executable File
# Copyright 2017 The dm_control Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
|
|
"""Randomization functions."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from dm_control.mujoco.wrapper import mjbindings
|
|
import numpy as np
|
|
from six.moves import range
|
|
|
|
|
|
def random_limited_quaternion(random, limit):
|
|
"""Generates a random quaternion limited to the specified rotations."""
|
|
axis = random.randn(3)
|
|
axis /= np.linalg.norm(axis)
|
|
angle = random.rand() * limit
|
|
|
|
quaternion = np.zeros(4)
|
|
mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
|
|
|
|
return quaternion
|
|
|
|
|
|
def randomize_limited_and_rotational_joints(physics, random=None):
|
|
"""Randomizes the positions of joints defined in the physics body.
|
|
|
|
The following randomization rules apply:
|
|
- Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
|
|
- Unbounded hinges are samples uniformly in [-pi, pi]
|
|
- Quaternions for unlimited free joints and ball joints are sampled
|
|
uniformly on the unit 3-sphere.
|
|
- Quaternions for limited ball joints are sampled uniformly on a sector
|
|
of the unit 3-sphere.
|
|
- The linear degrees of freedom of free joints are not randomized.
|
|
|
|
Args:
|
|
physics: Instance of 'Physics' class that holds a loaded model.
|
|
random: Optional instance of 'np.random.RandomState'. Defaults to the global
|
|
NumPy random state.
|
|
"""
|
|
random = random or np.random
|
|
|
|
hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
|
|
slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
|
|
ball = mjbindings.enums.mjtJoint.mjJNT_BALL
|
|
free = mjbindings.enums.mjtJoint.mjJNT_FREE
|
|
|
|
qpos = physics.named.data.qpos
|
|
|
|
for joint_id in range(physics.model.njnt):
|
|
joint_name = physics.model.id2name(joint_id, 'joint')
|
|
joint_type = physics.model.jnt_type[joint_id]
|
|
is_limited = physics.model.jnt_limited[joint_id]
|
|
range_min, range_max = physics.model.jnt_range[joint_id]
|
|
|
|
if is_limited:
|
|
if joint_type == hinge or joint_type == slide:
|
|
qpos[joint_name] = random.uniform(range_min, range_max)
|
|
|
|
elif joint_type == ball:
|
|
qpos[joint_name] = random_limited_quaternion(random, range_max)
|
|
|
|
else:
|
|
if joint_type == hinge:
|
|
qpos[joint_name] = random.uniform(-np.pi, np.pi)
|
|
|
|
elif joint_type == ball:
|
|
quat = random.randn(4)
|
|
quat /= np.linalg.norm(quat)
|
|
qpos[joint_name] = quat
|
|
|
|
elif joint_type == free:
|
|
quat = random.rand(4)
|
|
quat /= np.linalg.norm(quat)
|
|
qpos[joint_name][3:] = quat
|
|
|