98 lines
3.3 KiB
Python
Executable File
98 lines
3.3 KiB
Python
Executable File
# Copyright 2017 The dm_control Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
|
|
"""Cheetah Domain."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
|
|
from dm_control import mujoco
|
|
from dm_control.rl import control
|
|
from local_dm_control_suite import base
|
|
from local_dm_control_suite import common
|
|
from dm_control.utils import containers
|
|
from dm_control.utils import rewards
|
|
|
|
|
|
# How long the simulation will run, in seconds.
|
|
_DEFAULT_TIME_LIMIT = 10
|
|
|
|
# Running speed above which reward is 1.
|
|
_RUN_SPEED = 10
|
|
|
|
SUITE = containers.TaggedTasks()
|
|
|
|
|
|
def get_model_and_assets():
|
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
|
return common.read_model('cheetah.xml'), common.ASSETS
|
|
|
|
|
|
@SUITE.add('benchmarking')
|
|
def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
|
"""Returns the run task."""
|
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
|
task = Cheetah(random=random)
|
|
environment_kwargs = environment_kwargs or {}
|
|
return control.Environment(physics, task, time_limit=time_limit,
|
|
**environment_kwargs)
|
|
|
|
|
|
class Physics(mujoco.Physics):
|
|
"""Physics simulation with additional features for the Cheetah domain."""
|
|
|
|
def speed(self):
|
|
"""Returns the horizontal speed of the Cheetah."""
|
|
return self.named.data.sensordata['torso_subtreelinvel'][0]
|
|
|
|
|
|
class Cheetah(base.Task):
|
|
"""A `Task` to train a running Cheetah."""
|
|
|
|
def initialize_episode(self, physics):
|
|
"""Sets the state of the environment at the start of each episode."""
|
|
# The indexing below assumes that all joints have a single DOF.
|
|
assert physics.model.nq == physics.model.njnt
|
|
is_limited = physics.model.jnt_limited == 1
|
|
lower, upper = physics.model.jnt_range[is_limited].T
|
|
physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
|
|
|
|
# Stabilize the model before the actual simulation.
|
|
for _ in range(200):
|
|
physics.step()
|
|
|
|
physics.data.time = 0
|
|
self._timeout_progress = 0
|
|
super(Cheetah, self).initialize_episode(physics)
|
|
|
|
def get_observation(self, physics):
|
|
"""Returns an observation of the state, ignoring horizontal position."""
|
|
obs = collections.OrderedDict()
|
|
# Ignores horizontal position to maintain translational invariance.
|
|
obs['position'] = physics.data.qpos[1:].copy()
|
|
obs['velocity'] = physics.velocity()
|
|
return obs
|
|
|
|
def get_reward(self, physics):
|
|
"""Returns a reward to the agent."""
|
|
return rewards.tolerance(physics.speed(),
|
|
bounds=(_RUN_SPEED, float('inf')),
|
|
margin=_RUN_SPEED,
|
|
value_at_margin=0,
|
|
sigmoid='linear')
|