152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
|
# Copyright 2017 The dm_control Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ============================================================================
|
||
|
|
||
|
"""A collection of MuJoCo-based Reinforcement Learning environments."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import collections
|
||
|
import inspect
|
||
|
import itertools
|
||
|
|
||
|
from dm_control.rl import control
|
||
|
|
||
|
from local_dm_control_suite import acrobot
|
||
|
from local_dm_control_suite import ball_in_cup
|
||
|
from local_dm_control_suite import cartpole
|
||
|
from local_dm_control_suite import cheetah
|
||
|
from local_dm_control_suite import finger
|
||
|
from local_dm_control_suite import fish
|
||
|
from local_dm_control_suite import hopper
|
||
|
from local_dm_control_suite import humanoid
|
||
|
from local_dm_control_suite import humanoid_CMU
|
||
|
from local_dm_control_suite import lqr
|
||
|
from local_dm_control_suite import manipulator
|
||
|
from local_dm_control_suite import pendulum
|
||
|
from local_dm_control_suite import point_mass
|
||
|
from local_dm_control_suite import quadruped
|
||
|
from local_dm_control_suite import reacher
|
||
|
from local_dm_control_suite import stacker
|
||
|
from local_dm_control_suite import swimmer
|
||
|
from local_dm_control_suite import walker
|
||
|
|
||
|
# Find all domains imported.
|
||
|
_DOMAINS = {name: module for name, module in locals().items()
|
||
|
if inspect.ismodule(module) and hasattr(module, 'SUITE')}
|
||
|
|
||
|
|
||
|
def _get_tasks(tag):
|
||
|
"""Returns a sequence of (domain name, task name) pairs for the given tag."""
|
||
|
result = []
|
||
|
|
||
|
for domain_name in sorted(_DOMAINS.keys()):
|
||
|
domain = _DOMAINS[domain_name]
|
||
|
|
||
|
if tag is None:
|
||
|
tasks_in_domain = domain.SUITE
|
||
|
else:
|
||
|
tasks_in_domain = domain.SUITE.tagged(tag)
|
||
|
|
||
|
for task_name in tasks_in_domain.keys():
|
||
|
result.append((domain_name, task_name))
|
||
|
|
||
|
return tuple(result)
|
||
|
|
||
|
|
||
|
def _get_tasks_by_domain(tasks):
|
||
|
"""Returns a dict mapping from task name to a tuple of domain names."""
|
||
|
result = collections.defaultdict(list)
|
||
|
|
||
|
for domain_name, task_name in tasks:
|
||
|
result[domain_name].append(task_name)
|
||
|
|
||
|
return {k: tuple(v) for k, v in result.items()}
|
||
|
|
||
|
|
||
|
# A sequence containing all (domain name, task name) pairs.
|
||
|
ALL_TASKS = _get_tasks(tag=None)
|
||
|
|
||
|
# Subsets of ALL_TASKS, generated via the tag mechanism.
|
||
|
BENCHMARKING = _get_tasks('benchmarking')
|
||
|
EASY = _get_tasks('easy')
|
||
|
HARD = _get_tasks('hard')
|
||
|
EXTRA = tuple(sorted(set(ALL_TASKS) - set(BENCHMARKING)))
|
||
|
|
||
|
# A mapping from each domain name to a sequence of its task names.
|
||
|
TASKS_BY_DOMAIN = _get_tasks_by_domain(ALL_TASKS)
|
||
|
|
||
|
|
||
|
def load(domain_name, task_name, task_kwargs=None, environment_kwargs=None,
|
||
|
visualize_reward=False):
|
||
|
"""Returns an environment from a domain name, task name and optional settings.
|
||
|
|
||
|
```python
|
||
|
env = suite.load('cartpole', 'balance')
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
domain_name: A string containing the name of a domain.
|
||
|
task_name: A string containing the name of a task.
|
||
|
task_kwargs: Optional `dict` of keyword arguments for the task.
|
||
|
environment_kwargs: Optional `dict` specifying keyword arguments for the
|
||
|
environment.
|
||
|
visualize_reward: Optional `bool`. If `True`, object colours in rendered
|
||
|
frames are set to indicate the reward at each step. Default `False`.
|
||
|
|
||
|
Returns:
|
||
|
The requested environment.
|
||
|
"""
|
||
|
return build_environment(domain_name, task_name, task_kwargs,
|
||
|
environment_kwargs, visualize_reward)
|
||
|
|
||
|
|
||
|
def build_environment(domain_name, task_name, task_kwargs=None,
|
||
|
environment_kwargs=None, visualize_reward=False):
|
||
|
"""Returns an environment from the suite given a domain name and a task name.
|
||
|
|
||
|
Args:
|
||
|
domain_name: A string containing the name of a domain.
|
||
|
task_name: A string containing the name of a task.
|
||
|
task_kwargs: Optional `dict` specifying keyword arguments for the task.
|
||
|
environment_kwargs: Optional `dict` specifying keyword arguments for the
|
||
|
environment.
|
||
|
visualize_reward: Optional `bool`. If `True`, object colours in rendered
|
||
|
frames are set to indicate the reward at each step. Default `False`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the domain or task doesn't exist.
|
||
|
|
||
|
Returns:
|
||
|
An instance of the requested environment.
|
||
|
"""
|
||
|
if domain_name not in _DOMAINS:
|
||
|
raise ValueError('Domain {!r} does not exist.'.format(domain_name))
|
||
|
|
||
|
domain = _DOMAINS[domain_name]
|
||
|
|
||
|
if task_name not in domain.SUITE:
|
||
|
raise ValueError('Level {!r} does not exist in domain {!r}.'.format(
|
||
|
task_name, domain_name))
|
||
|
|
||
|
task_kwargs = task_kwargs or {}
|
||
|
if environment_kwargs is not None:
|
||
|
task_kwargs = task_kwargs.copy()
|
||
|
task_kwargs['environment_kwargs'] = environment_kwargs
|
||
|
env = domain.SUITE[task_name](**task_kwargs)
|
||
|
env.task.visualize_reward = visualize_reward
|
||
|
return env
|