# Copyright 2017 The dm_control Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """A collection of MuJoCo-based Reinforcement Learning environments.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import inspect import itertools from dm_control.rl import control from local_dm_control_suite import acrobot from local_dm_control_suite import ball_in_cup from local_dm_control_suite import cartpole from local_dm_control_suite import cheetah from local_dm_control_suite import finger from local_dm_control_suite import fish from local_dm_control_suite import hopper from local_dm_control_suite import humanoid from local_dm_control_suite import humanoid_CMU from local_dm_control_suite import lqr from local_dm_control_suite import manipulator from local_dm_control_suite import pendulum from local_dm_control_suite import point_mass from local_dm_control_suite import quadruped from local_dm_control_suite import reacher from local_dm_control_suite import stacker from local_dm_control_suite import swimmer from local_dm_control_suite import walker # Find all domains imported. _DOMAINS = {name: module for name, module in locals().items() if inspect.ismodule(module) and hasattr(module, 'SUITE')} def _get_tasks(tag): """Returns a sequence of (domain name, task name) pairs for the given tag.""" result = [] for domain_name in sorted(_DOMAINS.keys()): domain = _DOMAINS[domain_name] if tag is None: tasks_in_domain = domain.SUITE else: tasks_in_domain = domain.SUITE.tagged(tag) for task_name in tasks_in_domain.keys(): result.append((domain_name, task_name)) return tuple(result) def _get_tasks_by_domain(tasks): """Returns a dict mapping from task name to a tuple of domain names.""" result = collections.defaultdict(list) for domain_name, task_name in tasks: result[domain_name].append(task_name) return {k: tuple(v) for k, v in result.items()} # A sequence containing all (domain name, task name) pairs. ALL_TASKS = _get_tasks(tag=None) # Subsets of ALL_TASKS, generated via the tag mechanism. BENCHMARKING = _get_tasks('benchmarking') EASY = _get_tasks('easy') HARD = _get_tasks('hard') EXTRA = tuple(sorted(set(ALL_TASKS) - set(BENCHMARKING))) # A mapping from each domain name to a sequence of its task names. TASKS_BY_DOMAIN = _get_tasks_by_domain(ALL_TASKS) def load(domain_name, task_name, task_kwargs=None, environment_kwargs=None, visualize_reward=False): """Returns an environment from a domain name, task name and optional settings. ```python env = suite.load('cartpole', 'balance') ``` Args: domain_name: A string containing the name of a domain. task_name: A string containing the name of a task. task_kwargs: Optional `dict` of keyword arguments for the task. environment_kwargs: Optional `dict` specifying keyword arguments for the environment. visualize_reward: Optional `bool`. If `True`, object colours in rendered frames are set to indicate the reward at each step. Default `False`. Returns: The requested environment. """ return build_environment(domain_name, task_name, task_kwargs, environment_kwargs, visualize_reward) def build_environment(domain_name, task_name, task_kwargs=None, environment_kwargs=None, visualize_reward=False): """Returns an environment from the suite given a domain name and a task name. Args: domain_name: A string containing the name of a domain. task_name: A string containing the name of a task. task_kwargs: Optional `dict` specifying keyword arguments for the task. environment_kwargs: Optional `dict` specifying keyword arguments for the environment. visualize_reward: Optional `bool`. If `True`, object colours in rendered frames are set to indicate the reward at each step. Default `False`. Raises: ValueError: If the domain or task doesn't exist. Returns: An instance of the requested environment. """ if domain_name not in _DOMAINS: raise ValueError('Domain {!r} does not exist.'.format(domain_name)) domain = _DOMAINS[domain_name] if task_name not in domain.SUITE: raise ValueError('Level {!r} does not exist in domain {!r}.'.format( task_name, domain_name)) task_kwargs = task_kwargs or {} if environment_kwargs is not None: task_kwargs = task_kwargs.copy() task_kwargs['environment_kwargs'] = environment_kwargs env = domain.SUITE[task_name](**task_kwargs) env.task.visualize_reward = visualize_reward return env