diff --git a/conda_env.yml b/conda_env.yml
index 5f0ae94..3289fff 100644
--- a/conda_env.yml
+++ b/conda_env.yml
@@ -2,16 +2,16 @@ name: pytorch_sac_ae
channels:
- defaults
dependencies:
- - python=3.6
+ - python=3.7
- pytorch
- torchvision
- cudatoolkit=9.2
- absl-py
- - pyparsing
+ - pyparsing < 3.0.0
- pip:
- termcolor
- - git+git://github.com/deepmind/dm_control.git
- - git+git://github.com/denisyarats/dmc2gym.git
+ - git+https://github.com/deepmind/dm_control.git
+ - git+https://github.com/denisyarats/dmc2gym.git
- tb-nightly
- imageio
- imageio-ffmpeg
diff --git a/dmc2gym/__init__.py b/dmc2gym/__init__.py
new file mode 100644
index 0000000..7c1d277
--- /dev/null
+++ b/dmc2gym/__init__.py
@@ -0,0 +1,52 @@
+import gym
+from gym.envs.registration import register
+
+
+def make(
+ domain_name,
+ task_name,
+ resource_files,
+ img_source,
+ total_frames,
+ seed=1,
+ visualize_reward=True,
+ from_pixels=False,
+ height=84,
+ width=84,
+ camera_id=0,
+ frame_skip=1,
+ episode_length=1000,
+ environment_kwargs=None
+):
+ env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
+
+ if from_pixels:
+ assert not visualize_reward, 'cannot use visualize reward when learning from pixels'
+
+ # shorten episode length
+ max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
+
+ if not env_id in gym.envs.registry.env_specs:
+ register(
+ id=env_id,
+ entry_point='dmc2gym.wrappers:DMCWrapper',
+ kwargs={
+ 'domain_name': domain_name,
+ 'task_name': task_name,
+ 'resource_files': resource_files,
+ 'img_source': img_source,
+ 'total_frames': total_frames,
+ 'task_kwargs': {
+ 'random': seed
+ },
+ 'environment_kwargs': environment_kwargs,
+ 'visualize_reward': visualize_reward,
+ 'from_pixels': from_pixels,
+ 'height': height,
+ 'width': width,
+ 'camera_id': camera_id,
+ 'frame_skip': frame_skip,
+ },
+ max_episode_steps=max_episode_steps
+ )
+ return gym.make(env_id)
diff --git a/dmc2gym/natural_imgsource.py b/dmc2gym/natural_imgsource.py
new file mode 100644
index 0000000..42ef62f
--- /dev/null
+++ b/dmc2gym/natural_imgsource.py
@@ -0,0 +1,183 @@
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import cv2
+import skvideo.io
+import random
+import tqdm
+
+class BackgroundMatting(object):
+ """
+ Produce a mask by masking the given color. This is a simple strategy
+ but effective for many games.
+ """
+ def __init__(self, color):
+ """
+ Args:
+ color: a (r, g, b) tuple or single value for grayscale
+ """
+ self._color = color
+
+ def get_mask(self, img):
+ return img == self._color
+
+
+class ImageSource(object):
+ """
+ Source of natural images to be added to a simulated environment.
+ """
+ def get_image(self):
+ """
+ Returns:
+ an RGB image of [h, w, 3] with a fixed shape.
+ """
+ pass
+
+ def reset(self):
+ """ Called when an episode ends. """
+ pass
+
+
+class FixedColorSource(ImageSource):
+ def __init__(self, shape, color):
+ """
+ Args:
+ shape: [h, w]
+ color: a 3-tuple
+ """
+ self.arr = np.zeros((shape[0], shape[1], 3))
+ self.arr[:, :] = color
+
+ def get_image(self):
+ return self.arr
+
+
+class RandomColorSource(ImageSource):
+ def __init__(self, shape):
+ """
+ Args:
+ shape: [h, w]
+ """
+ self.shape = shape
+ self.arr = None
+ self.reset()
+
+ def reset(self):
+ self._color = np.random.randint(0, 256, size=(3,))
+ self.arr = np.zeros((self.shape[0], self.shape[1], 3))
+ self.arr[:, :] = self._color
+
+ def get_image(self):
+ return self.arr
+
+
+class NoiseSource(ImageSource):
+ def __init__(self, shape, strength=255):
+ """
+ Args:
+ shape: [h, w]
+ strength (int): the strength of noise, in range [0, 255]
+ """
+ self.shape = shape
+ self.strength = strength
+
+ def get_image(self):
+ return np.random.randn(self.shape[0], self.shape[1], 3) * self.strength
+
+
+class RandomImageSource(ImageSource):
+ def __init__(self, shape, filelist, total_frames=None, grayscale=False):
+ """
+ Args:
+ shape: [h, w]
+ filelist: a list of image files
+ """
+ self.grayscale = grayscale
+ self.total_frames = total_frames
+ self.shape = shape
+ self.filelist = filelist
+ self.build_arr()
+ self.current_idx = 0
+ self.reset()
+
+ def build_arr(self):
+ self.total_frames = self.total_frames if self.total_frames else len(self.filelist)
+ self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
+ for i in range(self.total_frames):
+ # if i % len(self.filelist) == 0: random.shuffle(self.filelist)
+ fname = self.filelist[i % len(self.filelist)]
+ if self.grayscale: im = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)[..., None]
+ else: im = cv2.imread(fname, cv2.IMREAD_COLOR)
+ self.arr[i] = cv2.resize(im, (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
+
+ def reset(self):
+ self._loc = np.random.randint(0, self.total_frames)
+
+ def get_image(self):
+ return self.arr[self._loc]
+
+
+class RandomVideoSource(ImageSource):
+ def __init__(self, shape, filelist, total_frames=None, grayscale=False):
+ """
+ Args:
+ shape: [h, w]
+ filelist: a list of video files
+ """
+ self.grayscale = grayscale
+ self.total_frames = total_frames
+ self.shape = shape
+ self.filelist = filelist
+ self.build_arr()
+ self.current_idx = 0
+ self.reset()
+
+ def build_arr(self):
+ if not self.total_frames:
+ self.total_frames = 0
+ self.arr = None
+ random.shuffle(self.filelist)
+ for fname in tqdm.tqdm(self.filelist, desc="Loading videos for natural", position=0):
+ if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
+ else: frames = skvideo.io.vread(fname)
+ local_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
+ for i in tqdm.tqdm(range(frames.shape[0]), desc="video frames", position=1):
+ local_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
+ if self.arr is None:
+ self.arr = local_arr
+ else:
+ self.arr = np.concatenate([self.arr, local_arr], 0)
+ self.total_frames += local_arr.shape[0]
+ else:
+ self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
+ total_frame_i = 0
+ file_i = 0
+ with tqdm.tqdm(total=self.total_frames, desc="Loading videos for natural") as pbar:
+ while total_frame_i < self.total_frames:
+ if file_i % len(self.filelist) == 0: random.shuffle(self.filelist)
+ file_i += 1
+ fname = self.filelist[file_i % len(self.filelist)]
+ if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
+ else: frames = skvideo.io.vread(fname)
+ for frame_i in range(frames.shape[0]):
+ if total_frame_i >= self.total_frames: break
+ if self.grayscale:
+ self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))[..., None] ## THIS IS NOT A BUG! cv2 uses (width, height)
+ else:
+ self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))
+ pbar.update(1)
+ total_frame_i += 1
+
+
+ def reset(self):
+ self._loc = np.random.randint(0, self.total_frames)
+
+ def get_image(self):
+ img = self.arr[self._loc % self.total_frames]
+ self._loc += 1
+ return img
diff --git a/dmc2gym/wrappers.py b/dmc2gym/wrappers.py
new file mode 100644
index 0000000..077f2eb
--- /dev/null
+++ b/dmc2gym/wrappers.py
@@ -0,0 +1,198 @@
+from gym import core, spaces
+import glob
+import os
+import local_dm_control_suite as suite
+from dm_env import specs
+import numpy as np
+import skimage.io
+
+from dmc2gym import natural_imgsource
+
+
+def _spec_to_box(spec):
+ def extract_min_max(s):
+ assert s.dtype == np.float64 or s.dtype == np.float32
+ dim = np.int(np.prod(s.shape))
+ if type(s) == specs.Array:
+ bound = np.inf * np.ones(dim, dtype=np.float32)
+ return -bound, bound
+ elif type(s) == specs.BoundedArray:
+ zeros = np.zeros(dim, dtype=np.float32)
+ return s.minimum + zeros, s.maximum + zeros
+
+ mins, maxs = [], []
+ for s in spec:
+ mn, mx = extract_min_max(s)
+ mins.append(mn)
+ maxs.append(mx)
+ low = np.concatenate(mins, axis=0)
+ high = np.concatenate(maxs, axis=0)
+ assert low.shape == high.shape
+ return spaces.Box(low, high, dtype=np.float32)
+
+
+def _flatten_obs(obs):
+ obs_pieces = []
+ for v in obs.values():
+ flat = np.array([v]) if np.isscalar(v) else v.ravel()
+ obs_pieces.append(flat)
+ return np.concatenate(obs_pieces, axis=0)
+
+
+class DMCWrapper(core.Env):
+ def __init__(
+ self,
+ domain_name,
+ task_name,
+ resource_files,
+ img_source,
+ total_frames,
+ task_kwargs=None,
+ visualize_reward={},
+ from_pixels=False,
+ height=84,
+ width=84,
+ camera_id=0,
+ frame_skip=1,
+ environment_kwargs=None
+ ):
+ assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
+ self._from_pixels = from_pixels
+ self._height = height
+ self._width = width
+ self._camera_id = camera_id
+ self._frame_skip = frame_skip
+ self._img_source = img_source
+
+ # create task
+ self._env = suite.load(
+ domain_name=domain_name,
+ task_name=task_name,
+ task_kwargs=task_kwargs,
+ visualize_reward=visualize_reward,
+ environment_kwargs=environment_kwargs
+ )
+
+ # true and normalized action spaces
+ self._true_action_space = _spec_to_box([self._env.action_spec()])
+ self._norm_action_space = spaces.Box(
+ low=-1.0,
+ high=1.0,
+ shape=self._true_action_space.shape,
+ dtype=np.float32
+ )
+
+ # create observation space
+ if from_pixels:
+ self._observation_space = spaces.Box(
+ low=0, high=255, shape=[3, height, width], dtype=np.uint8
+ )
+ else:
+ self._observation_space = _spec_to_box(
+ self._env.observation_spec().values()
+ )
+
+ self._internal_state_space = spaces.Box(
+ low=-np.inf,
+ high=np.inf,
+ shape=self._env.physics.get_state().shape,
+ dtype=np.float32
+ )
+
+ # background
+ if img_source is not None:
+ shape2d = (height, width)
+ if img_source == "color":
+ self._bg_source = natural_imgsource.RandomColorSource(shape2d)
+ elif img_source == "noise":
+ self._bg_source = natural_imgsource.NoiseSource(shape2d)
+ else:
+ files = glob.glob(os.path.expanduser(resource_files))
+ assert len(files), "Pattern {} does not match any files".format(
+ resource_files
+ )
+ if img_source == "images":
+ self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
+ elif img_source == "video":
+ self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
+ else:
+ raise Exception("img_source %s not defined." % img_source)
+
+ # set seed
+ self.seed(seed=task_kwargs.get('random', 1))
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+ def _get_obs(self, time_step):
+ if self._from_pixels:
+ obs = self.render(
+ height=self._height,
+ width=self._width,
+ camera_id=self._camera_id
+ )
+ if self._img_source is not None:
+ mask = np.logical_and((obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])) # hardcoded for dmc
+ bg = self._bg_source.get_image()
+ obs[mask] = bg[mask]
+ obs = obs.transpose(2, 0, 1).copy()
+ else:
+ obs = _flatten_obs(time_step.observation)
+ return obs
+
+ def _convert_action(self, action):
+ action = action.astype(np.float64)
+ true_delta = self._true_action_space.high - self._true_action_space.low
+ norm_delta = self._norm_action_space.high - self._norm_action_space.low
+ action = (action - self._norm_action_space.low) / norm_delta
+ action = action * true_delta + self._true_action_space.low
+ action = action.astype(np.float32)
+ return action
+
+ @property
+ def observation_space(self):
+ return self._observation_space
+
+ @property
+ def internal_state_space(self):
+ return self._internal_state_space
+
+ @property
+ def action_space(self):
+ return self._norm_action_space
+
+ def seed(self, seed):
+ self._true_action_space.seed(seed)
+ self._norm_action_space.seed(seed)
+ self._observation_space.seed(seed)
+
+ def step(self, action):
+ assert self._norm_action_space.contains(action)
+ action = self._convert_action(action)
+ assert self._true_action_space.contains(action)
+ reward = 0
+ extra = {'internal_state': self._env.physics.get_state().copy()}
+
+ for _ in range(self._frame_skip):
+ time_step = self._env.step(action)
+ reward += time_step.reward or 0
+ done = time_step.last()
+ if done:
+ break
+ obs = self._get_obs(time_step)
+ extra['discount'] = time_step.discount
+ return obs, reward, done, extra
+
+ def reset(self):
+ time_step = self._env.reset()
+ obs = self._get_obs(time_step)
+ return obs
+
+ def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
+ assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
+ height = height or self._height
+ width = width or self._width
+ camera_id = camera_id or self._camera_id
+ return self._env.physics.render(
+ height=height, width=width, camera_id=camera_id
+ )
diff --git a/encoder.py b/encoder.py
index b137a62..f6bc0bb 100644
--- a/encoder.py
+++ b/encoder.py
@@ -1,8 +1,9 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
-def tie_weights(src, trg):
+def tie_weights(src, trg):
assert type(src) == type(trg)
trg.weight = src.weight
trg.bias = src.bias
@@ -28,8 +29,8 @@ class PixelEncoder(nn.Module):
self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
out_dim = OUT_DIM[num_layers]
- self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
- self.ln = nn.LayerNorm(self.feature_dim)
+ self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim * 2)
+ self.ln = nn.LayerNorm(self.feature_dim * 2)
self.outputs = dict()
@@ -63,11 +64,18 @@ class PixelEncoder(nn.Module):
h_norm = self.ln(h_fc)
self.outputs['ln'] = h_norm
+
+ h_tan = torch.tanh(h_norm)
- out = torch.tanh(h_norm)
- self.outputs['tanh'] = out
+ mu, logstd = torch.chunk(h_tan, 2, dim=-1)
+ self.outputs['mu'] = mu
+ self.outputs['logstd'] = logstd
- return out
+ std = torch.tanh(h_norm)
+ self.outputs['std'] = std
+
+ out = self.reparameterize(mu, logstd)
+ return out, mu, logstd
def copy_conv_weights_from(self, source):
"""Tie convolutional layers"""
@@ -107,6 +115,103 @@ class IdentityEncoder(nn.Module):
pass
+class TransitionModel(nn.Module):
+ def __init__(self, state_size, hidden_size, action_size, history_size):
+ super().__init__()
+
+ self.state_size = state_size
+ self.hidden_size = hidden_size
+ self.action_size = action_size
+ self.history_size = history_size
+ self.act_fn = nn.ELU()
+
+ self.fc_state_action = nn.Linear(state_size + action_size, hidden_size)
+ self.history_cell = nn.GRUCell(hidden_size, history_size)
+ self.fc_state_mu = nn.Linear(history_size + hidden_size, state_size)
+ self.fc_state_sigma = nn.Linear(history_size + hidden_size, state_size)
+
+ self.batch_norm = nn.BatchNorm1d(hidden_size)
+ self.batch_norm2 = nn.BatchNorm1d(state_size)
+
+ self.min_sigma = 1e-4
+ self.max_sigma = 1e0
+
+ def init_states(self, batch_size, device):
+ self.prev_state = torch.zeros(batch_size, self.state_size).to(device)
+ self.prev_action = torch.zeros(batch_size, self.action_size).to(device)
+ self.prev_history = torch.zeros(batch_size, self.history_size).to(device)
+
+ def get_dist(self, mean, std):
+ distribution = torch.distributions.Normal(mean, std)
+ distribution = torch.distributions.independent.Independent(distribution, 1)
+ return distribution
+
+ def stack_states(self, states, dim=0):
+ s = dict(
+ mean = torch.stack([state['mean'] for state in states], dim=dim),
+ std = torch.stack([state['std'] for state in states], dim=dim),
+ sample = torch.stack([state['sample'] for state in states], dim=dim),
+ history = torch.stack([state['history'] for state in states], dim=dim),)
+ if 'distribution' in states:
+ dist = dict(distribution = [state['distribution'] for state in states])
+ s.update(dist)
+ return s
+
+ def seq_to_batch(self, state, name):
+ return dict(
+ sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:])))
+
+ def transition_step(self, prev_state, prev_action, prev_hist, prev_not_done):
+ prev_state = prev_state.detach() * prev_not_done
+ prev_hist = prev_hist * prev_not_done
+
+ state_action_enc = self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))
+ state_action_enc = self.act_fn(self.batch_norm(state_action_enc))
+
+ current_hist = self.history_cell(state_action_enc, prev_hist)
+ state_mu = self.act_fn(self.fc_state_mu(torch.cat([state_action_enc, prev_hist], dim=-1)))
+ state_sigma = F.softplus(self.fc_state_sigma(torch.cat([state_action_enc, prev_hist], dim=-1)))
+ sample_state = state_mu + torch.randn_like(state_mu) * state_sigma
+
+ state_enc = {"mean": state_mu, "std": state_sigma, "sample": sample_state, "history": current_hist}
+ return state_enc
+
+ def observe_step(self, prev_state, prev_action, prev_history):
+ state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))))
+ current_history = self.history_cell(state_action_enc, prev_history)
+ state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([state_action_enc, prev_history], dim=-1))))
+ state_sigma = F.softplus(self.fc_state_sigma(torch.cat([state_action_enc, prev_history], dim=-1)))
+
+ sample_state = state_mu + torch.randn_like(state_mu) * state_sigma
+ state_enc = {"mean": state_mu, "std": state_sigma, "sample": sample_state, "history": current_history}
+ return state_enc
+
+ def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms):
+ observed_rollout = []
+ for i in range(rollout_states.shape[0]):
+ rollout_states_ = rollout_states[i]
+ rollout_actions_ = rollout_actions[i]
+ init_history_ = nonterms[i] * init_history
+ state_enc = self.observe_step(rollout_states_, rollout_actions_, init_history_)
+ init_history = state_enc["history"]
+ observed_rollout.append(state_enc)
+ observed_rollout = self.stack_states(observed_rollout, dim=0)
+ return observed_rollout
+
+ def reparemeterize(self, mean, std):
+ eps = torch.randn_like(mean)
+ return mean + eps * std
+
+
+def club_loss(x_samples, x_mu, x_logvar, y_samples):
+ sample_size = x_samples.shape[0]
+ random_index = torch.randperm(sample_size).long()
+
+ positive = -(x_mu - y_samples)**2 / x_logvar.exp()
+ negative = - (x_mu - y_samples[random_index])**2 / x_logvar.exp()
+ upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()
+ return upper_bound/2.
+
_AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}
diff --git a/local_dm_control_suite/README.md b/local_dm_control_suite/README.md
new file mode 100755
index 0000000..135ab42
--- /dev/null
+++ b/local_dm_control_suite/README.md
@@ -0,0 +1,56 @@
+# DeepMind Control Suite.
+
+This submodule contains the domains and tasks described in the
+[DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690).
+
+## Quickstart
+
+```python
+from dm_control import suite
+import numpy as np
+
+# Load one task:
+env = suite.load(domain_name="cartpole", task_name="swingup")
+
+# Iterate over a task set:
+for domain_name, task_name in suite.BENCHMARKING:
+ env = suite.load(domain_name, task_name)
+
+# Step through an episode and print out reward, discount and observation.
+action_spec = env.action_spec()
+time_step = env.reset()
+while not time_step.last():
+ action = np.random.uniform(action_spec.minimum,
+ action_spec.maximum,
+ size=action_spec.shape)
+ time_step = env.step(action)
+ print(time_step.reward, time_step.discount, time_step.observation)
+```
+
+## Illustration video
+
+Below is a video montage of solved Control Suite tasks, with reward
+visualisation enabled.
+
+[![Video montage](https://img.youtube.com/vi/rAai4QzcYbs/0.jpg)](https://www.youtube.com/watch?v=rAai4QzcYbs)
+
+
+### Quadruped domain [April 2019]
+
+Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are:
+
+- 4 DoFs per leg, 1 constraining tendon.
+- 3 actuators per leg: 'yaw', 'lift', 'extend'.
+- Filtered position actuators with timescale of 100ms.
+- Sensors include an IMU, force/torque sensors, and rangefinders.
+
+Four tasks:
+
+- `walk` and `run`: self-right the body then move forward at a desired speed.
+- `escape`: escape a bowl-shaped random terrain (uses rangefinders).
+- `fetch`, go to a moving ball and bring it to a target.
+
+All behaviors in the video below were trained with [Abdolmaleki et al's
+MPO](https://arxiv.org/abs/1806.06920).
+
+[![Video montage](https://img.youtube.com/vi/RhRLjbb7pBE/0.jpg)](https://www.youtube.com/watch?v=RhRLjbb7pBE)
diff --git a/local_dm_control_suite/__init__.py b/local_dm_control_suite/__init__.py
new file mode 100755
index 0000000..c4d7cb9
--- /dev/null
+++ b/local_dm_control_suite/__init__.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A collection of MuJoCo-based Reinforcement Learning environments."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import inspect
+import itertools
+
+from dm_control.rl import control
+
+from local_dm_control_suite import acrobot
+from local_dm_control_suite import ball_in_cup
+from local_dm_control_suite import cartpole
+from local_dm_control_suite import cheetah
+from local_dm_control_suite import finger
+from local_dm_control_suite import fish
+from local_dm_control_suite import hopper
+from local_dm_control_suite import humanoid
+from local_dm_control_suite import humanoid_CMU
+from local_dm_control_suite import lqr
+from local_dm_control_suite import manipulator
+from local_dm_control_suite import pendulum
+from local_dm_control_suite import point_mass
+from local_dm_control_suite import quadruped
+from local_dm_control_suite import reacher
+from local_dm_control_suite import stacker
+from local_dm_control_suite import swimmer
+from local_dm_control_suite import walker
+
+# Find all domains imported.
+_DOMAINS = {name: module for name, module in locals().items()
+ if inspect.ismodule(module) and hasattr(module, 'SUITE')}
+
+
+def _get_tasks(tag):
+ """Returns a sequence of (domain name, task name) pairs for the given tag."""
+ result = []
+
+ for domain_name in sorted(_DOMAINS.keys()):
+ domain = _DOMAINS[domain_name]
+
+ if tag is None:
+ tasks_in_domain = domain.SUITE
+ else:
+ tasks_in_domain = domain.SUITE.tagged(tag)
+
+ for task_name in tasks_in_domain.keys():
+ result.append((domain_name, task_name))
+
+ return tuple(result)
+
+
+def _get_tasks_by_domain(tasks):
+ """Returns a dict mapping from task name to a tuple of domain names."""
+ result = collections.defaultdict(list)
+
+ for domain_name, task_name in tasks:
+ result[domain_name].append(task_name)
+
+ return {k: tuple(v) for k, v in result.items()}
+
+
+# A sequence containing all (domain name, task name) pairs.
+ALL_TASKS = _get_tasks(tag=None)
+
+# Subsets of ALL_TASKS, generated via the tag mechanism.
+BENCHMARKING = _get_tasks('benchmarking')
+EASY = _get_tasks('easy')
+HARD = _get_tasks('hard')
+EXTRA = tuple(sorted(set(ALL_TASKS) - set(BENCHMARKING)))
+
+# A mapping from each domain name to a sequence of its task names.
+TASKS_BY_DOMAIN = _get_tasks_by_domain(ALL_TASKS)
+
+
+def load(domain_name, task_name, task_kwargs=None, environment_kwargs=None,
+ visualize_reward=False):
+ """Returns an environment from a domain name, task name and optional settings.
+
+ ```python
+ env = suite.load('cartpole', 'balance')
+ ```
+
+ Args:
+ domain_name: A string containing the name of a domain.
+ task_name: A string containing the name of a task.
+ task_kwargs: Optional `dict` of keyword arguments for the task.
+ environment_kwargs: Optional `dict` specifying keyword arguments for the
+ environment.
+ visualize_reward: Optional `bool`. If `True`, object colours in rendered
+ frames are set to indicate the reward at each step. Default `False`.
+
+ Returns:
+ The requested environment.
+ """
+ return build_environment(domain_name, task_name, task_kwargs,
+ environment_kwargs, visualize_reward)
+
+
+def build_environment(domain_name, task_name, task_kwargs=None,
+ environment_kwargs=None, visualize_reward=False):
+ """Returns an environment from the suite given a domain name and a task name.
+
+ Args:
+ domain_name: A string containing the name of a domain.
+ task_name: A string containing the name of a task.
+ task_kwargs: Optional `dict` specifying keyword arguments for the task.
+ environment_kwargs: Optional `dict` specifying keyword arguments for the
+ environment.
+ visualize_reward: Optional `bool`. If `True`, object colours in rendered
+ frames are set to indicate the reward at each step. Default `False`.
+
+ Raises:
+ ValueError: If the domain or task doesn't exist.
+
+ Returns:
+ An instance of the requested environment.
+ """
+ if domain_name not in _DOMAINS:
+ raise ValueError('Domain {!r} does not exist.'.format(domain_name))
+
+ domain = _DOMAINS[domain_name]
+
+ if task_name not in domain.SUITE:
+ raise ValueError('Level {!r} does not exist in domain {!r}.'.format(
+ task_name, domain_name))
+
+ task_kwargs = task_kwargs or {}
+ if environment_kwargs is not None:
+ task_kwargs = task_kwargs.copy()
+ task_kwargs['environment_kwargs'] = environment_kwargs
+ env = domain.SUITE[task_name](**task_kwargs)
+ env.task.visualize_reward = visualize_reward
+ return env
diff --git a/local_dm_control_suite/acrobot.py b/local_dm_control_suite/acrobot.py
new file mode 100755
index 0000000..a12b892
--- /dev/null
+++ b/local_dm_control_suite/acrobot.py
@@ -0,0 +1,127 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Acrobot domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 10
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('acrobot.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns Acrobot balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(sparse=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns Acrobot sparse balance."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(sparse=True, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Acrobot domain."""
+
+ def horizontal(self):
+ """Returns horizontal (x) component of body frame z-axes."""
+ return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz']
+
+ def vertical(self):
+ """Returns vertical (z) component of body frame z-axes."""
+ return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz']
+
+ def to_target(self):
+ """Returns the distance from the tip to the target."""
+ tip_to_target = (self.named.data.site_xpos['target'] -
+ self.named.data.site_xpos['tip'])
+ return np.linalg.norm(tip_to_target)
+
+ def orientations(self):
+ """Returns the sines and cosines of the pole angles."""
+ return np.concatenate((self.horizontal(), self.vertical()))
+
+
+class Balance(base.Task):
+ """An Acrobot `Task` to swing up and balance the pole."""
+
+ def __init__(self, sparse, random=None):
+ """Initializes an instance of `Balance`.
+
+ Args:
+ sparse: A `bool` specifying whether to use a sparse (indicator) reward.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._sparse = sparse
+ super(Balance, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Shoulder and elbow are set to a random position between [-pi, pi).
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ physics.named.data.qpos[
+ ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
+ super(Balance, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of pole orientation and angular velocities."""
+ obs = collections.OrderedDict()
+ obs['orientations'] = physics.orientations()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def _get_reward(self, physics, sparse):
+ target_radius = physics.named.model.site_size['target', 0]
+ return rewards.tolerance(physics.to_target(),
+ bounds=(0, target_radius),
+ margin=0 if sparse else 1)
+
+ def get_reward(self, physics):
+ """Returns a sparse or a smooth reward, as specified in the constructor."""
+ return self._get_reward(physics, sparse=self._sparse)
diff --git a/local_dm_control_suite/acrobot.xml b/local_dm_control_suite/acrobot.xml
new file mode 100755
index 0000000..79b76d9
--- /dev/null
+++ b/local_dm_control_suite/acrobot.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/ball_in_cup.py b/local_dm_control_suite/ball_in_cup.py
new file mode 100755
index 0000000..ac3e47f
--- /dev/null
+++ b/local_dm_control_suite/ball_in_cup.py
@@ -0,0 +1,100 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Ball-in-Cup Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+
+_DEFAULT_TIME_LIMIT = 20 # (seconds)
+_CONTROL_TIMESTEP = .02 # (seconds)
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('ball_in_cup.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'easy')
+def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Ball-in-Cup task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = BallInCup(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics with additional features for the Ball-in-Cup domain."""
+
+ def ball_to_target(self):
+ """Returns the vector from the ball to the target."""
+ target = self.named.data.site_xpos['target', ['x', 'z']]
+ ball = self.named.data.xpos['ball', ['x', 'z']]
+ return target - ball
+
+ def in_target(self):
+ """Returns 1 if the ball is in the target, 0 otherwise."""
+ ball_to_target = abs(self.ball_to_target())
+ target_size = self.named.model.site_size['target', [0, 2]]
+ ball_size = self.named.model.geom_size['ball', 0]
+ return float(all(ball_to_target < target_size - ball_size))
+
+
+class BallInCup(base.Task):
+ """The Ball-in-Cup task. Put the ball in the cup."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Find a collision-free random initial position of the ball.
+ penetrating = True
+ while penetrating:
+ # Assign a random ball position.
+ physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
+ physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+ super(BallInCup, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of the state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a sparse reward."""
+ return physics.in_target()
diff --git a/local_dm_control_suite/ball_in_cup.xml b/local_dm_control_suite/ball_in_cup.xml
new file mode 100755
index 0000000..792073f
--- /dev/null
+++ b/local_dm_control_suite/ball_in_cup.xml
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/base.py b/local_dm_control_suite/base.py
new file mode 100755
index 0000000..fd78318
--- /dev/null
+++ b/local_dm_control_suite/base.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for tasks in the Control Suite."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import mujoco
+from dm_control.rl import control
+
+import numpy as np
+
+
+class Task(control.Task):
+ """Base class for tasks in the Control Suite.
+
+ Actions are mapped directly to the states of MuJoCo actuators: each element of
+ the action array is used to set the control input for a single actuator. The
+ ordering of the actuators is the same as in the corresponding MJCF XML file.
+
+ Attributes:
+ random: A `numpy.random.RandomState` instance. This should be used to
+ generate all random variables associated with the task, such as random
+ starting states, observation noise* etc.
+
+ *If sensor noise is enabled in the MuJoCo model then this will be generated
+ using MuJoCo's internal RNG, which has its own independent state.
+ """
+
+ def __init__(self, random=None):
+ """Initializes a new continuous control task.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an integer
+ seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ if not isinstance(random, np.random.RandomState):
+ random = np.random.RandomState(random)
+ self._random = random
+ self._visualize_reward = False
+
+ @property
+ def random(self):
+ """Task-specific `numpy.random.RandomState` instance."""
+ return self._random
+
+ def action_spec(self, physics):
+ """Returns a `BoundedArraySpec` matching the `physics` actuators."""
+ return mujoco.action_spec(physics)
+
+ def initialize_episode(self, physics):
+ """Resets geom colors to their defaults after starting a new episode.
+
+ Subclasses of `base.Task` must delegate to this method after performing
+ their own initialization.
+
+ Args:
+ physics: An instance of `mujoco.Physics`.
+ """
+ self.after_step(physics)
+
+ def before_step(self, action, physics):
+ """Sets the control signal for the actuators to values in `action`."""
+ # Support legacy internal code.
+ action = getattr(action, "continuous_actions", action)
+ physics.set_control(action)
+
+ def after_step(self, physics):
+ """Modifies colors according to the reward."""
+ if self._visualize_reward:
+ reward = np.clip(self.get_reward(physics), 0.0, 1.0)
+ _set_reward_colors(physics, reward)
+
+ @property
+ def visualize_reward(self):
+ return self._visualize_reward
+
+ @visualize_reward.setter
+ def visualize_reward(self, value):
+ if not isinstance(value, bool):
+ raise ValueError("Expected a boolean, got {}.".format(type(value)))
+ self._visualize_reward = value
+
+
+_MATERIALS = ["self", "effector", "target"]
+_DEFAULT = [name + "_default" for name in _MATERIALS]
+_HIGHLIGHT = [name + "_highlight" for name in _MATERIALS]
+
+
+def _set_reward_colors(physics, reward):
+ """Sets the highlight, effector and target colors according to the reward."""
+ assert 0.0 <= reward <= 1.0
+ colors = physics.named.model.mat_rgba
+ default = colors[_DEFAULT]
+ highlight = colors[_HIGHLIGHT]
+ blend_coef = reward ** 4 # Better color distinction near high rewards.
+ colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default
diff --git a/local_dm_control_suite/cartpole.py b/local_dm_control_suite/cartpole.py
new file mode 100755
index 0000000..b8fec14
--- /dev/null
+++ b/local_dm_control_suite/cartpole.py
@@ -0,0 +1,230 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Cartpole domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from lxml import etree
+import numpy as np
+from six.moves import range
+
+
+_DEFAULT_TIME_LIMIT = 10
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets(num_poles=1):
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return _make_model(num_poles), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def balance(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Cartpole Balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=False, sparse=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def balance_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the sparse reward variant of the Cartpole Balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=False, sparse=True, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Cartpole Swing-Up task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=True, sparse=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the sparse reward variant of teh Cartpole Swing-Up task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=True, sparse=True, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add()
+def two_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Cartpole Balance task with two poles."""
+ physics = Physics.from_xml_string(*get_model_and_assets(num_poles=2))
+ task = Balance(swing_up=True, sparse=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add()
+def three_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None, num_poles=3,
+ sparse=False, environment_kwargs=None):
+ """Returns the Cartpole Balance task with three or more poles."""
+ physics = Physics.from_xml_string(*get_model_and_assets(num_poles=num_poles))
+ task = Balance(swing_up=True, sparse=sparse, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+def _make_model(n_poles):
+ """Generates an xml string defining a cart with `n_poles` bodies."""
+ xml_string = common.read_model('cartpole.xml')
+ if n_poles == 1:
+ return xml_string
+ mjcf = etree.fromstring(xml_string)
+ parent = mjcf.find('./worldbody/body/body') # Find first pole.
+ # Make chain of poles.
+ for pole_index in range(2, n_poles+1):
+ child = etree.Element('body', name='pole_{}'.format(pole_index),
+ pos='0 0 1', childclass='pole')
+ etree.SubElement(child, 'joint', name='hinge_{}'.format(pole_index))
+ etree.SubElement(child, 'geom', name='pole_{}'.format(pole_index))
+ parent.append(child)
+ parent = child
+ # Move plane down.
+ floor = mjcf.find('./worldbody/geom')
+ floor.set('pos', '0 0 {}'.format(1 - n_poles - .05))
+ # Move cameras back.
+ cameras = mjcf.findall('./worldbody/camera')
+ cameras[0].set('pos', '0 {} 1'.format(-1 - 2*n_poles))
+ cameras[1].set('pos', '0 {} 2'.format(-2*n_poles))
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Cartpole domain."""
+
+ def cart_position(self):
+ """Returns the position of the cart."""
+ return self.named.data.qpos['slider'][0]
+
+ def angular_vel(self):
+ """Returns the angular velocity of the pole."""
+ return self.data.qvel[1:]
+
+ def pole_angle_cosine(self):
+ """Returns the cosine of the pole angle."""
+ return self.named.data.xmat[2:, 'zz']
+
+ def bounded_position(self):
+ """Returns the state, with pole angle split into sin/cos."""
+ return np.hstack((self.cart_position(),
+ self.named.data.xmat[2:, ['zz', 'xz']].ravel()))
+
+
+class Balance(base.Task):
+ """A Cartpole `Task` to balance the pole.
+
+ State is initialized either close to the target configuration or at a random
+ configuration.
+ """
+ _CART_RANGE = (-.25, .25)
+ _ANGLE_COSINE_RANGE = (.995, 1)
+
+ def __init__(self, swing_up, sparse, random=None):
+ """Initializes an instance of `Balance`.
+
+ Args:
+ swing_up: A `bool`, which if `True` sets the cart to the middle of the
+ slider and the pole pointing towards the ground. Otherwise, sets the
+ cart to a random position on the slider and the pole to a random
+ near-vertical position.
+ sparse: A `bool`, whether to return a sparse or a smooth reward.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._sparse = sparse
+ self._swing_up = swing_up
+ super(Balance, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Initializes the cart and pole according to `swing_up`, and in both cases
+ adds a small random initial velocity to break symmetry.
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ nv = physics.model.nv
+ if self._swing_up:
+ physics.named.data.qpos['slider'] = .01*self.random.randn()
+ physics.named.data.qpos['hinge_1'] = np.pi + .01*self.random.randn()
+ physics.named.data.qpos[2:] = .1*self.random.randn(nv - 2)
+ else:
+ physics.named.data.qpos['slider'] = self.random.uniform(-.1, .1)
+ physics.named.data.qpos[1:] = self.random.uniform(-.034, .034, nv - 1)
+ physics.named.data.qvel[:] = 0.01 * self.random.randn(physics.model.nv)
+ super(Balance, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of the (bounded) physics state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.bounded_position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def _get_reward(self, physics, sparse):
+ if sparse:
+ cart_in_bounds = rewards.tolerance(physics.cart_position(),
+ self._CART_RANGE)
+ angle_in_bounds = rewards.tolerance(physics.pole_angle_cosine(),
+ self._ANGLE_COSINE_RANGE).prod()
+ return cart_in_bounds * angle_in_bounds
+ else:
+ upright = (physics.pole_angle_cosine() + 1) / 2
+ centered = rewards.tolerance(physics.cart_position(), margin=2)
+ centered = (1 + centered) / 2
+ small_control = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic')[0]
+ small_control = (4 + small_control) / 5
+ small_velocity = rewards.tolerance(physics.angular_vel(), margin=5).min()
+ small_velocity = (1 + small_velocity) / 2
+ return upright.mean() * small_control * small_velocity * centered
+
+ def get_reward(self, physics):
+ """Returns a sparse or a smooth reward, as specified in the constructor."""
+ return self._get_reward(physics, sparse=self._sparse)
diff --git a/local_dm_control_suite/cartpole.xml b/local_dm_control_suite/cartpole.xml
new file mode 100755
index 0000000..e01869d
--- /dev/null
+++ b/local_dm_control_suite/cartpole.xml
@@ -0,0 +1,37 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/cheetah.py b/local_dm_control_suite/cheetah.py
new file mode 100755
index 0000000..7dd2a63
--- /dev/null
+++ b/local_dm_control_suite/cheetah.py
@@ -0,0 +1,97 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Cheetah Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+
+# How long the simulation will run, in seconds.
+_DEFAULT_TIME_LIMIT = 10
+
+# Running speed above which reward is 1.
+_RUN_SPEED = 10
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('cheetah.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Cheetah domain."""
+
+ def speed(self):
+ """Returns the horizontal speed of the Cheetah."""
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
+
+
+class Cheetah(base.Task):
+ """A `Task` to train a running Cheetah."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # The indexing below assumes that all joints have a single DOF.
+ assert physics.model.nq == physics.model.njnt
+ is_limited = physics.model.jnt_limited == 1
+ lower, upper = physics.model.jnt_range[is_limited].T
+ physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
+
+ # Stabilize the model before the actual simulation.
+ for _ in range(200):
+ physics.step()
+
+ physics.data.time = 0
+ self._timeout_progress = 0
+ super(Cheetah, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of the state, ignoring horizontal position."""
+ obs = collections.OrderedDict()
+ # Ignores horizontal position to maintain translational invariance.
+ obs['position'] = physics.data.qpos[1:].copy()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ return rewards.tolerance(physics.speed(),
+ bounds=(_RUN_SPEED, float('inf')),
+ margin=_RUN_SPEED,
+ value_at_margin=0,
+ sigmoid='linear')
diff --git a/local_dm_control_suite/cheetah.xml b/local_dm_control_suite/cheetah.xml
new file mode 100755
index 0000000..1952b5e
--- /dev/null
+++ b/local_dm_control_suite/cheetah.xml
@@ -0,0 +1,73 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/common/__init__.py b/local_dm_control_suite/common/__init__.py
new file mode 100755
index 0000000..62eab26
--- /dev/null
+++ b/local_dm_control_suite/common/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Functions to manage the common assets for domains."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from dm_control.utils import io as resources
+
+_SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
+_FILENAMES = [
+ "./common/materials.xml",
+ "./common/materials_white_floor.xml",
+ "./common/skybox.xml",
+ "./common/visual.xml",
+]
+
+ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
+ for filename in _FILENAMES}
+
+
+def read_model(model_filename):
+ """Reads a model XML file and returns its contents as a string."""
+ return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
diff --git a/local_dm_control_suite/common/materials.xml b/local_dm_control_suite/common/materials.xml
new file mode 100755
index 0000000..5a3b169
--- /dev/null
+++ b/local_dm_control_suite/common/materials.xml
@@ -0,0 +1,23 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/common/materials_white_floor.xml b/local_dm_control_suite/common/materials_white_floor.xml
new file mode 100755
index 0000000..a1e35c2
--- /dev/null
+++ b/local_dm_control_suite/common/materials_white_floor.xml
@@ -0,0 +1,23 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/common/skybox.xml b/local_dm_control_suite/common/skybox.xml
new file mode 100755
index 0000000..b888692
--- /dev/null
+++ b/local_dm_control_suite/common/skybox.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
diff --git a/local_dm_control_suite/common/visual.xml b/local_dm_control_suite/common/visual.xml
new file mode 100755
index 0000000..ede15ad
--- /dev/null
+++ b/local_dm_control_suite/common/visual.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/demos/mocap_demo.py b/local_dm_control_suite/demos/mocap_demo.py
new file mode 100755
index 0000000..2e2c7ca
--- /dev/null
+++ b/local_dm_control_suite/demos/mocap_demo.py
@@ -0,0 +1,84 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Demonstration of amc parsing for CMU mocap database.
+
+To run the demo, supply a path to a `.amc` file:
+
+ python mocap_demo --filename='path/to/mocap.amc'
+
+CMU motion capture clips are available at mocap.cs.cmu.edu
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+# Internal dependencies.
+
+from absl import app
+from absl import flags
+
+from local_dm_control_suite import humanoid_CMU
+from dm_control.suite.utils import parse_amc
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('filename', None, 'amc file to be converted.')
+flags.DEFINE_integer('max_num_frames', 90,
+ 'Maximum number of frames for plotting/playback')
+
+
+def main(unused_argv):
+ env = humanoid_CMU.stand()
+
+ # Parse and convert specified clip.
+ converted = parse_amc.convert(FLAGS.filename,
+ env.physics, env.control_timestep())
+
+ max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1)
+
+ width = 480
+ height = 480
+ video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)
+
+ for i in range(max_frame):
+ p_i = converted.qpos[:, i]
+ with env.physics.reset_context():
+ env.physics.data.qpos[:] = p_i
+ video[i] = np.hstack([env.physics.render(height, width, camera_id=0),
+ env.physics.render(height, width, camera_id=1)])
+
+ tic = time.time()
+ for i in range(max_frame):
+ if i == 0:
+ img = plt.imshow(video[i])
+ else:
+ img.set_data(video[i])
+ toc = time.time()
+ clock_dt = toc - tic
+ tic = time.time()
+ # Real-time playback not always possible as clock_dt > .03
+ plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0.
+ plt.draw()
+ plt.waitforbuttonpress()
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('filename')
+ app.run(main)
diff --git a/local_dm_control_suite/demos/zeros.amc b/local_dm_control_suite/demos/zeros.amc
new file mode 100755
index 0000000..b4590a4
--- /dev/null
+++ b/local_dm_control_suite/demos/zeros.amc
@@ -0,0 +1,213 @@
+#DUMMY AMC for testing
+:FULLY-SPECIFIED
+:DEGREES
+1
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+2
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+3
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+4
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+5
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+6
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+7
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
diff --git a/local_dm_control_suite/explore.py b/local_dm_control_suite/explore.py
new file mode 100755
index 0000000..06fb0a8
--- /dev/null
+++ b/local_dm_control_suite/explore.py
@@ -0,0 +1,84 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Control suite environments explorer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import app
+from absl import flags
+from dm_control import suite
+from dm_control.suite.wrappers import action_noise
+from six.moves import input
+
+from dm_control import viewer
+
+
+_ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS]
+
+flags.DEFINE_enum('environment_name', None, _ALL_NAMES,
+ 'Optional \'domain_name.task_name\' pair specifying the '
+ 'environment to load. If unspecified a prompt will appear to '
+ 'select one.')
+flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
+flags.DEFINE_bool('visualize_reward', True,
+ 'Whether to vary the colors of geoms according to the '
+ 'current reward value.')
+flags.DEFINE_float('action_noise', 0.,
+ 'Standard deviation of Gaussian noise to apply to actions, '
+ 'expressed as a fraction of the max-min range for each '
+ 'action dimension. Defaults to 0, i.e. no noise.')
+FLAGS = flags.FLAGS
+
+
+def prompt_environment_name(prompt, values):
+ environment_name = None
+ while not environment_name:
+ environment_name = input(prompt)
+ if not environment_name or values.index(environment_name) < 0:
+ print('"%s" is not a valid environment name.' % environment_name)
+ environment_name = None
+ return environment_name
+
+
+def main(argv):
+ del argv
+ environment_name = FLAGS.environment_name
+ if environment_name is None:
+ print('\n '.join(['Available environments:'] + _ALL_NAMES))
+ environment_name = prompt_environment_name(
+ 'Please select an environment name: ', _ALL_NAMES)
+
+ index = _ALL_NAMES.index(environment_name)
+ domain_name, task_name = suite.ALL_TASKS[index]
+
+ task_kwargs = {}
+ if not FLAGS.timeout:
+ task_kwargs['time_limit'] = float('inf')
+
+ def loader():
+ env = suite.load(
+ domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs)
+ env.task.visualize_reward = FLAGS.visualize_reward
+ if FLAGS.action_noise > 0:
+ env = action_noise.Wrapper(env, scale=FLAGS.action_noise)
+ return env
+
+ viewer.launch(loader)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/local_dm_control_suite/finger.py b/local_dm_control_suite/finger.py
new file mode 100755
index 0000000..e700db6
--- /dev/null
+++ b/local_dm_control_suite/finger.py
@@ -0,0 +1,217 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Finger Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+import numpy as np
+from six.moves import range
+
+_DEFAULT_TIME_LIMIT = 20 # (seconds)
+_CONTROL_TIMESTEP = .02 # (seconds)
+# For TURN tasks, the 'tip' geom needs to enter a spherical target of sizes:
+_EASY_TARGET_SIZE = 0.07
+_HARD_TARGET_SIZE = 0.03
+# Initial spin velocity for the Stop task.
+_INITIAL_SPIN_VELOCITY = 100
+# Spinning slower than this value (radian/second) is considered stopped.
+_STOP_VELOCITY = 1e-6
+# Spinning faster than this value (radian/second) is considered spinning.
+_SPIN_VELOCITY = 15.0
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('finger.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Spin task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Spin(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def turn_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the easy Turn task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Turn(target_radius=_EASY_TARGET_SIZE, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def turn_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the hard Turn task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Turn(target_radius=_HARD_TARGET_SIZE, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Finger domain."""
+
+ def touch(self):
+ """Returns logarithmically scaled signals from the two touch sensors."""
+ return np.log1p(self.named.data.sensordata[['touchtop', 'touchbottom']])
+
+ def hinge_velocity(self):
+ """Returns the velocity of the hinge joint."""
+ return self.named.data.sensordata['hinge_velocity']
+
+ def tip_position(self):
+ """Returns the (x,z) position of the tip relative to the hinge."""
+ return (self.named.data.sensordata['tip'][[0, 2]] -
+ self.named.data.sensordata['spinner'][[0, 2]])
+
+ def bounded_position(self):
+ """Returns the positions, with the hinge angle replaced by tip position."""
+ return np.hstack((self.named.data.sensordata[['proximal', 'distal']],
+ self.tip_position()))
+
+ def velocity(self):
+ """Returns the velocities (extracted from sensordata)."""
+ return self.named.data.sensordata[['proximal_velocity',
+ 'distal_velocity',
+ 'hinge_velocity']]
+
+ def target_position(self):
+ """Returns the (x,z) position of the target relative to the hinge."""
+ return (self.named.data.sensordata['target'][[0, 2]] -
+ self.named.data.sensordata['spinner'][[0, 2]])
+
+ def to_target(self):
+ """Returns the vector from the tip to the target."""
+ return self.target_position() - self.tip_position()
+
+ def dist_to_target(self):
+ """Returns the signed distance to the target surface, negative is inside."""
+ return (np.linalg.norm(self.to_target()) -
+ self.named.model.site_size['target', 0])
+
+
+class Spin(base.Task):
+ """A Finger `Task` to spin the stopped body."""
+
+ def __init__(self, random=None):
+ """Initializes a new `Spin` instance.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(Spin, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ physics.named.model.site_rgba['target', 3] = 0
+ physics.named.model.site_rgba['tip', 3] = 0
+ physics.named.model.dof_damping['hinge'] = .03
+ _set_random_joint_angles(physics, self.random)
+ super(Spin, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns state and touch sensors, and target info."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.bounded_position()
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a sparse reward."""
+ return float(physics.hinge_velocity() <= -_SPIN_VELOCITY)
+
+
+class Turn(base.Task):
+ """A Finger `Task` to turn the body to a target angle."""
+
+ def __init__(self, target_radius, random=None):
+ """Initializes a new `Turn` instance.
+
+ Args:
+ target_radius: Radius of the target site, which specifies the goal angle.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._target_radius = target_radius
+ super(Turn, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ target_angle = self.random.uniform(-np.pi, np.pi)
+ hinge_x, hinge_z = physics.named.data.xanchor['hinge', ['x', 'z']]
+ radius = physics.named.model.geom_size['cap1'].sum()
+ target_x = hinge_x + radius * np.sin(target_angle)
+ target_z = hinge_z + radius * np.cos(target_angle)
+ physics.named.model.site_pos['target', ['x', 'z']] = target_x, target_z
+ physics.named.model.site_size['target', 0] = self._target_radius
+
+ _set_random_joint_angles(physics, self.random)
+
+ super(Turn, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns state, touch sensors, and target info."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.bounded_position()
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ obs['target_position'] = physics.target_position()
+ obs['dist_to_target'] = physics.dist_to_target()
+ return obs
+
+ def get_reward(self, physics):
+ return float(physics.dist_to_target() <= 0)
+
+
+def _set_random_joint_angles(physics, random, max_attempts=1000):
+ """Sets the joints to a random collision-free state."""
+
+ for _ in range(max_attempts):
+ randomizers.randomize_limited_and_rotational_joints(physics, random)
+ # Check for collisions.
+ physics.after_reset()
+ if physics.data.ncon == 0:
+ break
+ else:
+ raise RuntimeError('Could not find a collision-free state '
+ 'after {} attempts'.format(max_attempts))
diff --git a/local_dm_control_suite/finger.xml b/local_dm_control_suite/finger.xml
new file mode 100755
index 0000000..3b35986
--- /dev/null
+++ b/local_dm_control_suite/finger.xml
@@ -0,0 +1,72 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/fish.py b/local_dm_control_suite/fish.py
new file mode 100755
index 0000000..3262def
--- /dev/null
+++ b/local_dm_control_suite/fish.py
@@ -0,0 +1,176 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Fish Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+
+_DEFAULT_TIME_LIMIT = 40
+_CONTROL_TIMESTEP = .04
+_JOINTS = ['tail1',
+ 'tail_twist',
+ 'tail2',
+ 'finright_roll',
+ 'finright_pitch',
+ 'finleft_roll',
+ 'finleft_pitch']
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('fish.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def upright(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Fish Upright task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Upright(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def swim(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Fish Swim task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Swim(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Fish domain."""
+
+ def upright(self):
+ """Returns projection from z-axes of torso to the z-axes of worldbody."""
+ return self.named.data.xmat['torso', 'zz']
+
+ def torso_velocity(self):
+ """Returns velocities and angular velocities of the torso."""
+ return self.data.sensordata
+
+ def joint_velocities(self):
+ """Returns the joint velocities."""
+ return self.named.data.qvel[_JOINTS]
+
+ def joint_angles(self):
+ """Returns the joint positions."""
+ return self.named.data.qpos[_JOINTS]
+
+ def mouth_to_target(self):
+ """Returns a vector, from mouth to target in local coordinate of mouth."""
+ data = self.named.data
+ mouth_to_target_global = data.geom_xpos['target'] - data.geom_xpos['mouth']
+ return mouth_to_target_global.dot(data.geom_xmat['mouth'].reshape(3, 3))
+
+
+class Upright(base.Task):
+ """A Fish `Task` for getting the torso upright with smooth reward."""
+
+ def __init__(self, random=None):
+ """Initializes an instance of `Upright`.
+
+ Args:
+ random: Either an existing `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically.
+ """
+ super(Upright, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Randomizes the tail and fin angles and the orientation of the Fish."""
+ quat = self.random.randn(4)
+ physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
+ for joint in _JOINTS:
+ physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
+ # Hide the target. It's irrelevant for this task.
+ physics.named.model.geom_rgba['target', 3] = 0
+ super(Upright, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of joint angles, velocities and uprightness."""
+ obs = collections.OrderedDict()
+ obs['joint_angles'] = physics.joint_angles()
+ obs['upright'] = physics.upright()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a smooth reward."""
+ return rewards.tolerance(physics.upright(), bounds=(1, 1), margin=1)
+
+
+class Swim(base.Task):
+ """A Fish `Task` for swimming with smooth reward."""
+
+ def __init__(self, random=None):
+ """Initializes an instance of `Swim`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(Swim, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+
+ quat = self.random.randn(4)
+ physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
+ for joint in _JOINTS:
+ physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
+ # Randomize target position.
+ physics.named.model.geom_pos['target', 'x'] = self.random.uniform(-.4, .4)
+ physics.named.model.geom_pos['target', 'y'] = self.random.uniform(-.4, .4)
+ physics.named.model.geom_pos['target', 'z'] = self.random.uniform(.1, .3)
+ super(Swim, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of joints, target direction and velocities."""
+ obs = collections.OrderedDict()
+ obs['joint_angles'] = physics.joint_angles()
+ obs['upright'] = physics.upright()
+ obs['target'] = physics.mouth_to_target()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a smooth reward."""
+ radii = physics.named.model.geom_size[['mouth', 'target'], 0].sum()
+ in_target = rewards.tolerance(np.linalg.norm(physics.mouth_to_target()),
+ bounds=(0, radii), margin=2*radii)
+ is_upright = 0.5 * (physics.upright() + 1)
+ return (7*in_target + is_upright) / 8
diff --git a/local_dm_control_suite/fish.xml b/local_dm_control_suite/fish.xml
new file mode 100755
index 0000000..43de56d
--- /dev/null
+++ b/local_dm_control_suite/fish.xml
@@ -0,0 +1,85 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/hopper.py b/local_dm_control_suite/hopper.py
new file mode 100755
index 0000000..6458e41
--- /dev/null
+++ b/local_dm_control_suite/hopper.py
@@ -0,0 +1,138 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Hopper domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+
+SUITE = containers.TaggedTasks()
+
+_CONTROL_TIMESTEP = .02 # (Seconds)
+
+# Default duration of an episode, in seconds.
+_DEFAULT_TIME_LIMIT = 20
+
+# Minimal height of torso over foot above which stand reward is 1.
+_STAND_HEIGHT = 0.6
+
+# Hopping speed above which hop reward is 1.
+_HOP_SPEED = 2
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('hopper.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns a Hopper that strives to stand upright, balancing its pose."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Hopper(hopping=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def hop(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns a Hopper that strives to hop forward."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Hopper(hopping=True, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Hopper domain."""
+
+ def height(self):
+ """Returns height of torso with respect to foot."""
+ return (self.named.data.xipos['torso', 'z'] -
+ self.named.data.xipos['foot', 'z'])
+
+ def speed(self):
+ """Returns horizontal speed of the Hopper."""
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
+
+ def touch(self):
+ """Returns the signals from two foot touch sensors."""
+ return np.log1p(self.named.data.sensordata[['touch_toe', 'touch_heel']])
+
+
+class Hopper(base.Task):
+ """A Hopper's `Task` to train a standing and a jumping Hopper."""
+
+ def __init__(self, hopping, random=None):
+ """Initialize an instance of `Hopper`.
+
+ Args:
+ hopping: Boolean, if True the task is to hop forwards, otherwise it is to
+ balance upright.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._hopping = hopping
+ super(Hopper, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ self._timeout_progress = 0
+ super(Hopper, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of positions, velocities and touch sensors."""
+ obs = collections.OrderedDict()
+ # Ignores horizontal position to maintain translational invariance:
+ obs['position'] = physics.data.qpos[1:].copy()
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward applicable to the performed task."""
+ standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
+ if self._hopping:
+ hopping = rewards.tolerance(physics.speed(),
+ bounds=(_HOP_SPEED, float('inf')),
+ margin=_HOP_SPEED/2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return standing * hopping
+ else:
+ small_control = rewards.tolerance(physics.control(),
+ margin=1, value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (small_control + 4) / 5
+ return standing * small_control
diff --git a/local_dm_control_suite/hopper.xml b/local_dm_control_suite/hopper.xml
new file mode 100755
index 0000000..0c8ec28
--- /dev/null
+++ b/local_dm_control_suite/hopper.xml
@@ -0,0 +1,66 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/humanoid.py b/local_dm_control_suite/humanoid.py
new file mode 100755
index 0000000..5a161f0
--- /dev/null
+++ b/local_dm_control_suite/humanoid.py
@@ -0,0 +1,211 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Humanoid Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 25
+_CONTROL_TIMESTEP = .025
+
+# Height of head above which stand reward is 1.
+_STAND_HEIGHT = 1.4
+
+# Horizontal speeds above which move reward is 1.
+_WALK_SPEED = 1
+_RUN_SPEED = 10
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('humanoid.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Stand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=0, pure_state=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=_WALK_SPEED, pure_state=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=_RUN_SPEED, pure_state=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def run_pure_state(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=_RUN_SPEED, pure_state=True, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Walker domain."""
+
+ def torso_upright(self):
+ """Returns projection from z-axes of torso to the z-axes of world."""
+ return self.named.data.xmat['torso', 'zz']
+
+ def head_height(self):
+ """Returns the height of the torso."""
+ return self.named.data.xpos['head', 'z']
+
+ def center_of_mass_position(self):
+ """Returns position of the center-of-mass."""
+ return self.named.data.subtree_com['torso'].copy()
+
+ def center_of_mass_velocity(self):
+ """Returns the velocity of the center-of-mass."""
+ return self.named.data.sensordata['torso_subtreelinvel'].copy()
+
+ def torso_vertical_orientation(self):
+ """Returns the z-projection of the torso orientation matrix."""
+ return self.named.data.xmat['torso', ['zx', 'zy', 'zz']]
+
+ def joint_angles(self):
+ """Returns the state without global orientation or position."""
+ return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
+
+ def extremities(self):
+ """Returns end effector positions in egocentric frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ positions = []
+ for side in ('left_', 'right_'):
+ for limb in ('hand', 'foot'):
+ torso_to_limb = self.named.data.xpos[side + limb] - torso_pos
+ positions.append(torso_to_limb.dot(torso_frame))
+ return np.hstack(positions)
+
+
+class Humanoid(base.Task):
+ """A humanoid task."""
+
+ def __init__(self, move_speed, pure_state, random=None):
+ """Initializes an instance of `Humanoid`.
+
+ Args:
+ move_speed: A float. If this value is zero, reward is given simply for
+ standing up. Otherwise this specifies a target horizontal velocity for
+ the walking task.
+ pure_state: A bool. Whether the observations consist of the pure MuJoCo
+ state or includes some useful features thereof.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._move_speed = move_speed
+ self._pure_state = pure_state
+ super(Humanoid, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Find a collision-free random initial configuration.
+ penetrating = True
+ while penetrating:
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+ super(Humanoid, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns either the pure state or a set of egocentric features."""
+ obs = collections.OrderedDict()
+ if self._pure_state:
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ else:
+ obs['joint_angles'] = physics.joint_angles()
+ obs['head_height'] = physics.head_height()
+ obs['extremities'] = physics.extremities()
+ obs['torso_vertical'] = physics.torso_vertical_orientation()
+ obs['com_velocity'] = physics.center_of_mass_velocity()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ standing = rewards.tolerance(physics.head_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT/4)
+ upright = rewards.tolerance(physics.torso_upright(),
+ bounds=(0.9, float('inf')), sigmoid='linear',
+ margin=1.9, value_at_margin=0)
+ stand_reward = standing * upright
+ small_control = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (4 + small_control) / 5
+ if self._move_speed == 0:
+ horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]]
+ dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean()
+ return small_control * stand_reward * dont_move
+ else:
+ com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]])
+ move = rewards.tolerance(com_velocity,
+ bounds=(self._move_speed, float('inf')),
+ margin=self._move_speed, value_at_margin=0,
+ sigmoid='linear')
+ move = (5*move + 1) / 6
+ return small_control * stand_reward * move
diff --git a/local_dm_control_suite/humanoid.xml b/local_dm_control_suite/humanoid.xml
new file mode 100755
index 0000000..32b84c5
--- /dev/null
+++ b/local_dm_control_suite/humanoid.xml
@@ -0,0 +1,202 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/humanoid_CMU.py b/local_dm_control_suite/humanoid_CMU.py
new file mode 100755
index 0000000..d06fb63
--- /dev/null
+++ b/local_dm_control_suite/humanoid_CMU.py
@@ -0,0 +1,179 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Humanoid_CMU Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 20
+_CONTROL_TIMESTEP = 0.02
+
+# Height of head above which stand reward is 1.
+_STAND_HEIGHT = 1.4
+
+# Horizontal speeds above which move reward is 1.
+_WALK_SPEED = 1
+_RUN_SPEED = 10
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('humanoid_CMU.xml'), common.ASSETS
+
+
+@SUITE.add()
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Stand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = HumanoidCMU(move_speed=0, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = HumanoidCMU(move_speed=_RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the humanoid_CMU domain."""
+
+ def thorax_upright(self):
+ """Returns projection from y-axes of thorax to the z-axes of world."""
+ return self.named.data.xmat['thorax', 'zy']
+
+ def head_height(self):
+ """Returns the height of the head."""
+ return self.named.data.xpos['head', 'z']
+
+ def center_of_mass_position(self):
+ """Returns position of the center-of-mass."""
+ return self.named.data.subtree_com['thorax']
+
+ def center_of_mass_velocity(self):
+ """Returns the velocity of the center-of-mass."""
+ return self.named.data.sensordata['thorax_subtreelinvel'].copy()
+
+ def torso_vertical_orientation(self):
+ """Returns the z-projection of the thorax orientation matrix."""
+ return self.named.data.xmat['thorax', ['zx', 'zy', 'zz']]
+
+ def joint_angles(self):
+ """Returns the state without global orientation or position."""
+ return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
+
+ def extremities(self):
+ """Returns end effector positions in egocentric frame."""
+ torso_frame = self.named.data.xmat['thorax'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['thorax']
+ positions = []
+ for side in ('l', 'r'):
+ for limb in ('hand', 'foot'):
+ torso_to_limb = self.named.data.xpos[side + limb] - torso_pos
+ positions.append(torso_to_limb.dot(torso_frame))
+ return np.hstack(positions)
+
+
+class HumanoidCMU(base.Task):
+ """A task for the CMU Humanoid."""
+
+ def __init__(self, move_speed, random=None):
+ """Initializes an instance of `Humanoid_CMU`.
+
+ Args:
+ move_speed: A float. If this value is zero, reward is given simply for
+ standing up. Otherwise this specifies a target horizontal velocity for
+ the walking task.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._move_speed = move_speed
+ super(HumanoidCMU, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets a random collision-free configuration at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ penetrating = True
+ while penetrating:
+ randomizers.randomize_limited_and_rotational_joints(
+ physics, self.random)
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+ super(HumanoidCMU, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns a set of egocentric features."""
+ obs = collections.OrderedDict()
+ obs['joint_angles'] = physics.joint_angles()
+ obs['head_height'] = physics.head_height()
+ obs['extremities'] = physics.extremities()
+ obs['torso_vertical'] = physics.torso_vertical_orientation()
+ obs['com_velocity'] = physics.center_of_mass_velocity()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ standing = rewards.tolerance(physics.head_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT/4)
+ upright = rewards.tolerance(physics.thorax_upright(),
+ bounds=(0.9, float('inf')), sigmoid='linear',
+ margin=1.9, value_at_margin=0)
+ stand_reward = standing * upright
+ small_control = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (4 + small_control) / 5
+ if self._move_speed == 0:
+ horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]]
+ dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean()
+ return small_control * stand_reward * dont_move
+ else:
+ com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]])
+ move = rewards.tolerance(com_velocity,
+ bounds=(self._move_speed, float('inf')),
+ margin=self._move_speed, value_at_margin=0,
+ sigmoid='linear')
+ move = (5*move + 1) / 6
+ return small_control * stand_reward * move
diff --git a/local_dm_control_suite/humanoid_CMU.xml b/local_dm_control_suite/humanoid_CMU.xml
new file mode 100755
index 0000000..9a41a16
--- /dev/null
+++ b/local_dm_control_suite/humanoid_CMU.xml
@@ -0,0 +1,289 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/lqr.py b/local_dm_control_suite/lqr.py
new file mode 100755
index 0000000..34197b4
--- /dev/null
+++ b/local_dm_control_suite/lqr.py
@@ -0,0 +1,272 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Procedurally generated LQR domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import xml_tools
+from lxml import etree
+import numpy as np
+from six.moves import range
+
+from dm_control.utils import io as resources
+
+_DEFAULT_TIME_LIMIT = float('inf')
+_CONTROL_COST_COEF = 0.1
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets(n_bodies, n_actuators, random):
+ """Returns the model description as an XML string and a dict of assets.
+
+ Args:
+ n_bodies: An int, number of bodies of the LQR.
+ n_actuators: An int, number of actuated bodies of the LQR. `n_actuators`
+ should be less or equal than `n_bodies`.
+ random: A `numpy.random.RandomState` instance.
+
+ Returns:
+ A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of
+ `{filename: contents_string}` pairs.
+ """
+ return _make_model(n_bodies, n_actuators, random), common.ASSETS
+
+
+@SUITE.add()
+def lqr_2_1(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns an LQR environment with 2 bodies of which the first is actuated."""
+ return _make_lqr(n_bodies=2,
+ n_actuators=1,
+ control_cost_coef=_CONTROL_COST_COEF,
+ time_limit=time_limit,
+ random=random,
+ environment_kwargs=environment_kwargs)
+
+
+@SUITE.add()
+def lqr_6_2(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns an LQR environment with 6 bodies of which first 2 are actuated."""
+ return _make_lqr(n_bodies=6,
+ n_actuators=2,
+ control_cost_coef=_CONTROL_COST_COEF,
+ time_limit=time_limit,
+ random=random,
+ environment_kwargs=environment_kwargs)
+
+
+def _make_lqr(n_bodies, n_actuators, control_cost_coef, time_limit, random,
+ environment_kwargs):
+ """Returns a LQR environment.
+
+ Args:
+ n_bodies: An int, number of bodies of the LQR.
+ n_actuators: An int, number of actuated bodies of the LQR. `n_actuators`
+ should be less or equal than `n_bodies`.
+ control_cost_coef: A number, the coefficient of the control cost.
+ time_limit: An int, maximum time for each episode in seconds.
+ random: Either an existing `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically.
+ environment_kwargs: A `dict` specifying keyword arguments for the
+ environment, or None.
+
+ Returns:
+ A LQR environment with `n_bodies` bodies of which first `n_actuators` are
+ actuated.
+ """
+
+ if not isinstance(random, np.random.RandomState):
+ random = np.random.RandomState(random)
+
+ model_string, assets = get_model_and_assets(n_bodies, n_actuators,
+ random=random)
+ physics = Physics.from_xml_string(model_string, assets=assets)
+ task = LQRLevel(control_cost_coef, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ **environment_kwargs)
+
+
+def _make_body(body_id, stiffness_range, damping_range, random):
+ """Returns an `etree.Element` defining a body.
+
+ Args:
+ body_id: Id of the created body.
+ stiffness_range: A tuple of (stiffness_lower_bound, stiffness_uppder_bound).
+ The stiffness of the joint is drawn uniformly from this range.
+ damping_range: A tuple of (damping_lower_bound, damping_upper_bound). The
+ damping of the joint is drawn uniformly from this range.
+ random: A `numpy.random.RandomState` instance.
+
+ Returns:
+ A new instance of `etree.Element`. A body element with two children: joint
+ and geom.
+ """
+ body_name = 'body_{}'.format(body_id)
+ joint_name = 'joint_{}'.format(body_id)
+ geom_name = 'geom_{}'.format(body_id)
+
+ body = etree.Element('body', name=body_name)
+ body.set('pos', '.25 0 0')
+ joint = etree.SubElement(body, 'joint', name=joint_name)
+ body.append(etree.Element('geom', name=geom_name))
+ joint.set('stiffness',
+ str(random.uniform(stiffness_range[0], stiffness_range[1])))
+ joint.set('damping',
+ str(random.uniform(damping_range[0], damping_range[1])))
+ return body
+
+
+def _make_model(n_bodies,
+ n_actuators,
+ random,
+ stiffness_range=(15, 25),
+ damping_range=(0, 0)):
+ """Returns an MJCF XML string defining a model of springs and dampers.
+
+ Args:
+ n_bodies: An integer, the number of bodies (DoFs) in the system.
+ n_actuators: An integer, the number of actuated bodies.
+ random: A `numpy.random.RandomState` instance.
+ stiffness_range: A tuple containing minimum and maximum stiffness. Each
+ joint's stiffness is sampled uniformly from this interval.
+ damping_range: A tuple containing minimum and maximum damping. Each joint's
+ damping is sampled uniformly from this interval.
+
+ Returns:
+ An MJCF string describing the linear system.
+
+ Raises:
+ ValueError: If the number of bodies or actuators is erronous.
+ """
+ if n_bodies < 1 or n_actuators < 1:
+ raise ValueError('At least 1 body and 1 actuator required.')
+ if n_actuators > n_bodies:
+ raise ValueError('At most 1 actuator per body.')
+
+ file_path = os.path.join(os.path.dirname(__file__), 'lqr.xml')
+ with resources.GetResourceAsFile(file_path) as xml_file:
+ mjcf = xml_tools.parse(xml_file)
+ parent = mjcf.find('./worldbody')
+ actuator = etree.SubElement(mjcf.getroot(), 'actuator')
+ tendon = etree.SubElement(mjcf.getroot(), 'tendon')
+
+ for body in range(n_bodies):
+ # Inserting body.
+ child = _make_body(body, stiffness_range, damping_range, random)
+ site_name = 'site_{}'.format(body)
+ child.append(etree.Element('site', name=site_name))
+
+ if body == 0:
+ child.set('pos', '.25 0 .1')
+ # Add actuators to the first n_actuators bodies.
+ if body < n_actuators:
+ # Adding actuator.
+ joint_name = 'joint_{}'.format(body)
+ motor_name = 'motor_{}'.format(body)
+ child.find('joint').set('name', joint_name)
+ actuator.append(etree.Element('motor', name=motor_name, joint=joint_name))
+
+ # Add a tendon between consecutive bodies (for visualisation purposes only).
+ if body < n_bodies - 1:
+ child_site_name = 'site_{}'.format(body + 1)
+ tendon_name = 'tendon_{}'.format(body)
+ spatial = etree.SubElement(tendon, 'spatial', name=tendon_name)
+ spatial.append(etree.Element('site', site=site_name))
+ spatial.append(etree.Element('site', site=child_site_name))
+ parent.append(child)
+ parent = child
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the LQR domain."""
+
+ def state_norm(self):
+ """Returns the norm of the physics state."""
+ return np.linalg.norm(self.state())
+
+
+class LQRLevel(base.Task):
+ """A Linear Quadratic Regulator `Task`."""
+
+ _TERMINAL_TOL = 1e-6
+
+ def __init__(self, control_cost_coef, random=None):
+ """Initializes an LQR level with cost = sum(states^2) + c*sum(controls^2).
+
+ Args:
+ control_cost_coef: The coefficient of the control cost.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+
+ Raises:
+ ValueError: If the control cost coefficient is not positive.
+ """
+ if control_cost_coef <= 0:
+ raise ValueError('control_cost_coef must be positive.')
+
+ self._control_cost_coef = control_cost_coef
+ super(LQRLevel, self).__init__(random=random)
+
+ @property
+ def control_cost_coef(self):
+ return self._control_cost_coef
+
+ def initialize_episode(self, physics):
+ """Random state sampled from a unit sphere."""
+ ndof = physics.model.nq
+ unit = self.random.randn(ndof)
+ physics.data.qpos[:] = np.sqrt(2) * unit / np.linalg.norm(unit)
+ super(LQRLevel, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of the state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a quadratic state and control reward."""
+ position = physics.position()
+ state_cost = 0.5 * np.dot(position, position)
+ control_signal = physics.control()
+ control_l2_norm = 0.5 * np.dot(control_signal, control_signal)
+ return 1 - (state_cost + control_l2_norm * self._control_cost_coef)
+
+ def get_evaluation(self, physics):
+ """Returns a sparse evaluation reward that is not used for learning."""
+ return float(physics.state_norm() <= 0.01)
+
+ def get_termination(self, physics):
+ """Terminates when the state norm is smaller than epsilon."""
+ if physics.state_norm() < self._TERMINAL_TOL:
+ return 0.0
diff --git a/local_dm_control_suite/lqr.xml b/local_dm_control_suite/lqr.xml
new file mode 100755
index 0000000..d403532
--- /dev/null
+++ b/local_dm_control_suite/lqr.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/lqr_solver.py b/local_dm_control_suite/lqr_solver.py
new file mode 100755
index 0000000..3935c7d
--- /dev/null
+++ b/local_dm_control_suite/lqr_solver.py
@@ -0,0 +1,142 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+r"""Optimal policy for LQR levels.
+
+LQR control problem is described in
+https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+from dm_control.mujoco import wrapper
+import numpy as np
+from six.moves import range
+
+try:
+ import scipy.linalg as sp # pylint: disable=g-import-not-at-top
+except ImportError:
+ sp = None
+
+
+def _solve_dare(a, b, q, r):
+ """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration.
+
+ Algebraic Riccati Equation:
+ ```none
+ P_{t-1} = Q + A' * P_{t} * A -
+ A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A
+ ```
+
+ Args:
+ a: A 2 dimensional numpy array, transition matrix A.
+ b: A 2 dimensional numpy array, control matrix B.
+ q: A 2 dimensional numpy array, symmetric positive definite cost matrix.
+ r: A 2 dimensional numpy array, symmetric positive definite cost matrix
+
+ Returns:
+ A numpy array, a real symmetric matrix P which is the solution to DARE.
+
+ Raises:
+ RuntimeError: If the computed P matrix is not symmetric and
+ positive-definite.
+ """
+ p = np.eye(len(a))
+ for _ in range(1000000):
+ a_p = a.T.dot(p) # A' * P_t
+ a_p_b = np.dot(a_p, b) # A' * P_t * B
+ # Algebraic Riccati Equation.
+ p_next = q + np.dot(a_p, a) - a_p_b.dot(
+ np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T))
+ p_next += p_next.T
+ p_next *= .5
+ if np.abs(p - p_next).max() < 1e-12:
+ break
+ p = p_next
+ else:
+ logging.warning('DARE solver did not converge')
+ try:
+ # Check that the result is symmetric and positive-definite.
+ np.linalg.cholesky(p_next)
+ except np.linalg.LinAlgError:
+ raise RuntimeError('ARE solver failed: P matrix is not symmetric and '
+ 'positive-definite.')
+ return p_next
+
+
+def solve(env):
+ """Returns the optimal value and policy for LQR problem.
+
+ Args:
+ env: An instance of `control.EnvironmentV2` with LQR level.
+
+ Returns:
+ p: A numpy array, the Hessian of the optimal total cost-to-go (value
+ function at state x) is V(x) = .5 * x' * p * x.
+ k: A numpy array which gives the optimal linear policy u = k * x.
+ beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at
+ timestep n the state tends to 0 like beta^n.
+
+ Raises:
+ RuntimeError: If the controlled system is unstable.
+ """
+ n = env.physics.model.nq # number of DoFs
+ m = env.physics.model.nu # number of controls
+
+ # Compute the mass matrix.
+ mass = np.zeros((n, n))
+ wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass,
+ env.physics.data.qM)
+
+ # Compute input matrices a, b, q and r to the DARE solvers.
+ # State transition matrix a.
+ stiffness = np.diag(env.physics.model.jnt_stiffness.ravel())
+ damping = np.diag(env.physics.model.dof_damping.ravel())
+ dt = env.physics.model.opt.timestep
+
+ j = np.linalg.solve(-mass, np.hstack((stiffness, damping)))
+ a = np.eye(2 * n) + dt * np.vstack(
+ (dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j))
+
+ # Control transition matrix b.
+ b = env.physics.data.actuator_moment.T
+ bc = np.linalg.solve(mass, b)
+ b = dt * np.vstack((dt * bc, bc))
+
+ # State cost Hessian q.
+ q = np.diag(np.hstack([np.ones(n), np.zeros(n)]))
+
+ # Control cost Hessian r.
+ r = env.task.control_cost_coef * np.eye(m)
+
+ if sp:
+ # Use scipy's faster DARE solver if available.
+ solve_dare = sp.solve_discrete_are
+ else:
+ # Otherwise fall back on a slower internal implementation.
+ solve_dare = _solve_dare
+
+ # Solve the discrete algebraic Riccati equation.
+ p = solve_dare(a, b, q, r)
+ k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a)))
+
+ # Under optimal policy, state tends to 0 like beta^n_timesteps
+ beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max()
+ if beta >= 1.0:
+ raise RuntimeError('Controlled system is unstable.')
+ return p, k, beta
diff --git a/local_dm_control_suite/manipulator.py b/local_dm_control_suite/manipulator.py
new file mode 100755
index 0000000..b2ed31f
--- /dev/null
+++ b/local_dm_control_suite/manipulator.py
@@ -0,0 +1,290 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Planar Manipulator domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+
+from lxml import etree
+import numpy as np
+
+_CLOSE = .01 # (Meters) Distance below which a thing is considered close.
+_CONTROL_TIMESTEP = .01 # (Seconds)
+_TIME_LIMIT = 10 # (Seconds)
+_P_IN_HAND = .1 # Probabillity of object-in-hand initial state
+_P_IN_TARGET = .1 # Probabillity of object-in-target initial state
+_ARM_JOINTS = ['arm_root', 'arm_shoulder', 'arm_elbow', 'arm_wrist',
+ 'finger', 'fingertip', 'thumb', 'thumbtip']
+_ALL_PROPS = frozenset(['ball', 'target_ball', 'cup',
+ 'peg', 'target_peg', 'slot'])
+
+SUITE = containers.TaggedTasks()
+
+
+def make_model(use_peg, insert):
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ xml_string = common.read_model('manipulator.xml')
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # Select the desired prop.
+ if use_peg:
+ required_props = ['peg', 'target_peg']
+ if insert:
+ required_props += ['slot']
+ else:
+ required_props = ['ball', 'target_ball']
+ if insert:
+ required_props += ['cup']
+
+ # Remove unused props
+ for unused_prop in _ALL_PROPS.difference(required_props):
+ prop = xml_tools.find_element(mjcf, 'body', unused_prop)
+ prop.getparent().remove(prop)
+
+ return etree.tostring(mjcf, pretty_print=True), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'hard')
+def bring_ball(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns manipulator bring task with the ball prop."""
+ use_peg = False
+ insert = False
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+@SUITE.add('hard')
+def bring_peg(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns manipulator bring task with the peg prop."""
+ use_peg = True
+ insert = False
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+@SUITE.add('hard')
+def insert_ball(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns manipulator insert task with the ball prop."""
+ use_peg = False
+ insert = True
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+@SUITE.add('hard')
+def insert_peg(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns manipulator insert task with the peg prop."""
+ use_peg = True
+ insert = True
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics with additional features for the Planar Manipulator domain."""
+
+ def bounded_joint_pos(self, joint_names):
+ """Returns joint positions as (sin, cos) values."""
+ joint_pos = self.named.data.qpos[joint_names]
+ return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T
+
+ def joint_vel(self, joint_names):
+ """Returns joint velocities."""
+ return self.named.data.qvel[joint_names]
+
+ def body_2d_pose(self, body_names, orientation=True):
+ """Returns positions and/or orientations of bodies."""
+ if not isinstance(body_names, str):
+ body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices.
+ pos = self.named.data.xpos[body_names, ['x', 'z']]
+ if orientation:
+ ori = self.named.data.xquat[body_names, ['qw', 'qy']]
+ return np.hstack([pos, ori])
+ else:
+ return pos
+
+ def touch(self):
+ return np.log1p(self.data.sensordata)
+
+ def site_distance(self, site1, site2):
+ site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0)
+ return np.linalg.norm(site1_to_site2)
+
+
+class Bring(base.Task):
+ """A Bring `Task`: bring the prop to the target."""
+
+ def __init__(self, use_peg, insert, fully_observable, random=None):
+ """Initialize an instance of the `Bring` task.
+
+ Args:
+ use_peg: A `bool`, whether to replace the ball prop with the peg prop.
+ insert: A `bool`, whether to insert the prop in a receptacle.
+ fully_observable: A `bool`, whether the observation should contain the
+ position and velocity of the object being manipulated and the target
+ location.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._use_peg = use_peg
+ self._target = 'target_peg' if use_peg else 'target_ball'
+ self._object = 'peg' if self._use_peg else 'ball'
+ self._object_joints = ['_'.join([self._object, dim]) for dim in 'xzy']
+ self._receptacle = 'slot' if self._use_peg else 'cup'
+ self._insert = insert
+ self._fully_observable = fully_observable
+ super(Bring, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # Local aliases
+ choice = self.random.choice
+ uniform = self.random.uniform
+ model = physics.named.model
+ data = physics.named.data
+
+ # Find a collision-free random initial configuration.
+ penetrating = True
+ while penetrating:
+
+ # Randomise angles of arm joints.
+ is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
+ joint_range = model.jnt_range[_ARM_JOINTS]
+ lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
+ upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
+ angles = uniform(lower_limits, upper_limits)
+ data.qpos[_ARM_JOINTS] = angles
+
+ # Symmetrize hand.
+ data.qpos['finger'] = data.qpos['thumb']
+
+ # Randomise target location.
+ target_x = uniform(-.4, .4)
+ target_z = uniform(.1, .4)
+ if self._insert:
+ target_angle = uniform(-np.pi/3, np.pi/3)
+ model.body_pos[self._receptacle, ['x', 'z']] = target_x, target_z
+ model.body_quat[self._receptacle, ['qw', 'qy']] = [
+ np.cos(target_angle/2), np.sin(target_angle/2)]
+ else:
+ target_angle = uniform(-np.pi, np.pi)
+
+ model.body_pos[self._target, ['x', 'z']] = target_x, target_z
+ model.body_quat[self._target, ['qw', 'qy']] = [
+ np.cos(target_angle/2), np.sin(target_angle/2)]
+
+ # Randomise object location.
+ object_init_probs = [_P_IN_HAND, _P_IN_TARGET, 1-_P_IN_HAND-_P_IN_TARGET]
+ init_type = choice(['in_hand', 'in_target', 'uniform'],
+ p=object_init_probs)
+ if init_type == 'in_target':
+ object_x = target_x
+ object_z = target_z
+ object_angle = target_angle
+ elif init_type == 'in_hand':
+ physics.after_reset()
+ object_x = data.site_xpos['grasp', 'x']
+ object_z = data.site_xpos['grasp', 'z']
+ grasp_direction = data.site_xmat['grasp', ['xx', 'zx']]
+ object_angle = np.pi-np.arctan2(grasp_direction[1], grasp_direction[0])
+ else:
+ object_x = uniform(-.5, .5)
+ object_z = uniform(0, .7)
+ object_angle = uniform(0, 2*np.pi)
+ data.qvel[self._object + '_x'] = uniform(-5, 5)
+
+ data.qpos[self._object_joints] = object_x, object_z, object_angle
+
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+
+ super(Bring, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns either features or only sensors (to be used with pixels)."""
+ obs = collections.OrderedDict()
+ obs['arm_pos'] = physics.bounded_joint_pos(_ARM_JOINTS)
+ obs['arm_vel'] = physics.joint_vel(_ARM_JOINTS)
+ obs['touch'] = physics.touch()
+ if self._fully_observable:
+ obs['hand_pos'] = physics.body_2d_pose('hand')
+ obs['object_pos'] = physics.body_2d_pose(self._object)
+ obs['object_vel'] = physics.joint_vel(self._object_joints)
+ obs['target_pos'] = physics.body_2d_pose(self._target)
+ return obs
+
+ def _is_close(self, distance):
+ return rewards.tolerance(distance, (0, _CLOSE), _CLOSE*2)
+
+ def _peg_reward(self, physics):
+ """Returns a reward for bringing the peg prop to the target."""
+ grasp = self._is_close(physics.site_distance('peg_grasp', 'grasp'))
+ pinch = self._is_close(physics.site_distance('peg_pinch', 'pinch'))
+ grasping = (grasp + pinch) / 2
+ bring = self._is_close(physics.site_distance('peg', 'target_peg'))
+ bring_tip = self._is_close(physics.site_distance('target_peg_tip',
+ 'peg_tip'))
+ bringing = (bring + bring_tip) / 2
+ return max(bringing, grasping/3)
+
+ def _ball_reward(self, physics):
+ """Returns a reward for bringing the ball prop to the target."""
+ return self._is_close(physics.site_distance('ball', 'target_ball'))
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ if self._use_peg:
+ return self._peg_reward(physics)
+ else:
+ return self._ball_reward(physics)
diff --git a/local_dm_control_suite/manipulator.xml b/local_dm_control_suite/manipulator.xml
new file mode 100755
index 0000000..d6d1767
--- /dev/null
+++ b/local_dm_control_suite/manipulator.xml
@@ -0,0 +1,211 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+ >
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/pendulum.py b/local_dm_control_suite/pendulum.py
new file mode 100755
index 0000000..38f442b
--- /dev/null
+++ b/local_dm_control_suite/pendulum.py
@@ -0,0 +1,114 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Pendulum domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+
+_DEFAULT_TIME_LIMIT = 20
+_ANGLE_BOUND = 8
+_COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND))
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('pendulum.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns pendulum swingup task ."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = SwingUp(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Pendulum domain."""
+
+ def pole_vertical(self):
+ """Returns vertical (z) component of pole frame."""
+ return self.named.data.xmat['pole', 'zz']
+
+ def angular_velocity(self):
+ """Returns the angular velocity of the pole."""
+ return self.named.data.qvel['hinge'].copy()
+
+ def pole_orientation(self):
+ """Returns both horizontal and vertical components of pole frame."""
+ return self.named.data.xmat['pole', ['zz', 'xz']]
+
+
+class SwingUp(base.Task):
+ """A Pendulum `Task` to swing up and balance the pole."""
+
+ def __init__(self, random=None):
+ """Initialize an instance of `Pendulum`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(SwingUp, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Pole is set to a random angle between [-pi, pi).
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
+ super(SwingUp, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation.
+
+ Observations are states concatenating pole orientation and angular velocity
+ and pixels from fixed camera.
+
+ Args:
+ physics: An instance of `physics`, Pendulum physics.
+
+ Returns:
+ A `dict` of observation.
+ """
+ obs = collections.OrderedDict()
+ obs['orientation'] = physics.pole_orientation()
+ obs['velocity'] = physics.angular_velocity()
+ return obs
+
+ def get_reward(self, physics):
+ return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1))
diff --git a/local_dm_control_suite/pendulum.xml b/local_dm_control_suite/pendulum.xml
new file mode 100755
index 0000000..14377ae
--- /dev/null
+++ b/local_dm_control_suite/pendulum.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/point_mass.py b/local_dm_control_suite/point_mass.py
new file mode 100755
index 0000000..b45ba17
--- /dev/null
+++ b/local_dm_control_suite/point_mass.py
@@ -0,0 +1,130 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Point-mass domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 20
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('point_mass.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'easy')
+def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the easy point_mass task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PointMass(randomize_gains=False, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add()
+def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the hard point_mass task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PointMass(randomize_gains=True, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """physics for the point_mass domain."""
+
+ def mass_to_target(self):
+ """Returns the vector from mass to target in global coordinate."""
+ return (self.named.data.geom_xpos['target'] -
+ self.named.data.geom_xpos['pointmass'])
+
+ def mass_to_target_dist(self):
+ """Returns the distance from mass to the target."""
+ return np.linalg.norm(self.mass_to_target())
+
+
+class PointMass(base.Task):
+ """A point_mass `Task` to reach target with smooth reward."""
+
+ def __init__(self, randomize_gains, random=None):
+ """Initialize an instance of `PointMass`.
+
+ Args:
+ randomize_gains: A `bool`, whether to randomize the actuator gains.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._randomize_gains = randomize_gains
+ super(PointMass, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ If _randomize_gains is True, the relationship between the controls and
+ the joints is randomized, so that each control actuates a random linear
+ combination of joints.
+
+ Args:
+ physics: An instance of `mujoco.Physics`.
+ """
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ if self._randomize_gains:
+ dir1 = self.random.randn(2)
+ dir1 /= np.linalg.norm(dir1)
+ # Find another actuation direction that is not 'too parallel' to dir1.
+ parallel = True
+ while parallel:
+ dir2 = self.random.randn(2)
+ dir2 /= np.linalg.norm(dir2)
+ parallel = abs(np.dot(dir1, dir2)) > 0.9
+ physics.model.wrap_prm[[0, 1]] = dir1
+ physics.model.wrap_prm[[2, 3]] = dir2
+ super(PointMass, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of the state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ target_size = physics.named.model.geom_size['target', 0]
+ near_target = rewards.tolerance(physics.mass_to_target_dist(),
+ bounds=(0, target_size), margin=target_size)
+ control_reward = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (control_reward + 4) / 5
+ return near_target * small_control
diff --git a/local_dm_control_suite/point_mass.xml b/local_dm_control_suite/point_mass.xml
new file mode 100755
index 0000000..c447cf6
--- /dev/null
+++ b/local_dm_control_suite/point_mass.xml
@@ -0,0 +1,49 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/quadruped.py b/local_dm_control_suite/quadruped.py
new file mode 100755
index 0000000..9e326d7
--- /dev/null
+++ b/local_dm_control_suite/quadruped.py
@@ -0,0 +1,480 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Quadruped Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+
+from lxml import etree
+import numpy as np
+from scipy import ndimage
+
+enums = mjbindings.enums
+mjlib = mjbindings.mjlib
+
+
+_DEFAULT_TIME_LIMIT = 20
+_CONTROL_TIMESTEP = .02
+
+# Horizontal speeds above which the move reward is 1.
+_RUN_SPEED = 5
+_WALK_SPEED = 0.5
+
+# Constants related to terrain generation.
+_HEIGHTFIELD_ID = 0
+_TERRAIN_SMOOTHNESS = 0.15 # 0.0: maximally bumpy; 1.0: completely smooth.
+_TERRAIN_BUMP_SCALE = 2 # Spatial scale of terrain bumps (in meters).
+
+# Named model elements.
+_TOES = ['toe_front_left', 'toe_back_left', 'toe_back_right', 'toe_front_right']
+_WALLS = ['wall_px', 'wall_py', 'wall_nx', 'wall_ny']
+
+SUITE = containers.TaggedTasks()
+
+
+def make_model(floor_size=None, terrain=False, rangefinders=False,
+ walls_and_ball=False):
+ """Returns the model XML string."""
+ xml_string = common.read_model('quadruped.xml')
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # Set floor size.
+ if floor_size is not None:
+ floor_geom = mjcf.find('.//geom[@name={!r}]'.format('floor'))
+ floor_geom.attrib['size'] = '{} {} .5'.format(floor_size, floor_size)
+
+ # Remove walls, ball and target.
+ if not walls_and_ball:
+ for wall in _WALLS:
+ wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
+ wall_geom.getparent().remove(wall_geom)
+
+ # Remove ball.
+ ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
+ ball_body.getparent().remove(ball_body)
+
+ # Remove target.
+ target_site = xml_tools.find_element(mjcf, 'site', 'target')
+ target_site.getparent().remove(target_site)
+
+ # Remove terrain.
+ if not terrain:
+ terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
+ terrain_geom.getparent().remove(terrain_geom)
+
+ # Remove rangefinders if they're not used, as range computations can be
+ # expensive, especially in a scene with heightfields.
+ if not rangefinders:
+ rangefinder_sensors = mjcf.findall('.//rangefinder')
+ for rf in rangefinder_sensors:
+ rf.getparent().remove(rf)
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+@SUITE.add()
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Move(desired_speed=_WALK_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _RUN_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Move(desired_speed=_RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def escape(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Escape task."""
+ xml_string = make_model(floor_size=40, terrain=True, rangefinders=True)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Escape(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def fetch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Fetch task."""
+ xml_string = make_model(walls_and_ball=True)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Fetch(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Quadruped domain."""
+
+ def _reload_from_data(self, data):
+ super(Physics, self)._reload_from_data(data)
+ # Clear cached sensor names when the physics is reloaded.
+ self._sensor_types_to_names = {}
+ self._hinge_names = []
+
+ def _get_sensor_names(self, *sensor_types):
+ try:
+ sensor_names = self._sensor_types_to_names[sensor_types]
+ except KeyError:
+ [sensor_ids] = np.where(np.in1d(self.model.sensor_type, sensor_types))
+ sensor_names = [self.model.id2name(s_id, 'sensor') for s_id in sensor_ids]
+ self._sensor_types_to_names[sensor_types] = sensor_names
+ return sensor_names
+
+ def torso_upright(self):
+ """Returns the dot-product of the torso z-axis and the global z-axis."""
+ return np.asarray(self.named.data.xmat['torso', 'zz'])
+
+ def torso_velocity(self):
+ """Returns the velocity of the torso, in the local frame."""
+ return self.named.data.sensordata['velocimeter'].copy()
+
+ def egocentric_state(self):
+ """Returns the state without global orientation or position."""
+ if not self._hinge_names:
+ [hinge_ids] = np.nonzero(self.model.jnt_type ==
+ enums.mjtJoint.mjJNT_HINGE)
+ self._hinge_names = [self.model.id2name(j_id, 'joint')
+ for j_id in hinge_ids]
+ return np.hstack((self.named.data.qpos[self._hinge_names],
+ self.named.data.qvel[self._hinge_names],
+ self.data.act))
+
+ def toe_positions(self):
+ """Returns toe positions in egocentric frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ torso_to_toe = self.named.data.xpos[_TOES] - torso_pos
+ return torso_to_toe.dot(torso_frame)
+
+ def force_torque(self):
+ """Returns scaled force/torque sensor readings at the toes."""
+ force_torque_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_FORCE,
+ enums.mjtSensor.mjSENS_TORQUE)
+ return np.arcsinh(self.named.data.sensordata[force_torque_sensors])
+
+ def imu(self):
+ """Returns IMU-like sensor readings."""
+ imu_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_GYRO,
+ enums.mjtSensor.mjSENS_ACCELEROMETER)
+ return self.named.data.sensordata[imu_sensors]
+
+ def rangefinder(self):
+ """Returns scaled rangefinder sensor readings."""
+ rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER)
+ rf_readings = self.named.data.sensordata[rf_sensors]
+ no_intersection = -1.0
+ return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings))
+
+ def origin_distance(self):
+ """Returns the distance from the origin to the workspace."""
+ return np.asarray(np.linalg.norm(self.named.data.site_xpos['workspace']))
+
+ def origin(self):
+ """Returns origin position in the torso frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ return -torso_pos.dot(torso_frame)
+
+ def ball_state(self):
+ """Returns ball position and velocity relative to the torso frame."""
+ data = self.named.data
+ torso_frame = data.xmat['torso'].reshape(3, 3)
+ ball_rel_pos = data.xpos['ball'] - data.xpos['torso']
+ ball_rel_vel = data.qvel['ball_root'][:3] - data.qvel['root'][:3]
+ ball_rot_vel = data.qvel['ball_root'][3:]
+ ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel))
+ return ball_state.dot(torso_frame).ravel()
+
+ def target_position(self):
+ """Returns target position in torso frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ torso_to_target = self.named.data.site_xpos['target'] - torso_pos
+ return torso_to_target.dot(torso_frame)
+
+ def ball_to_target_distance(self):
+ """Returns horizontal distance from the ball to the target."""
+ ball_to_target = (self.named.data.site_xpos['target'] -
+ self.named.data.xpos['ball'])
+ return np.linalg.norm(ball_to_target[:2])
+
+ def self_to_ball_distance(self):
+ """Returns horizontal distance from the quadruped workspace to the ball."""
+ self_to_ball = (self.named.data.site_xpos['workspace']
+ -self.named.data.xpos['ball'])
+ return np.linalg.norm(self_to_ball[:2])
+
+
+def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0):
+ """Find a height with no contacts given a body orientation.
+
+ Args:
+ physics: An instance of `Physics`.
+ orientation: A quaternion.
+ x_pos: A float. Position along global x-axis.
+ y_pos: A float. Position along global y-axis.
+ Raises:
+ RuntimeError: If a non-contacting configuration has not been found after
+ 10,000 attempts.
+ """
+ z_pos = 0.0 # Start embedded in the floor.
+ num_contacts = 1
+ num_attempts = 0
+ # Move up in 1cm increments until no contacts.
+ while num_contacts > 0:
+ try:
+ with physics.reset_context():
+ physics.named.data.qpos['root'][:3] = x_pos, y_pos, z_pos
+ physics.named.data.qpos['root'][3:] = orientation
+ except control.PhysicsError:
+ # We may encounter a PhysicsError here due to filling the contact
+ # buffer, in which case we simply increment the height and continue.
+ pass
+ num_contacts = physics.data.ncon
+ z_pos += 0.01
+ num_attempts += 1
+ if num_attempts > 10000:
+ raise RuntimeError('Failed to find a non-contacting configuration.')
+
+
+def _common_observations(physics):
+ """Returns the observations common to all tasks."""
+ obs = collections.OrderedDict()
+ obs['egocentric_state'] = physics.egocentric_state()
+ obs['torso_velocity'] = physics.torso_velocity()
+ obs['torso_upright'] = physics.torso_upright()
+ obs['imu'] = physics.imu()
+ obs['force_torque'] = physics.force_torque()
+ return obs
+
+
+def _upright_reward(physics, deviation_angle=0):
+ """Returns a reward proportional to how upright the torso is.
+
+ Args:
+ physics: an instance of `Physics`.
+ deviation_angle: A float, in degrees. The reward is 0 when the torso is
+ exactly upside-down and 1 when the torso's z-axis is less than
+ `deviation_angle` away from the global z-axis.
+ """
+ deviation = np.cos(np.deg2rad(deviation_angle))
+ return rewards.tolerance(
+ physics.torso_upright(),
+ bounds=(deviation, float('inf')),
+ sigmoid='linear',
+ margin=1 + deviation,
+ value_at_margin=0)
+
+
+class Move(base.Task):
+ """A quadruped task solved by moving forward at a designated speed."""
+
+ def __init__(self, desired_speed, random=None):
+ """Initializes an instance of `Move`.
+
+ Args:
+ desired_speed: A float. If this value is zero, reward is given simply
+ for standing upright. Otherwise this specifies the horizontal velocity
+ at which the velocity-dependent reward component is maximized.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._desired_speed = desired_speed
+ super(Move, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+ super(Move, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ return _common_observations(physics)
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Move reward term.
+ move_reward = rewards.tolerance(
+ physics.torso_velocity()[0],
+ bounds=(self._desired_speed, float('inf')),
+ margin=self._desired_speed,
+ value_at_margin=0.5,
+ sigmoid='linear')
+
+ return _upright_reward(physics) * move_reward
+
+
+class Escape(base.Task):
+ """A quadruped task solved by escaping a bowl-shaped terrain."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Get heightfield resolution, assert that it is square.
+ res = physics.model.hfield_nrow[_HEIGHTFIELD_ID]
+ assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID]
+ # Sinusoidal bowl shape.
+ row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j]
+ radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .04, 1)
+ bowl_shape = .5 - np.cos(2*np.pi*radius)/2
+ # Random smooth bumps.
+ terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
+ bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
+ smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
+ # Terrain is elementwise product.
+ terrain = bowl_shape * smooth_bumps
+ start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID]
+ physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel()
+ super(Escape, self).initialize_episode(physics)
+
+ # If we have a rendering context, we need to re-upload the modified
+ # heightfield data.
+ if physics.contexts:
+ with physics.contexts.gl.make_current() as ctx:
+ ctx.call(mjlib.mjr_uploadHField,
+ physics.model.ptr,
+ physics.contexts.mujoco.ptr,
+ _HEIGHTFIELD_ID)
+
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ obs = _common_observations(physics)
+ obs['origin'] = physics.origin()
+ obs['rangefinder'] = physics.rangefinder()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Escape reward term.
+ terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ escape_reward = rewards.tolerance(
+ physics.origin_distance(),
+ bounds=(terrain_size, float('inf')),
+ margin=terrain_size,
+ value_at_margin=0,
+ sigmoid='linear')
+
+ return _upright_reward(physics, deviation_angle=20) * escape_reward
+
+
+class Fetch(base.Task):
+ """A quadruped task solved by bringing a ball to the origin."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Initial configuration, random azimuth and horizontal position.
+ azimuth = self.random.uniform(0, 2*np.pi)
+ orientation = np.array((np.cos(azimuth/2), 0, 0, np.sin(azimuth/2)))
+ spawn_radius = 0.9 * physics.named.model.geom_size['floor', 0]
+ x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,))
+ _find_non_contacting_height(physics, orientation, x_pos, y_pos)
+
+ # Initial ball state.
+ physics.named.data.qpos['ball_root'][:2] = self.random.uniform(
+ -spawn_radius, spawn_radius, size=(2,))
+ physics.named.data.qpos['ball_root'][2] = 2
+ physics.named.data.qvel['ball_root'][:2] = 5*self.random.randn(2)
+ super(Fetch, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ obs = _common_observations(physics)
+ obs['ball_state'] = physics.ball_state()
+ obs['target_position'] = physics.target_position()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Reward for moving close to the ball.
+ arena_radius = physics.named.model.geom_size['floor', 0] * np.sqrt(2)
+ workspace_radius = physics.named.model.site_size['workspace', 0]
+ ball_radius = physics.named.model.geom_size['ball', 0]
+ reach_reward = rewards.tolerance(
+ physics.self_to_ball_distance(),
+ bounds=(0, workspace_radius+ball_radius),
+ sigmoid='linear',
+ margin=arena_radius, value_at_margin=0)
+
+ # Reward for bringing the ball to the target.
+ target_radius = physics.named.model.site_size['target', 0]
+ fetch_reward = rewards.tolerance(
+ physics.ball_to_target_distance(),
+ bounds=(0, target_radius),
+ sigmoid='linear',
+ margin=arena_radius, value_at_margin=0)
+
+ reach_then_fetch = reach_reward * (0.5 + 0.5*fetch_reward)
+
+ return _upright_reward(physics) * reach_then_fetch
diff --git a/local_dm_control_suite/quadruped.xml b/local_dm_control_suite/quadruped.xml
new file mode 100755
index 0000000..958d2c0
--- /dev/null
+++ b/local_dm_control_suite/quadruped.xml
@@ -0,0 +1,329 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/reacher.py b/local_dm_control_suite/reacher.py
new file mode 100755
index 0000000..feea8b4
--- /dev/null
+++ b/local_dm_control_suite/reacher.py
@@ -0,0 +1,116 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Reacher domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+import numpy as np
+
+SUITE = containers.TaggedTasks()
+_DEFAULT_TIME_LIMIT = 20
+_BIG_TARGET = .05
+_SMALL_TARGET = .015
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('reacher.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'easy')
+def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns reacher with sparse reward with 5e-2 tol and randomized target."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Reacher(target_size=_BIG_TARGET, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns reacher with sparse reward with 1e-2 tol and randomized target."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Reacher(target_size=_SMALL_TARGET, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Reacher domain."""
+
+ def finger_to_target(self):
+ """Returns the vector from target to finger in global coordinates."""
+ return (self.named.data.geom_xpos['target', :2] -
+ self.named.data.geom_xpos['finger', :2])
+
+ def finger_to_target_dist(self):
+ """Returns the signed distance between the finger and target surface."""
+ return np.linalg.norm(self.finger_to_target())
+
+
+class Reacher(base.Task):
+ """A reacher `Task` to reach the target."""
+
+ def __init__(self, target_size, random=None):
+ """Initialize an instance of `Reacher`.
+
+ Args:
+ target_size: A `float`, tolerance to determine whether finger reached the
+ target.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._target_size = target_size
+ super(Reacher, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ physics.named.model.geom_size['target', 0] = self._target_size
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+
+ # Randomize target position
+ angle = self.random.uniform(0, 2 * np.pi)
+ radius = self.random.uniform(.05, .20)
+ physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
+ physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
+
+ super(Reacher, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of the state and the target position."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['to_target'] = physics.finger_to_target()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ radii = physics.named.model.geom_size[['target', 'finger'], 0].sum()
+ return rewards.tolerance(physics.finger_to_target_dist(), (0, radii))
diff --git a/local_dm_control_suite/reacher.xml b/local_dm_control_suite/reacher.xml
new file mode 100755
index 0000000..343f799
--- /dev/null
+++ b/local_dm_control_suite/reacher.xml
@@ -0,0 +1,47 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/stacker.py b/local_dm_control_suite/stacker.py
new file mode 100755
index 0000000..6d4d49c
--- /dev/null
+++ b/local_dm_control_suite/stacker.py
@@ -0,0 +1,208 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Planar Stacker domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+
+from lxml import etree
+import numpy as np
+
+
+_CLOSE = .01 # (Meters) Distance below which a thing is considered close.
+_CONTROL_TIMESTEP = .01 # (Seconds)
+_TIME_LIMIT = 10 # (Seconds)
+_ARM_JOINTS = ['arm_root', 'arm_shoulder', 'arm_elbow', 'arm_wrist',
+ 'finger', 'fingertip', 'thumb', 'thumbtip']
+
+SUITE = containers.TaggedTasks()
+
+
+def make_model(n_boxes):
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ xml_string = common.read_model('stacker.xml')
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # Remove unused boxes
+ for b in range(n_boxes, 4):
+ box = xml_tools.find_element(mjcf, 'body', 'box' + str(b))
+ box.getparent().remove(box)
+
+ return etree.tostring(mjcf, pretty_print=True), common.ASSETS
+
+
+@SUITE.add('hard')
+def stack_2(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns stacker task with 2 boxes."""
+ n_boxes = 2
+ physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
+ task = Stack(n_boxes=n_boxes,
+ fully_observable=fully_observable,
+ random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+@SUITE.add('hard')
+def stack_4(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns stacker task with 4 boxes."""
+ n_boxes = 4
+ physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
+ task = Stack(n_boxes=n_boxes,
+ fully_observable=fully_observable,
+ random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics with additional features for the Planar Manipulator domain."""
+
+ def bounded_joint_pos(self, joint_names):
+ """Returns joint positions as (sin, cos) values."""
+ joint_pos = self.named.data.qpos[joint_names]
+ return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T
+
+ def joint_vel(self, joint_names):
+ """Returns joint velocities."""
+ return self.named.data.qvel[joint_names]
+
+ def body_2d_pose(self, body_names, orientation=True):
+ """Returns positions and/or orientations of bodies."""
+ if not isinstance(body_names, str):
+ body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices.
+ pos = self.named.data.xpos[body_names, ['x', 'z']]
+ if orientation:
+ ori = self.named.data.xquat[body_names, ['qw', 'qy']]
+ return np.hstack([pos, ori])
+ else:
+ return pos
+
+ def touch(self):
+ return np.log1p(self.data.sensordata)
+
+ def site_distance(self, site1, site2):
+ site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0)
+ return np.linalg.norm(site1_to_site2)
+
+
+class Stack(base.Task):
+ """A Stack `Task`: stack the boxes."""
+
+ def __init__(self, n_boxes, fully_observable, random=None):
+ """Initialize an instance of the `Stack` task.
+
+ Args:
+ n_boxes: An `int`, number of boxes to stack.
+ fully_observable: A `bool`, whether the observation should contain the
+ positions and velocities of the boxes and the location of the target.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._n_boxes = n_boxes
+ self._box_names = ['box' + str(b) for b in range(n_boxes)]
+ self._box_joint_names = []
+ for name in self._box_names:
+ for dim in 'xyz':
+ self._box_joint_names.append('_'.join([name, dim]))
+ self._fully_observable = fully_observable
+ super(Stack, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # Local aliases
+ randint = self.random.randint
+ uniform = self.random.uniform
+ model = physics.named.model
+ data = physics.named.data
+
+ # Find a collision-free random initial configuration.
+ penetrating = True
+ while penetrating:
+
+ # Randomise angles of arm joints.
+ is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
+ joint_range = model.jnt_range[_ARM_JOINTS]
+ lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
+ upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
+ angles = uniform(lower_limits, upper_limits)
+ data.qpos[_ARM_JOINTS] = angles
+
+ # Symmetrize hand.
+ data.qpos['finger'] = data.qpos['thumb']
+
+ # Randomise target location.
+ target_height = 2*randint(self._n_boxes) + 1
+ box_size = model.geom_size['target', 0]
+ model.body_pos['target', 'z'] = box_size * target_height
+ model.body_pos['target', 'x'] = uniform(-.37, .37)
+
+ # Randomise box locations.
+ for name in self._box_names:
+ data.qpos[name + '_x'] = uniform(.1, .3)
+ data.qpos[name + '_z'] = uniform(0, .7)
+ data.qpos[name + '_y'] = uniform(0, 2*np.pi)
+
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+
+ super(Stack, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns either features or only sensors (to be used with pixels)."""
+ obs = collections.OrderedDict()
+ obs['arm_pos'] = physics.bounded_joint_pos(_ARM_JOINTS)
+ obs['arm_vel'] = physics.joint_vel(_ARM_JOINTS)
+ obs['touch'] = physics.touch()
+ if self._fully_observable:
+ obs['hand_pos'] = physics.body_2d_pose('hand')
+ obs['box_pos'] = physics.body_2d_pose(self._box_names)
+ obs['box_vel'] = physics.joint_vel(self._box_joint_names)
+ obs['target_pos'] = physics.body_2d_pose('target', orientation=False)
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ box_size = physics.named.model.geom_size['target', 0]
+ min_box_to_target_distance = min(physics.site_distance(name, 'target')
+ for name in self._box_names)
+ box_is_close = rewards.tolerance(min_box_to_target_distance,
+ margin=2*box_size)
+ hand_to_target_distance = physics.site_distance('grasp', 'target')
+ hand_is_far = rewards.tolerance(hand_to_target_distance,
+ bounds=(.1, float('inf')),
+ margin=_CLOSE)
+ return box_is_close * hand_is_far
diff --git a/local_dm_control_suite/stacker.xml b/local_dm_control_suite/stacker.xml
new file mode 100755
index 0000000..7af4877
--- /dev/null
+++ b/local_dm_control_suite/stacker.xml
@@ -0,0 +1,193 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+ >
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/swimmer.py b/local_dm_control_suite/swimmer.py
new file mode 100755
index 0000000..96fd8ea
--- /dev/null
+++ b/local_dm_control_suite/swimmer.py
@@ -0,0 +1,215 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Procedurally generated Swimmer domain."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from lxml import etree
+import numpy as np
+from six.moves import range
+
+_DEFAULT_TIME_LIMIT = 30
+_CONTROL_TIMESTEP = .03 # (Seconds)
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets(n_joints):
+ """Returns a tuple containing the model XML string and a dict of assets.
+
+ Args:
+ n_joints: An integer specifying the number of joints in the swimmer.
+
+ Returns:
+ A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of
+ `{filename: contents_string}` pairs.
+ """
+ return _make_model(n_joints), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def swimmer6(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns a 6-link swimmer."""
+ return _make_swimmer(6, time_limit, random=random,
+ environment_kwargs=environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def swimmer15(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns a 15-link swimmer."""
+ return _make_swimmer(15, time_limit, random=random,
+ environment_kwargs=environment_kwargs)
+
+
+def swimmer(n_links=3, time_limit=_DEFAULT_TIME_LIMIT,
+ random=None, environment_kwargs=None):
+ """Returns a swimmer with n links."""
+ return _make_swimmer(n_links, time_limit, random=random,
+ environment_kwargs=environment_kwargs)
+
+
+def _make_swimmer(n_joints, time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns a swimmer control environment."""
+ model_string, assets = get_model_and_assets(n_joints)
+ physics = Physics.from_xml_string(model_string, assets=assets)
+ task = Swimmer(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+def _make_model(n_bodies):
+ """Generates an xml string defining a swimmer with `n_bodies` bodies."""
+ if n_bodies < 3:
+ raise ValueError('At least 3 bodies required. Received {}'.format(n_bodies))
+ mjcf = etree.fromstring(common.read_model('swimmer.xml'))
+ head_body = mjcf.find('./worldbody/body')
+ actuator = etree.SubElement(mjcf, 'actuator')
+ sensor = etree.SubElement(mjcf, 'sensor')
+
+ parent = head_body
+ for body_index in range(n_bodies - 1):
+ site_name = 'site_{}'.format(body_index)
+ child = _make_body(body_index=body_index)
+ child.append(etree.Element('site', name=site_name))
+ joint_name = 'joint_{}'.format(body_index)
+ joint_limit = 360.0/n_bodies
+ joint_range = '{} {}'.format(-joint_limit, joint_limit)
+ child.append(etree.Element('joint', {'name': joint_name,
+ 'range': joint_range}))
+ motor_name = 'motor_{}'.format(body_index)
+ actuator.append(etree.Element('motor', name=motor_name, joint=joint_name))
+ velocimeter_name = 'velocimeter_{}'.format(body_index)
+ sensor.append(etree.Element('velocimeter', name=velocimeter_name,
+ site=site_name))
+ gyro_name = 'gyro_{}'.format(body_index)
+ sensor.append(etree.Element('gyro', name=gyro_name, site=site_name))
+ parent.append(child)
+ parent = child
+
+ # Move tracking cameras further away from the swimmer according to its length.
+ cameras = mjcf.findall('./worldbody/body/camera')
+ scale = n_bodies / 6.0
+ for cam in cameras:
+ if cam.get('mode') == 'trackcom':
+ old_pos = cam.get('pos').split(' ')
+ new_pos = ' '.join([str(float(dim) * scale) for dim in old_pos])
+ cam.set('pos', new_pos)
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+def _make_body(body_index):
+ """Generates an xml string defining a single physical body."""
+ body_name = 'segment_{}'.format(body_index)
+ visual_name = 'visual_{}'.format(body_index)
+ inertial_name = 'inertial_{}'.format(body_index)
+ body = etree.Element('body', name=body_name)
+ body.set('pos', '0 .1 0')
+ etree.SubElement(body, 'geom', {'class': 'visual', 'name': visual_name})
+ etree.SubElement(body, 'geom', {'class': 'inertial', 'name': inertial_name})
+ return body
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the swimmer domain."""
+
+ def nose_to_target(self):
+ """Returns a vector from nose to target in local coordinate of the head."""
+ nose_to_target = (self.named.data.geom_xpos['target'] -
+ self.named.data.geom_xpos['nose'])
+ head_orientation = self.named.data.xmat['head'].reshape(3, 3)
+ return nose_to_target.dot(head_orientation)[:2]
+
+ def nose_to_target_dist(self):
+ """Returns the distance from the nose to the target."""
+ return np.linalg.norm(self.nose_to_target())
+
+ def body_velocities(self):
+ """Returns local body velocities: x,y linear, z rotational."""
+ xvel_local = self.data.sensordata[12:].reshape((-1, 6))
+ vx_vy_wz = [0, 1, 5] # Indices for linear x,y vels and rotational z vel.
+ return xvel_local[:, vx_vy_wz].ravel()
+
+ def joints(self):
+ """Returns all internal joint angles (excluding root joints)."""
+ return self.data.qpos[3:].copy()
+
+
+class Swimmer(base.Task):
+ """A swimmer `Task` to reach the target or just swim."""
+
+ def __init__(self, random=None):
+ """Initializes an instance of `Swimmer`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(Swimmer, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Initializes the swimmer orientation to [-pi, pi) and the relative joint
+ angle of each joint uniformly within its range.
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Random joint angles:
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ # Random target position.
+ close_target = self.random.rand() < .2 # Probability of a close target.
+ target_box = .3 if close_target else 2
+ xpos, ypos = self.random.uniform(-target_box, target_box, size=2)
+ physics.named.model.geom_pos['target', 'x'] = xpos
+ physics.named.model.geom_pos['target', 'y'] = ypos
+ physics.named.model.light_pos['target_light', 'x'] = xpos
+ physics.named.model.light_pos['target_light', 'y'] = ypos
+
+ super(Swimmer, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of joint angles, body velocities and target."""
+ obs = collections.OrderedDict()
+ obs['joints'] = physics.joints()
+ obs['to_target'] = physics.nose_to_target()
+ obs['body_velocities'] = physics.body_velocities()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a smooth reward."""
+ target_size = physics.named.model.geom_size['target', 0]
+ return rewards.tolerance(physics.nose_to_target_dist(),
+ bounds=(0, target_size),
+ margin=5*target_size,
+ sigmoid='long_tail')
diff --git a/local_dm_control_suite/swimmer.xml b/local_dm_control_suite/swimmer.xml
new file mode 100755
index 0000000..29c7bc8
--- /dev/null
+++ b/local_dm_control_suite/swimmer.xml
@@ -0,0 +1,57 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/tests/domains_test.py b/local_dm_control_suite/tests/domains_test.py
new file mode 100755
index 0000000..4c148cf
--- /dev/null
+++ b/local_dm_control_suite/tests/domains_test.py
@@ -0,0 +1,292 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.suite domains."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import suite
+from dm_control.rl import control
+import mock
+import numpy as np
+import six
+from six.moves import range
+from six.moves import zip
+
+
+def uniform_random_policy(action_spec, random=None):
+ lower_bounds = action_spec.minimum
+ upper_bounds = action_spec.maximum
+ # Draw values between -1 and 1 for unbounded actions.
+ lower_bounds = np.where(np.isinf(lower_bounds), -1.0, lower_bounds)
+ upper_bounds = np.where(np.isinf(upper_bounds), 1.0, upper_bounds)
+ random_state = np.random.RandomState(random)
+ def policy(time_step):
+ del time_step # Unused.
+ return random_state.uniform(lower_bounds, upper_bounds)
+ return policy
+
+
+def step_environment(env, policy, num_episodes=5, max_steps_per_episode=10):
+ for _ in range(num_episodes):
+ step_count = 0
+ time_step = env.reset()
+ yield time_step
+ while not time_step.last():
+ action = policy(time_step)
+ time_step = env.step(action)
+ step_count += 1
+ yield time_step
+ if step_count >= max_steps_per_episode:
+ break
+
+
+def make_trajectory(domain, task, seed, **trajectory_kwargs):
+ env = suite.load(domain, task, task_kwargs={'random': seed})
+ policy = uniform_random_policy(env.action_spec(), random=seed)
+ return step_environment(env, policy, **trajectory_kwargs)
+
+
+class DomainTest(parameterized.TestCase):
+ """Tests run on all the tasks registered."""
+
+ def test_constants(self):
+ num_tasks = sum(len(tasks) for tasks in
+ six.itervalues(suite.TASKS_BY_DOMAIN))
+
+ self.assertLen(suite.ALL_TASKS, num_tasks)
+
+ def _validate_observation(self, observation_dict, observation_spec):
+ obs = observation_dict.copy()
+ for name, spec in six.iteritems(observation_spec):
+ arr = obs.pop(name)
+ self.assertEqual(arr.shape, spec.shape)
+ self.assertEqual(arr.dtype, spec.dtype)
+ self.assertTrue(
+ np.all(np.isfinite(arr)),
+ msg='{!r} has non-finite value(s): {!r}'.format(name, arr))
+ self.assertEmpty(
+ obs,
+ msg='Observation contains arrays(s) that are not in the spec: {!r}'
+ .format(obs))
+
+ def _validate_reward_range(self, time_step):
+ if time_step.first():
+ self.assertIsNone(time_step.reward)
+ else:
+ self.assertIsInstance(time_step.reward, float)
+ self.assertBetween(time_step.reward, 0, 1)
+
+ def _validate_discount(self, time_step):
+ if time_step.first():
+ self.assertIsNone(time_step.discount)
+ else:
+ self.assertIsInstance(time_step.discount, float)
+ self.assertBetween(time_step.discount, 0, 1)
+
+ def _validate_control_range(self, lower_bounds, upper_bounds):
+ for b in lower_bounds:
+ self.assertEqual(b, -1.0)
+ for b in upper_bounds:
+ self.assertEqual(b, 1.0)
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_components_have_names(self, domain, task):
+ env = suite.load(domain, task)
+ model = env.physics.model
+
+ object_types_and_size_fields = [
+ ('body', 'nbody'),
+ ('joint', 'njnt'),
+ ('geom', 'ngeom'),
+ ('site', 'nsite'),
+ ('camera', 'ncam'),
+ ('light', 'nlight'),
+ ('mesh', 'nmesh'),
+ ('hfield', 'nhfield'),
+ ('texture', 'ntex'),
+ ('material', 'nmat'),
+ ('equality', 'neq'),
+ ('tendon', 'ntendon'),
+ ('actuator', 'nu'),
+ ('sensor', 'nsensor'),
+ ('numeric', 'nnumeric'),
+ ('text', 'ntext'),
+ ('tuple', 'ntuple'),
+ ]
+ for object_type, size_field in object_types_and_size_fields:
+ for idx in range(getattr(model, size_field)):
+ object_name = model.id2name(idx, object_type)
+ self.assertNotEqual(object_name, '',
+ msg='Model {!r} contains unnamed {!r} with ID {}.'
+ .format(model.name, object_type, idx))
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_model_has_at_least_2_cameras(self, domain, task):
+ env = suite.load(domain, task)
+ model = env.physics.model
+ self.assertGreaterEqual(model.ncam, 2,
+ 'Model {!r} should have at least 2 cameras, has {}.'
+ .format(model.name, model.ncam))
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_task_conforms_to_spec(self, domain, task):
+ """Tests that the environment timesteps conform to specifications."""
+ is_benchmark = (domain, task) in suite.BENCHMARKING
+ env = suite.load(domain, task)
+ observation_spec = env.observation_spec()
+ action_spec = env.action_spec()
+
+ # Check action bounds.
+ if is_benchmark:
+ self._validate_control_range(action_spec.minimum, action_spec.maximum)
+
+ # Step through the environment, applying random actions sampled within the
+ # valid range and check the observations, rewards, and discounts.
+ policy = uniform_random_policy(action_spec)
+ for time_step in step_environment(env, policy):
+ self._validate_observation(time_step.observation, observation_spec)
+ self._validate_discount(time_step)
+ if is_benchmark:
+ self._validate_reward_range(time_step)
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_environment_is_deterministic(self, domain, task):
+ """Tests that identical seeds and actions produce identical trajectories."""
+ seed = 0
+ # Iterate over two trajectories generated using identical sequences of
+ # random actions, and with identical task random states. Check that the
+ # observations, rewards, discounts and step types are identical.
+ trajectory1 = make_trajectory(domain=domain, task=task, seed=seed)
+ trajectory2 = make_trajectory(domain=domain, task=task, seed=seed)
+ for time_step1, time_step2 in zip(trajectory1, trajectory2):
+ self.assertEqual(time_step1.step_type, time_step2.step_type)
+ self.assertEqual(time_step1.reward, time_step2.reward)
+ self.assertEqual(time_step1.discount, time_step2.discount)
+ for key in six.iterkeys(time_step1.observation):
+ np.testing.assert_array_equal(
+ time_step1.observation[key], time_step2.observation[key],
+ err_msg='Observation {!r} is not equal.'.format(key))
+
+ def assertCorrectColors(self, physics, reward):
+ colors = physics.named.model.mat_rgba
+ for material_name in ('self', 'effector', 'target'):
+ highlight = colors[material_name + '_highlight']
+ default = colors[material_name + '_default']
+ blend_coef = reward ** 4
+ expected = blend_coef * highlight + (1.0 - blend_coef) * default
+ actual = colors[material_name]
+ err_msg = ('Material {!r} has unexpected color.\nExpected: {!r}\n'
+ 'Actual: {!r}'.format(material_name, expected, actual))
+ np.testing.assert_array_almost_equal(expected, actual, err_msg=err_msg)
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_visualize_reward(self, domain, task):
+ env = suite.load(domain, task)
+ env.task.visualize_reward = True
+ action = np.zeros(env.action_spec().shape)
+
+ with mock.patch.object(env.task, 'get_reward') as mock_get_reward:
+ mock_get_reward.return_value = -3.0 # Rewards < 0 should be clipped.
+ env.reset()
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=0.0)
+
+ mock_get_reward.reset_mock()
+ mock_get_reward.return_value = 0.5
+ env.step(action)
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)
+
+ mock_get_reward.reset_mock()
+ mock_get_reward.return_value = 2.0 # Rewards > 1 should be clipped.
+ env.step(action)
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=1.0)
+
+ mock_get_reward.reset_mock()
+ mock_get_reward.return_value = 0.25
+ env.reset()
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_task_supports_environment_kwargs(self, domain, task):
+ env = suite.load(domain, task,
+ environment_kwargs=dict(flat_observation=True))
+ # Check that the kwargs are actually passed through to the environment.
+ self.assertSetEqual(set(env.observation_spec()),
+ {control.FLAT_OBSERVATION_KEY})
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_observation_arrays_dont_share_memory(self, domain, task):
+ env = suite.load(domain, task)
+ first_timestep = env.reset()
+ action = np.zeros(env.action_spec().shape)
+ second_timestep = env.step(action)
+ for name, first_array in six.iteritems(first_timestep.observation):
+ second_array = second_timestep.observation[name]
+ self.assertFalse(
+ np.may_share_memory(first_array, second_array),
+ msg='Consecutive observations of {!r} may share memory.'.format(name))
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_observations_dont_contain_constant_elements(self, domain, task):
+ env = suite.load(domain, task)
+ trajectory = make_trajectory(domain=domain, task=task, seed=0,
+ num_episodes=2, max_steps_per_episode=1000)
+ observations = {name: [] for name in env.observation_spec()}
+ for time_step in trajectory:
+ for name, array in six.iteritems(time_step.observation):
+ observations[name].append(array)
+
+ failures = []
+
+ for name, array_list in six.iteritems(observations):
+ # Sampling random uniform actions generally isn't sufficient to trigger
+ # these touch sensors.
+ if (domain in ('manipulator', 'stacker') and name == 'touch' or
+ domain == 'quadruped' and name == 'force_torque'):
+ continue
+ stacked_arrays = np.array(array_list)
+ is_constant = np.all(stacked_arrays == stacked_arrays[0], axis=0)
+ has_constant_elements = (
+ is_constant if np.isscalar(is_constant) else np.any(is_constant))
+ if has_constant_elements:
+ failures.append((name, is_constant))
+
+ self.assertEmpty(
+ failures,
+ msg='The following observation(s) contain constant elements:\n{}'
+ .format('\n'.join(':\t'.join([name, str(is_constant)])
+ for (name, is_constant) in failures)))
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_initial_state_is_randomized(self, domain, task):
+ env = suite.load(domain, task, task_kwargs={'random': 42})
+ obs1 = env.reset().observation
+ obs2 = env.reset().observation
+ self.assertFalse(
+ all(np.all(obs1[k] == obs2[k]) for k in obs1),
+ 'Two consecutive initial states have identical observations.\n'
+ 'First: {}\nSecond: {}'.format(obs1, obs2))
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/local_dm_control_suite/tests/loader_test.py b/local_dm_control_suite/tests/loader_test.py
new file mode 100755
index 0000000..cbce4f5
--- /dev/null
+++ b/local_dm_control_suite/tests/loader_test.py
@@ -0,0 +1,52 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the dm_control.suite loader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control import suite
+from dm_control.rl import control
+
+
+class LoaderTest(absltest.TestCase):
+
+ def test_load_without_kwargs(self):
+ env = suite.load('cartpole', 'swingup')
+ self.assertIsInstance(env, control.Environment)
+
+ def test_load_with_kwargs(self):
+ env = suite.load('cartpole', 'swingup',
+ task_kwargs={'time_limit': 40, 'random': 99})
+ self.assertIsInstance(env, control.Environment)
+
+
+class LoaderConstantsTest(absltest.TestCase):
+
+ def testSuiteConstants(self):
+ self.assertNotEmpty(suite.BENCHMARKING)
+ self.assertNotEmpty(suite.EASY)
+ self.assertNotEmpty(suite.HARD)
+ self.assertNotEmpty(suite.EXTRA)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/local_dm_control_suite/tests/lqr_test.py b/local_dm_control_suite/tests/lqr_test.py
new file mode 100755
index 0000000..d6edcf0
--- /dev/null
+++ b/local_dm_control_suite/tests/lqr_test.py
@@ -0,0 +1,88 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests specific to the LQR domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import unittest
+
+# Internal dependencies.
+from absl import logging
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from local_dm_control_suite import lqr
+from local_dm_control_suite import lqr_solver
+
+import numpy as np
+from six.moves import range
+
+
+class LqrTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('lqr_2_1', lqr.lqr_2_1),
+ ('lqr_6_2', lqr.lqr_6_2))
+ def test_lqr_optimal_policy(self, make_env):
+ env = make_env()
+ p, k, beta = lqr_solver.solve(env)
+ self.assertPolicyisOptimal(env, p, k, beta)
+
+ @parameterized.named_parameters(
+ ('lqr_2_1', lqr.lqr_2_1),
+ ('lqr_6_2', lqr.lqr_6_2))
+ @unittest.skipUnless(
+ condition=lqr_solver.sp,
+ reason='scipy is not available, so non-scipy DARE solver is the default.')
+ def test_lqr_optimal_policy_no_scipy(self, make_env):
+ env = make_env()
+ old_sp = lqr_solver.sp
+ try:
+ lqr_solver.sp = None # Force the solver to use the non-scipy code path.
+ p, k, beta = lqr_solver.solve(env)
+ finally:
+ lqr_solver.sp = old_sp
+ self.assertPolicyisOptimal(env, p, k, beta)
+
+ def assertPolicyisOptimal(self, env, p, k, beta):
+ tolerance = 1e-3
+ n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
+ logging.info('%d timesteps for %g convergence.', n_steps, tolerance)
+ total_loss = 0.0
+
+ timestep = env.reset()
+ initial_state = np.hstack((timestep.observation['position'],
+ timestep.observation['velocity']))
+ logging.info('Measuring total cost over %d steps.', n_steps)
+ for _ in range(n_steps):
+ x = np.hstack((timestep.observation['position'],
+ timestep.observation['velocity']))
+ # u = k*x is the optimal policy
+ u = k.dot(x)
+ total_loss += 1 - (timestep.reward or 0.0)
+ timestep = env.step(u)
+
+ logging.info('Analytical expected total cost is .5*x^T*p*x.')
+ expected_loss = .5 * initial_state.T.dot(p).dot(initial_state)
+ logging.info('Comparing measured and predicted costs.')
+ np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/local_dm_control_suite/utils/__init__.py b/local_dm_control_suite/utils/__init__.py
new file mode 100755
index 0000000..2ea19cf
--- /dev/null
+++ b/local_dm_control_suite/utils/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utility functions used in the control suite."""
diff --git a/local_dm_control_suite/utils/parse_amc.py b/local_dm_control_suite/utils/parse_amc.py
new file mode 100755
index 0000000..3cea2ab
--- /dev/null
+++ b/local_dm_control_suite/utils/parse_amc.py
@@ -0,0 +1,251 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Parse and convert amc motion capture data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+from scipy import interpolate
+from six.moves import range
+
+mjlib = mjbindings.mjlib
+
+MOCAP_DT = 1.0/120.0
+CONVERSION_LENGTH = 0.056444
+
+_CMU_MOCAP_JOINT_ORDER = (
+ 'root0', 'root1', 'root2', 'root3', 'root4', 'root5', 'lowerbackrx',
+ 'lowerbackry', 'lowerbackrz', 'upperbackrx', 'upperbackry', 'upperbackrz',
+ 'thoraxrx', 'thoraxry', 'thoraxrz', 'lowerneckrx', 'lowerneckry',
+ 'lowerneckrz', 'upperneckrx', 'upperneckry', 'upperneckrz', 'headrx',
+ 'headry', 'headrz', 'rclaviclery', 'rclaviclerz', 'rhumerusrx',
+ 'rhumerusry', 'rhumerusrz', 'rradiusrx', 'rwristry', 'rhandrx', 'rhandrz',
+ 'rfingersrx', 'rthumbrx', 'rthumbrz', 'lclaviclery', 'lclaviclerz',
+ 'lhumerusrx', 'lhumerusry', 'lhumerusrz', 'lradiusrx', 'lwristry',
+ 'lhandrx', 'lhandrz', 'lfingersrx', 'lthumbrx', 'lthumbrz', 'rfemurrx',
+ 'rfemurry', 'rfemurrz', 'rtibiarx', 'rfootrx', 'rfootrz', 'rtoesrx',
+ 'lfemurrx', 'lfemurry', 'lfemurrz', 'ltibiarx', 'lfootrx', 'lfootrz',
+ 'ltoesrx'
+)
+
+Converted = collections.namedtuple('Converted',
+ ['qpos', 'qvel', 'time'])
+
+
+def convert(file_name, physics, timestep):
+ """Converts the parsed .amc values into qpos and qvel values and resamples.
+
+ Args:
+ file_name: The .amc file to be parsed and converted.
+ physics: The corresponding physics instance.
+ timestep: Desired output interval between resampled frames.
+
+ Returns:
+ A namedtuple with fields:
+ `qpos`, a numpy array containing converted positional variables.
+ `qvel`, a numpy array containing converted velocity variables.
+ `time`, a numpy array containing the corresponding times.
+ """
+ frame_values = parse(file_name)
+ joint2index = {}
+ for name in physics.named.data.qpos.axes.row.names:
+ joint2index[name] = physics.named.data.qpos.axes.row.convert_key_item(name)
+ index2joint = {}
+ for joint, index in joint2index.items():
+ if isinstance(index, slice):
+ indices = range(index.start, index.stop)
+ else:
+ indices = [index]
+ for ii in indices:
+ index2joint[ii] = joint
+
+ # Convert frame_values to qpos
+ amcvals2qpos_transformer = Amcvals2qpos(index2joint, _CMU_MOCAP_JOINT_ORDER)
+ qpos_values = []
+ for frame_value in frame_values:
+ qpos_values.append(amcvals2qpos_transformer(frame_value))
+ qpos_values = np.stack(qpos_values) # Time by nq
+
+ # Interpolate/resample.
+ # Note: interpolate quaternions rather than euler angles (slerp).
+ # see https://en.wikipedia.org/wiki/Slerp
+ qpos_values_resampled = []
+ time_vals = np.arange(0, len(frame_values)*MOCAP_DT - 1e-8, MOCAP_DT)
+ time_vals_new = np.arange(0, len(frame_values)*MOCAP_DT, timestep)
+ while time_vals_new[-1] > time_vals[-1]:
+ time_vals_new = time_vals_new[:-1]
+
+ for i in range(qpos_values.shape[1]):
+ f = interpolate.splrep(time_vals, qpos_values[:, i])
+ qpos_values_resampled.append(interpolate.splev(time_vals_new, f))
+
+ qpos_values_resampled = np.stack(qpos_values_resampled) # nq by ntime
+
+ qvel_list = []
+ for t in range(qpos_values_resampled.shape[1]-1):
+ p_tp1 = qpos_values_resampled[:, t + 1]
+ p_t = qpos_values_resampled[:, t]
+ qvel = [(p_tp1[:3]-p_t[:3])/ timestep,
+ mj_quat2vel(mj_quatdiff(p_t[3:7], p_tp1[3:7]), timestep),
+ (p_tp1[7:]-p_t[7:])/ timestep]
+ qvel_list.append(np.concatenate(qvel))
+
+ qvel_values_resampled = np.vstack(qvel_list).T
+
+ return Converted(qpos_values_resampled, qvel_values_resampled, time_vals_new)
+
+
+def parse(file_name):
+ """Parses the amc file format."""
+ values = []
+ fid = open(file_name, 'r')
+ line = fid.readline().strip()
+ frame_ind = 1
+ first_frame = True
+ while True:
+ # Parse first frame.
+ if first_frame and line[0] == str(frame_ind):
+ first_frame = False
+ frame_ind += 1
+ frame_vals = []
+ while True:
+ line = fid.readline().strip()
+ if not line or line == str(frame_ind):
+ values.append(np.array(frame_vals, dtype=np.float))
+ break
+ tokens = line.split()
+ frame_vals.extend(tokens[1:])
+ # Parse other frames.
+ elif line == str(frame_ind):
+ frame_ind += 1
+ frame_vals = []
+ while True:
+ line = fid.readline().strip()
+ if not line or line == str(frame_ind):
+ values.append(np.array(frame_vals, dtype=np.float))
+ break
+ tokens = line.split()
+ frame_vals.extend(tokens[1:])
+ else:
+ line = fid.readline().strip()
+ if not line:
+ break
+ return values
+
+
+class Amcvals2qpos(object):
+ """Callable that converts .amc values for a frame and to MuJoCo qpos format.
+ """
+
+ def __init__(self, index2joint, joint_order):
+ """Initializes a new Amcvals2qpos instance.
+
+ Args:
+ index2joint: List of joint angles in .amc file.
+ joint_order: List of joint names in MuJoco MJCF.
+ """
+ # Root is x,y,z, then quat.
+ # need to get indices of qpos that order for amc default order
+ self.qpos_root_xyz_ind = [0, 1, 2]
+ self.root_xyz_ransform = np.array(
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]]) * CONVERSION_LENGTH
+ self.qpos_root_quat_ind = [3, 4, 5, 6]
+ amc2qpos_transform = np.zeros((len(index2joint), len(joint_order)))
+ for i in range(len(index2joint)):
+ for j in range(len(joint_order)):
+ if index2joint[i] == joint_order[j]:
+ if 'rx' in index2joint[i]:
+ amc2qpos_transform[i][j] = 1
+ elif 'ry' in index2joint[i]:
+ amc2qpos_transform[i][j] = 1
+ elif 'rz' in index2joint[i]:
+ amc2qpos_transform[i][j] = 1
+ self.amc2qpos_transform = amc2qpos_transform
+
+ def __call__(self, amc_val):
+ """Converts a `.amc` frame to MuJoCo qpos format."""
+ amc_val_rad = np.deg2rad(amc_val)
+ qpos = np.dot(self.amc2qpos_transform, amc_val_rad)
+
+ # Root.
+ qpos[:3] = np.dot(self.root_xyz_ransform, amc_val[:3])
+ qpos_quat = euler2quat(amc_val[3], amc_val[4], amc_val[5])
+ qpos_quat = mj_quatprod(euler2quat(90, 0, 0), qpos_quat)
+
+ for i, ind in enumerate(self.qpos_root_quat_ind):
+ qpos[ind] = qpos_quat[i]
+
+ return qpos
+
+
+def euler2quat(ax, ay, az):
+ """Converts euler angles to a quaternion.
+
+ Note: rotation order is zyx
+
+ Args:
+ ax: Roll angle (deg)
+ ay: Pitch angle (deg).
+ az: Yaw angle (deg).
+
+ Returns:
+ A numpy array representing the rotation as a quaternion.
+ """
+ r1 = az
+ r2 = ay
+ r3 = ax
+
+ c1 = np.cos(np.deg2rad(r1 / 2))
+ s1 = np.sin(np.deg2rad(r1 / 2))
+ c2 = np.cos(np.deg2rad(r2 / 2))
+ s2 = np.sin(np.deg2rad(r2 / 2))
+ c3 = np.cos(np.deg2rad(r3 / 2))
+ s3 = np.sin(np.deg2rad(r3 / 2))
+
+ q0 = c1 * c2 * c3 + s1 * s2 * s3
+ q1 = c1 * c2 * s3 - s1 * s2 * c3
+ q2 = c1 * s2 * c3 + s1 * c2 * s3
+ q3 = s1 * c2 * c3 - c1 * s2 * s3
+
+ return np.array([q0, q1, q2, q3])
+
+
+def mj_quatprod(q, r):
+ quaternion = np.zeros(4)
+ mjlib.mju_mulQuat(quaternion, np.ascontiguousarray(q),
+ np.ascontiguousarray(r))
+ return quaternion
+
+
+def mj_quat2vel(q, dt):
+ vel = np.zeros(3)
+ mjlib.mju_quat2Vel(vel, np.ascontiguousarray(q), dt)
+ return vel
+
+
+def mj_quatneg(q):
+ quaternion = np.zeros(4)
+ mjlib.mju_negQuat(quaternion, np.ascontiguousarray(q))
+ return quaternion
+
+
+def mj_quatdiff(source, target):
+ return mj_quatprod(mj_quatneg(source), np.ascontiguousarray(target))
diff --git a/local_dm_control_suite/utils/parse_amc_test.py b/local_dm_control_suite/utils/parse_amc_test.py
new file mode 100755
index 0000000..c8a9052
--- /dev/null
+++ b/local_dm_control_suite/utils/parse_amc_test.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for parse_amc utility."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from local_dm_control_suite import humanoid_CMU
+from dm_control.suite.utils import parse_amc
+
+from dm_control.utils import io as resources
+
+_TEST_AMC_PATH = resources.GetResourceFilename(
+ os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
+
+
+class ParseAMCTest(absltest.TestCase):
+
+ def test_sizes_of_parsed_data(self):
+
+ # Instantiate the humanoid environment.
+ env = humanoid_CMU.stand()
+
+ # Parse and convert specified clip.
+ converted = parse_amc.convert(
+ _TEST_AMC_PATH, env.physics, env.control_timestep())
+
+ self.assertEqual(converted.qpos.shape[0], 63)
+ self.assertEqual(converted.qvel.shape[0], 62)
+ self.assertEqual(converted.time.shape[0], converted.qpos.shape[1])
+ self.assertEqual(converted.qpos.shape[1],
+ converted.qvel.shape[1] + 1)
+
+ # Parse and convert specified clip -- WITH SMALLER TIMESTEP
+ converted2 = parse_amc.convert(
+ _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep())
+
+ self.assertEqual(converted2.qpos.shape[0], 63)
+ self.assertEqual(converted2.qvel.shape[0], 62)
+ self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1])
+ self.assertEqual(converted.qpos.shape[1],
+ converted.qvel.shape[1] + 1)
+
+ # Compare sizes of parsed objects for different timesteps
+ self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/local_dm_control_suite/utils/randomizers.py b/local_dm_control_suite/utils/randomizers.py
new file mode 100755
index 0000000..30ec182
--- /dev/null
+++ b/local_dm_control_suite/utils/randomizers.py
@@ -0,0 +1,91 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Randomization functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+from six.moves import range
+
+
+def random_limited_quaternion(random, limit):
+ """Generates a random quaternion limited to the specified rotations."""
+ axis = random.randn(3)
+ axis /= np.linalg.norm(axis)
+ angle = random.rand() * limit
+
+ quaternion = np.zeros(4)
+ mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
+
+ return quaternion
+
+
+def randomize_limited_and_rotational_joints(physics, random=None):
+ """Randomizes the positions of joints defined in the physics body.
+
+ The following randomization rules apply:
+ - Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
+ - Unbounded hinges are samples uniformly in [-pi, pi]
+ - Quaternions for unlimited free joints and ball joints are sampled
+ uniformly on the unit 3-sphere.
+ - Quaternions for limited ball joints are sampled uniformly on a sector
+ of the unit 3-sphere.
+ - The linear degrees of freedom of free joints are not randomized.
+
+ Args:
+ physics: Instance of 'Physics' class that holds a loaded model.
+ random: Optional instance of 'np.random.RandomState'. Defaults to the global
+ NumPy random state.
+ """
+ random = random or np.random
+
+ hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
+ slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
+ ball = mjbindings.enums.mjtJoint.mjJNT_BALL
+ free = mjbindings.enums.mjtJoint.mjJNT_FREE
+
+ qpos = physics.named.data.qpos
+
+ for joint_id in range(physics.model.njnt):
+ joint_name = physics.model.id2name(joint_id, 'joint')
+ joint_type = physics.model.jnt_type[joint_id]
+ is_limited = physics.model.jnt_limited[joint_id]
+ range_min, range_max = physics.model.jnt_range[joint_id]
+
+ if is_limited:
+ if joint_type == hinge or joint_type == slide:
+ qpos[joint_name] = random.uniform(range_min, range_max)
+
+ elif joint_type == ball:
+ qpos[joint_name] = random_limited_quaternion(random, range_max)
+
+ else:
+ if joint_type == hinge:
+ qpos[joint_name] = random.uniform(-np.pi, np.pi)
+
+ elif joint_type == ball:
+ quat = random.randn(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name] = quat
+
+ elif joint_type == free:
+ quat = random.rand(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name][3:] = quat
+
diff --git a/local_dm_control_suite/utils/randomizers_test.py b/local_dm_control_suite/utils/randomizers_test.py
new file mode 100755
index 0000000..8b9b72d
--- /dev/null
+++ b/local_dm_control_suite/utils/randomizers_test.py
@@ -0,0 +1,164 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for randomizers.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mujoco
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.suite.utils import randomizers
+import numpy as np
+from six.moves import range
+
+mjlib = mjbindings.mjlib
+
+
+class RandomizeUnlimitedJointsTest(parameterized.TestCase):
+
+ def setUp(self):
+ self.rand = np.random.RandomState(100)
+
+ def test_single_joint_of_each_type(self):
+ physics = mujoco.Physics.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """)
+
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertNotEqual(0., physics.named.data.qpos['hinge'])
+ self.assertNotEqual(0., physics.named.data.qpos['limited_hinge'])
+ self.assertNotEqual(0., physics.named.data.qpos['limited_slide'])
+
+ self.assertNotEqual(0., np.sum(physics.named.data.qpos['ball']))
+ self.assertNotEqual(0., np.sum(physics.named.data.qpos['limited_ball']))
+
+ self.assertNotEqual(0., np.sum(physics.named.data.qpos['free'][3:]))
+
+ # Unlimited slide and the positional part of the free joint remains
+ # uninitialized.
+ self.assertEqual(0., physics.named.data.qpos['slide'])
+ self.assertEqual(0., np.sum(physics.named.data.qpos['free'][:3]))
+
+ def test_multiple_joints_of_same_type(self):
+ physics = mujoco.Physics.from_xml_string("""
+
+
+
+
+
+
+
+
+ """)
+
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertNotEqual(0., physics.named.data.qpos['hinge_1'])
+ self.assertNotEqual(0., physics.named.data.qpos['hinge_2'])
+ self.assertNotEqual(0., physics.named.data.qpos['hinge_3'])
+
+ self.assertNotEqual(physics.named.data.qpos['hinge_1'],
+ physics.named.data.qpos['hinge_2'])
+
+ self.assertNotEqual(physics.named.data.qpos['hinge_2'],
+ physics.named.data.qpos['hinge_3'])
+
+ self.assertNotEqual(physics.named.data.qpos['hinge_1'],
+ physics.named.data.qpos['hinge_3'])
+
+ def test_unlimited_hinge_randomization_range(self):
+ physics = mujoco.Physics.from_xml_string("""
+
+
+
+
+
+
+ """)
+
+ for _ in range(10):
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertBetween(physics.named.data.qpos['hinge'], -np.pi, np.pi)
+
+ def test_limited_1d_joint_limits_are_respected(self):
+ physics = mujoco.Physics.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+ """)
+
+ for _ in range(10):
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertBetween(physics.named.data.qpos['hinge'],
+ np.deg2rad(0), np.deg2rad(10))
+ self.assertBetween(physics.named.data.qpos['slide'], 30, 50)
+
+ def test_limited_ball_joint_are_respected(self):
+ physics = mujoco.Physics.from_xml_string("""
+
+
+
+
+
+
+ """)
+
+ body_axis = np.array([1., 0., 0.])
+ joint_axis = np.zeros(3)
+ for _ in range(10):
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+
+ quat = physics.named.data.qpos['ball']
+ mjlib.mju_rotVecQuat(joint_axis, body_axis, quat)
+ angle_cos = np.dot(body_axis, joint_axis)
+ self.assertGreater(angle_cos, 0.5) # cos(60) = 0.5
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/local_dm_control_suite/walker.py b/local_dm_control_suite/walker.py
new file mode 100755
index 0000000..b7bfd58
--- /dev/null
+++ b/local_dm_control_suite/walker.py
@@ -0,0 +1,158 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Planar Walker Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mujoco
+from dm_control.rl import control
+from local_dm_control_suite import base
+from local_dm_control_suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+
+_DEFAULT_TIME_LIMIT = 25
+_CONTROL_TIMESTEP = .025
+
+# Minimal height of torso over foot above which stand reward is 1.
+_STAND_HEIGHT = 1.2
+
+# Horizontal speeds (meters/second) above which move reward is 1.
+_WALK_SPEED = 1
+_RUN_SPEED = 8
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('walker.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Stand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PlanarWalker(move_speed=0, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PlanarWalker(move_speed=_WALK_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('benchmarking')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PlanarWalker(move_speed=_RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Walker domain."""
+
+ def torso_upright(self):
+ """Returns projection from z-axes of torso to the z-axes of world."""
+ return self.named.data.xmat['torso', 'zz']
+
+ def torso_height(self):
+ """Returns the height of the torso."""
+ return self.named.data.xpos['torso', 'z']
+
+ def horizontal_velocity(self):
+ """Returns the horizontal velocity of the center-of-mass."""
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
+
+ def orientations(self):
+ """Returns planar orientations of all bodies."""
+ return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
+
+
+class PlanarWalker(base.Task):
+ """A planar walker task."""
+
+ def __init__(self, move_speed, random=None):
+ """Initializes an instance of `PlanarWalker`.
+
+ Args:
+ move_speed: A float. If this value is zero, reward is given simply for
+ standing up. Otherwise this specifies a target horizontal velocity for
+ the walking task.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._move_speed = move_speed
+ super(PlanarWalker, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ In 'standing' mode, use initial orientation and small velocities.
+ In 'random' mode, randomize joint angles and let fall to the floor.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ super(PlanarWalker, self).initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of body orientations, height and velocites."""
+ obs = collections.OrderedDict()
+ obs['orientations'] = physics.orientations()
+ obs['height'] = physics.torso_height()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT/2)
+ upright = (1 + physics.torso_upright()) / 2
+ stand_reward = (3*standing + upright) / 4
+ if self._move_speed == 0:
+ return stand_reward
+ else:
+ move_reward = rewards.tolerance(physics.horizontal_velocity(),
+ bounds=(self._move_speed, float('inf')),
+ margin=self._move_speed/2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return stand_reward * (5*move_reward + 1) / 6
diff --git a/local_dm_control_suite/walker.xml b/local_dm_control_suite/walker.xml
new file mode 100755
index 0000000..d87ae82
--- /dev/null
+++ b/local_dm_control_suite/walker.xml
@@ -0,0 +1,70 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/local_dm_control_suite/wrappers/__init__.py b/local_dm_control_suite/wrappers/__init__.py
new file mode 100755
index 0000000..f7e4a68
--- /dev/null
+++ b/local_dm_control_suite/wrappers/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Environment wrappers used to extend or modify environment behaviour."""
diff --git a/local_dm_control_suite/wrappers/action_noise.py b/local_dm_control_suite/wrappers/action_noise.py
new file mode 100755
index 0000000..dab9970
--- /dev/null
+++ b/local_dm_control_suite/wrappers/action_noise.py
@@ -0,0 +1,74 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Wrapper control suite environments that adds Gaussian noise to actions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import dm_env
+import numpy as np
+
+
+_BOUNDS_MUST_BE_FINITE = (
+ 'All bounds in `env.action_spec()` must be finite, got: {action_spec}')
+
+
+class Wrapper(dm_env.Environment):
+ """Wraps a control environment and adds Gaussian noise to actions."""
+
+ def __init__(self, env, scale=0.01):
+ """Initializes a new action noise Wrapper.
+
+ Args:
+ env: The control suite environment to wrap.
+ scale: The standard deviation of the noise, expressed as a fraction
+ of the max-min range for each action dimension.
+
+ Raises:
+ ValueError: If any of the action dimensions of the wrapped environment are
+ unbounded.
+ """
+ action_spec = env.action_spec()
+ if not (np.all(np.isfinite(action_spec.minimum)) and
+ np.all(np.isfinite(action_spec.maximum))):
+ raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec))
+ self._minimum = action_spec.minimum
+ self._maximum = action_spec.maximum
+ self._noise_std = scale * (action_spec.maximum - action_spec.minimum)
+ self._env = env
+
+ def step(self, action):
+ noisy_action = action + self._env.task.random.normal(scale=self._noise_std)
+ # Clip the noisy actions in place so that they fall within the bounds
+ # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of-
+ # bounds control inputs, but we also clip here in case the actions do not
+ # correspond directly to MuJoCo actuators, or if there are other wrapper
+ # layers that expect the actions to be within bounds.
+ np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action)
+ return self._env.step(noisy_action)
+
+ def reset(self):
+ return self._env.reset()
+
+ def observation_spec(self):
+ return self._env.observation_spec()
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
diff --git a/local_dm_control_suite/wrappers/action_noise_test.py b/local_dm_control_suite/wrappers/action_noise_test.py
new file mode 100755
index 0000000..dcc5330
--- /dev/null
+++ b/local_dm_control_suite/wrappers/action_noise_test.py
@@ -0,0 +1,136 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the action noise wrapper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.rl import control
+from dm_control.suite.wrappers import action_noise
+from dm_env import specs
+import mock
+import numpy as np
+
+
+class ActionNoiseTest(parameterized.TestCase):
+
+ def make_action_spec(self, lower=(-1.,), upper=(1.,)):
+ lower, upper = np.broadcast_arrays(lower, upper)
+ return specs.BoundedArray(
+ shape=lower.shape, dtype=float, minimum=lower, maximum=upper)
+
+ def make_mock_env(self, action_spec=None):
+ action_spec = action_spec or self.make_action_spec()
+ env = mock.Mock(spec=control.Environment)
+ env.action_spec.return_value = action_spec
+ return env
+
+ def assertStepCalledOnceWithCorrectAction(self, env, expected_action):
+ # NB: `assert_called_once_with()` doesn't support numpy arrays.
+ env.step.assert_called_once()
+ actual_action = env.step.call_args_list[0][0][0]
+ np.testing.assert_array_equal(expected_action, actual_action)
+
+ @parameterized.parameters([
+ dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.05),
+ dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.),
+ dict(lower=np.r_[-1., 0.], upper=np.r_[-1., 0.], scale=0.05),
+ ])
+ def test_step(self, lower, upper, scale):
+ seed = 0
+ std = scale * (upper - lower)
+ expected_noise = np.random.RandomState(seed).normal(scale=std)
+ action = np.random.RandomState(seed).uniform(lower, upper)
+ expected_noisy_action = np.clip(action + expected_noise, lower, upper)
+ task = mock.Mock(spec=control.Task)
+ task.random = np.random.RandomState(seed)
+ action_spec = self.make_action_spec(lower=lower, upper=upper)
+ env = self.make_mock_env(action_spec=action_spec)
+ env.task = task
+ wrapped_env = action_noise.Wrapper(env, scale=scale)
+ time_step = wrapped_env.step(action)
+ self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action)
+ self.assertIs(time_step, env.step(expected_noisy_action))
+
+ @parameterized.named_parameters([
+ dict(testcase_name='within_bounds', action=np.r_[-1.], noise=np.r_[0.1]),
+ dict(testcase_name='below_lower', action=np.r_[-1.], noise=np.r_[-0.1]),
+ dict(testcase_name='above_upper', action=np.r_[1.], noise=np.r_[0.1]),
+ ])
+ def test_action_clipping(self, action, noise):
+ lower = -1.
+ upper = 1.
+ expected_noisy_action = np.clip(action + noise, lower, upper)
+ task = mock.Mock(spec=control.Task)
+ task.random = mock.Mock(spec=np.random.RandomState)
+ task.random.normal.return_value = noise
+ action_spec = self.make_action_spec(lower=lower, upper=upper)
+ env = self.make_mock_env(action_spec=action_spec)
+ env.task = task
+ wrapped_env = action_noise.Wrapper(env)
+ time_step = wrapped_env.step(action)
+ self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action)
+ self.assertIs(time_step, env.step(expected_noisy_action))
+
+ @parameterized.parameters([
+ dict(lower=np.r_[-1., 0.], upper=np.r_[1., np.inf]),
+ dict(lower=np.r_[np.nan, 0.], upper=np.r_[1., 2.]),
+ ])
+ def test_error_if_action_bounds_non_finite(self, lower, upper):
+ action_spec = self.make_action_spec(lower=lower, upper=upper)
+ env = self.make_mock_env(action_spec=action_spec)
+ with self.assertRaisesWithLiteralMatch(
+ ValueError,
+ action_noise._BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)):
+ _ = action_noise.Wrapper(env)
+
+ def test_reset(self):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ time_step = wrapped_env.reset()
+ env.reset.assert_called_once_with()
+ self.assertIs(time_step, env.reset())
+
+ def test_observation_spec(self):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ observation_spec = wrapped_env.observation_spec()
+ env.observation_spec.assert_called_once_with()
+ self.assertIs(observation_spec, env.observation_spec())
+
+ def test_action_spec(self):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ # `env.action_spec()` is called in `Wrapper.__init__()`
+ env.action_spec.reset_mock()
+ action_spec = wrapped_env.action_spec()
+ env.action_spec.assert_called_once_with()
+ self.assertIs(action_spec, env.action_spec())
+
+ @parameterized.parameters(['task', 'physics', 'control_timestep'])
+ def test_getattr(self, attribute_name):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ attr = getattr(wrapped_env, attribute_name)
+ self.assertIs(attr, getattr(env, attribute_name))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/local_dm_control_suite/wrappers/pixels.py b/local_dm_control_suite/wrappers/pixels.py
new file mode 100755
index 0000000..0f55fff
--- /dev/null
+++ b/local_dm_control_suite/wrappers/pixels.py
@@ -0,0 +1,120 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Wrapper that adds pixel observations to a control environment."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import dm_env
+from dm_env import specs
+
+STATE_KEY = 'state'
+
+
+class Wrapper(dm_env.Environment):
+ """Wraps a control environment and adds a rendered pixel observation."""
+
+ def __init__(self, env, pixels_only=True, render_kwargs=None,
+ observation_key='pixels'):
+ """Initializes a new pixel Wrapper.
+
+ Args:
+ env: The environment to wrap.
+ pixels_only: If True (default), the original set of 'state' observations
+ returned by the wrapped environment will be discarded, and the
+ `OrderedDict` of observations will only contain pixels. If False, the
+ `OrderedDict` will contain the original observations as well as the
+ pixel observations.
+ render_kwargs: Optional `dict` containing keyword arguments passed to the
+ `mujoco.Physics.render` method.
+ observation_key: Optional custom string specifying the pixel observation's
+ key in the `OrderedDict` of observations. Defaults to 'pixels'.
+
+ Raises:
+ ValueError: If `env`'s observation spec is not compatible with the
+ wrapper. Supported formats are a single array, or a dict of arrays.
+ ValueError: If `env`'s observation already contains the specified
+ `observation_key`.
+ """
+ if render_kwargs is None:
+ render_kwargs = {}
+
+ wrapped_observation_spec = env.observation_spec()
+
+ if isinstance(wrapped_observation_spec, specs.Array):
+ self._observation_is_dict = False
+ invalid_keys = set([STATE_KEY])
+ elif isinstance(wrapped_observation_spec, collections.MutableMapping):
+ self._observation_is_dict = True
+ invalid_keys = set(wrapped_observation_spec.keys())
+ else:
+ raise ValueError('Unsupported observation spec structure.')
+
+ if not pixels_only and observation_key in invalid_keys:
+ raise ValueError('Duplicate or reserved observation key {!r}.'
+ .format(observation_key))
+
+ if pixels_only:
+ self._observation_spec = collections.OrderedDict()
+ elif self._observation_is_dict:
+ self._observation_spec = wrapped_observation_spec.copy()
+ else:
+ self._observation_spec = collections.OrderedDict()
+ self._observation_spec[STATE_KEY] = wrapped_observation_spec
+
+ # Extend observation spec.
+ pixels = env.physics.render(**render_kwargs)
+ pixels_spec = specs.Array(
+ shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
+ self._observation_spec[observation_key] = pixels_spec
+
+ self._env = env
+ self._pixels_only = pixels_only
+ self._render_kwargs = render_kwargs
+ self._observation_key = observation_key
+
+ def reset(self):
+ time_step = self._env.reset()
+ return self._add_pixel_observation(time_step)
+
+ def step(self, action):
+ time_step = self._env.step(action)
+ return self._add_pixel_observation(time_step)
+
+ def observation_spec(self):
+ return self._observation_spec
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def _add_pixel_observation(self, time_step):
+ if self._pixels_only:
+ observation = collections.OrderedDict()
+ elif self._observation_is_dict:
+ observation = type(time_step.observation)(time_step.observation)
+ else:
+ observation = collections.OrderedDict()
+ observation[STATE_KEY] = time_step.observation
+
+ pixels = self._env.physics.render(**self._render_kwargs)
+ observation[self._observation_key] = pixels
+ return time_step._replace(observation=observation)
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
diff --git a/local_dm_control_suite/wrappers/pixels_test.py b/local_dm_control_suite/wrappers/pixels_test.py
new file mode 100755
index 0000000..26b7fc1
--- /dev/null
+++ b/local_dm_control_suite/wrappers/pixels_test.py
@@ -0,0 +1,133 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the pixel wrapper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from local_dm_control_suite import cartpole
+from dm_control.suite.wrappers import pixels
+import dm_env
+from dm_env import specs
+
+import numpy as np
+
+
+class FakePhysics(object):
+
+ def render(self, *args, **kwargs):
+ del args
+ del kwargs
+ return np.zeros((4, 5, 3), dtype=np.uint8)
+
+
+class FakeArrayObservationEnvironment(dm_env.Environment):
+
+ def __init__(self):
+ self.physics = FakePhysics()
+
+ def reset(self):
+ return dm_env.restart(np.zeros((2,)))
+
+ def step(self, action):
+ del action
+ return dm_env.transition(0.0, np.zeros((2,)))
+
+ def action_spec(self):
+ pass
+
+ def observation_spec(self):
+ return specs.Array(shape=(2,), dtype=np.float)
+
+
+class PixelsTest(parameterized.TestCase):
+
+ @parameterized.parameters(True, False)
+ def test_dict_observation(self, pixels_only):
+ pixel_key = 'rgb'
+
+ env = cartpole.swingup()
+
+ # Make sure we are testing the right environment for the test.
+ observation_spec = env.observation_spec()
+ self.assertIsInstance(observation_spec, collections.OrderedDict)
+
+ width = 320
+ height = 240
+
+ # The wrapper should only add one observation.
+ wrapped = pixels.Wrapper(env,
+ observation_key=pixel_key,
+ pixels_only=pixels_only,
+ render_kwargs={'width': width, 'height': height})
+
+ wrapped_observation_spec = wrapped.observation_spec()
+ self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
+
+ if pixels_only:
+ self.assertLen(wrapped_observation_spec, 1)
+ self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
+ else:
+ expected_length = len(observation_spec) + 1
+ self.assertLen(wrapped_observation_spec, expected_length)
+ expected_keys = list(observation_spec.keys()) + [pixel_key]
+ self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
+
+ # Check that the added spec item is consistent with the added observation.
+ time_step = wrapped.reset()
+ rgb_observation = time_step.observation[pixel_key]
+ wrapped_observation_spec[pixel_key].validate(rgb_observation)
+
+ self.assertEqual(rgb_observation.shape, (height, width, 3))
+ self.assertEqual(rgb_observation.dtype, np.uint8)
+
+ @parameterized.parameters(True, False)
+ def test_single_array_observation(self, pixels_only):
+ pixel_key = 'depth'
+
+ env = FakeArrayObservationEnvironment()
+ observation_spec = env.observation_spec()
+ self.assertIsInstance(observation_spec, specs.Array)
+
+ wrapped = pixels.Wrapper(env, observation_key=pixel_key,
+ pixels_only=pixels_only)
+ wrapped_observation_spec = wrapped.observation_spec()
+ self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
+
+ if pixels_only:
+ self.assertLen(wrapped_observation_spec, 1)
+ self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
+ else:
+ self.assertLen(wrapped_observation_spec, 2)
+ self.assertEqual([pixels.STATE_KEY, pixel_key],
+ list(wrapped_observation_spec.keys()))
+
+ time_step = wrapped.reset()
+
+ depth_observation = time_step.observation[pixel_key]
+ wrapped_observation_spec[pixel_key].validate(depth_observation)
+
+ self.assertEqual(depth_observation.shape, (4, 5, 3))
+ self.assertEqual(depth_observation.dtype, np.uint8)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/logger.py b/logger.py
index 8e31fd4..3a2adda 100644
--- a/logger.py
+++ b/logger.py
@@ -7,6 +7,7 @@ import torch
import torchvision
import numpy as np
from termcolor import colored
+from datetime import datetime
FORMAT_CONFIG = {
'rl': {
@@ -93,8 +94,10 @@ class MetersGroup(object):
class Logger(object):
def __init__(self, log_dir, use_tb=True, config='rl'):
self._log_dir = log_dir
+ now = datetime.now()
+ dt_string = now.strftime("%d_%m_%Y-%H_%M_%S")
if use_tb:
- tb_dir = os.path.join(log_dir, 'tb')
+ tb_dir = os.path.join(log_dir, 'runs/tb_'+dt_string)
if os.path.exists(tb_dir):
shutil.rmtree(tb_dir)
self._sw = SummaryWriter(tb_dir)
diff --git a/sac_ae.py b/sac_ae.py
index 0e5d915..3499fac 100644
--- a/sac_ae.py
+++ b/sac_ae.py
@@ -6,7 +6,7 @@ import copy
import math
import utils
-from encoder import make_encoder
+from encoder import make_encoder, club_loss, TransitionModel
from decoder import make_decoder
LOG_FREQ = 10000
@@ -70,10 +70,8 @@ class Actor(nn.Module):
self.outputs = dict()
self.apply(weight_init)
- def forward(
- self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
- ):
- obs = self.encoder(obs, detach=detach_encoder)
+ def forward(self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False):
+ obs, _, _ = self.encoder(obs, detach=detach_encoder)
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
@@ -100,7 +98,6 @@ class Actor(nn.Module):
log_pi = None
mu, pi, log_pi = squash(mu, pi, log_pi)
-
return mu, pi, log_pi, log_std
def log(self, L, step, log_freq=LOG_FREQ):
@@ -159,7 +156,7 @@ class Critic(nn.Module):
def forward(self, obs, action, detach_encoder=False):
# detach_encoder allows to stop gradient propogation to encoder
- obs = self.encoder(obs, detach=detach_encoder)
+ obs, _ , _ = self.encoder(obs, detach=detach_encoder)
q1 = self.Q1(obs, action)
q2 = self.Q2(obs, action)
@@ -182,7 +179,53 @@ class Critic(nn.Module):
L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
+class CURL(nn.Module):
+ """
+ CURL
+ """
+ def __init__(self, obs_shape, z_dim, a_dim, batch_size, critic, critic_target, output_type="continuous"):
+ super(CURL, self).__init__()
+ self.batch_size = batch_size
+
+ self.encoder = critic.encoder
+
+ self.encoder_target = critic_target.encoder
+
+ self.W = nn.Parameter(torch.rand(z_dim, z_dim))
+ self.combine = nn.Linear(z_dim + a_dim, z_dim)
+ self.output_type = output_type
+
+ def encode(self, x, a=None, detach=False, ema=False):
+ """
+ Encoder: z_t = e(x_t)
+ :param x: x_t, x y coordinates
+ :return: z_t, value in r2
+ """
+ if ema:
+ with torch.no_grad():
+ z_out = self.encoder_target(x)[0]
+ z_out = self.combine(torch.concat((z_out,a), dim=-1))
+ else:
+ z_out = self.encoder(x)[0]
+
+ if detach:
+ z_out = z_out.detach()
+ return z_out
+
+ def compute_logits(self, z_a, z_pos):
+ """
+ Uses logits trick for CURL:
+ - compute (B,B) matrix z_a (W z_pos.T)
+ - positives are all diagonal elements
+ - negatives are all other elements
+ - to compute loss use multiclass cross entropy with identity matrix for labels
+ """
+ Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B)
+ logits = torch.matmul(z_a, Wz) # (B,B)
+ logits = logits - torch.max(logits, 1)[0][:, None]
+ return logits
+
class SacAeAgent(object):
"""SAC+AE algorithm."""
def __init__(
@@ -224,6 +267,12 @@ class SacAeAgent(object):
self.critic_target_update_freq = critic_target_update_freq
self.decoder_update_freq = decoder_update_freq
self.decoder_latent_lambda = decoder_latent_lambda
+
+ self.transition_model = TransitionModel(
+ encoder_feature_dim,
+ hidden_dim,
+ action_shape[0],
+ encoder_feature_dim).to(device)
self.actor = Actor(
obs_shape, action_shape, hidden_dim, encoder_type,
@@ -251,6 +300,11 @@ class SacAeAgent(object):
# set target entropy to -|A|
self.target_entropy = -np.prod(action_shape)
+ self.CURL = CURL(obs_shape, encoder_feature_dim, action_shape[0],
+ obs_shape[0], self.critic,self.critic_target, output_type='continuous').to(self.device)
+
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
+
self.decoder = None
if decoder_type != 'identity':
# create decoder
@@ -281,6 +335,10 @@ class SacAeAgent(object):
self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
)
+ self.cpc_optimizer = torch.optim.Adam(
+ self.CURL.parameters(), lr=encoder_lr
+ )
+
self.log_alpha_optimizer = torch.optim.Adam(
[self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
)
@@ -329,7 +387,6 @@ class SacAeAgent(object):
target_Q) + F.mse_loss(current_Q2, target_Q)
L.log('train_critic/loss', critic_loss, step)
-
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
@@ -366,12 +423,38 @@ class SacAeAgent(object):
alpha_loss.backward()
self.log_alpha_optimizer.step()
- def update_decoder(self, obs, target_obs, L, step):
- h = self.critic.encoder(obs)
+ def update_decoder(self, last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, target_obs, L, step):
+ h_curr, mu_h_curr, std_h_curr = self.critic.encoder(curr_obs)
+ with torch.no_grad():
+ h_last, _, _ = self.critic.encoder(last_obs)
+ self.transition_model.init_states(last_obs.shape[0], self.device)
+ curr_state = self.transition_model.transition_step(h_last, last_action, self.transition_model.prev_history, last_not_done)
+
+ hist = curr_state["history"]
+ next_state = self.transition_model.transition_step(h_curr, action, hist, not_done)
+
+ next_state_mu = next_state["mean"]
+ next_state_sigma = next_state["std"]
+ next_state_sample = next_state["sample"]
+ pred_dist = torch.distributions.Normal(next_state_mu, next_state_sigma)
+
+ h, mu_h_next, logstd_h_next = self.critic.encoder(next_obs)
+ std_h_next = torch.exp(logstd_h_next)
+ enc_dist = torch.distributions.Normal(mu_h_next, std_h_next)
+ enc_loss = torch.mean(torch.distributions.kl.kl_divergence(enc_dist,pred_dist)) * 0.1
+
+ z_pos = self.CURL.encode(next_obs, action.detach(), ema=True)
+ logits = self.CURL.compute_logits(h_curr, z_pos)
+ labels = torch.arange(logits.shape[0]).long().to(self.device)
+ lb_loss = self.cross_entropy_loss(logits, labels) * 0.1
+
+ ub_loss = club_loss(h, mu_h_next, logstd_h_next, next_state_sample) * 0.1
+
if target_obs.dim() == 4:
# preprocess images to be in [-0.5, 0.5] range
target_obs = utils.preprocess_obs(target_obs)
+
rec_obs = self.decoder(h)
rec_loss = F.mse_loss(target_obs, rec_obs)
@@ -379,26 +462,35 @@ class SacAeAgent(object):
# see https://arxiv.org/pdf/1903.12436.pdf
latent_loss = (0.5 * h.pow(2).sum(1)).mean()
- loss = rec_loss + self.decoder_latent_lambda * latent_loss
+ loss = rec_loss + enc_loss + lb_loss + ub_loss #self.decoder_latent_lambda * latent_loss
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
+ self.cpc_optimizer.zero_grad()
loss.backward()
-
- self.encoder_optimizer.step()
+
+ self.encoder_optimizer.step()
self.decoder_optimizer.step()
+ self.cpc_optimizer.step()
L.log('train_ae/ae_loss', loss, step)
+ L.log('train_ae/lb_loss', lb_loss, step)
+ L.log('train_ae/ub_loss', ub_loss, step)
+ L.log('train_ae/enc_loss', enc_loss, step)
+ L.log('train_ae/dec_loss', rec_loss, step)
self.decoder.log(L, step, log_freq=LOG_FREQ)
def update(self, replay_buffer, L, step):
- obs, action, reward, next_obs, not_done = replay_buffer.sample()
+ last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done = replay_buffer.sample()
+ #obs, action, reward, next_obs, not_done = replay_buffer.sample()
- L.log('train/batch_reward', reward.mean(), step)
+ L.log('train/batch_reward', last_reward.mean(), step)
- self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+ #self.update_critic(last_obs, last_action, last_reward, curr_obs, last_not_done, L, step)
+ self.update_critic(curr_obs, action, reward, next_obs, not_done, L, step)
if step % self.actor_update_freq == 0:
- self.update_actor_and_alpha(obs, L, step)
+ #self.update_actor_and_alpha(last_obs, L, step)
+ self.update_actor_and_alpha(curr_obs, L, step)
if step % self.critic_target_update_freq == 0:
utils.soft_update_params(
@@ -413,7 +505,7 @@ class SacAeAgent(object):
)
if self.decoder is not None and step % self.decoder_update_freq == 0:
- self.update_decoder(obs, obs, L, step)
+ self.update_decoder(last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, next_obs, L, step)
def save(self, model_dir, step):
torch.save(
diff --git a/utils.py b/utils.py
index 067715c..6eece02 100644
--- a/utils.py
+++ b/utils.py
@@ -75,18 +75,26 @@ class ReplayBuffer(object):
# the proprioceptive obs is stored as float32, pixels obs as uint8
obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
- self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
+ self.last_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
+ self.curr_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
+ self.last_actions = np.empty((capacity, *action_shape), dtype=np.float32)
self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
+ self.last_rewards = np.empty((capacity, 1), dtype=np.float32)
self.rewards = np.empty((capacity, 1), dtype=np.float32)
+ self.last_not_dones = np.empty((capacity, 1), dtype=np.float32)
self.not_dones = np.empty((capacity, 1), dtype=np.float32)
self.idx = 0
self.last_save = 0
self.full = False
- def add(self, obs, action, reward, next_obs, done):
- np.copyto(self.obses[self.idx], obs)
+ def add(self, last_obs, last_action, last_reward, curr_obs, last_done, action, reward, next_obs, done):
+ np.copyto(self.last_obses[self.idx], last_obs)
+ np.copyto(self.last_actions[self.idx], last_action)
+ np.copyto(self.last_rewards[self.idx], last_reward)
+ np.copyto(self.curr_obses[self.idx], curr_obs)
+ np.copyto(self.last_not_dones[self.idx], not last_done)
np.copyto(self.actions[self.idx], action)
np.copyto(self.rewards[self.idx], reward)
np.copyto(self.next_obses[self.idx], next_obs)
@@ -100,25 +108,31 @@ class ReplayBuffer(object):
0, self.capacity if self.full else self.idx, size=self.batch_size
)
- obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
+ last_obses = torch.as_tensor(self.last_obses[idxs], device=self.device).float()
+ last_actions = torch.as_tensor(self.last_actions[idxs], device=self.device)
+ last_rewards = torch.as_tensor(self.last_rewards[idxs], device=self.device)
+ curr_obses = torch.as_tensor(self.curr_obses[idxs], device=self.device).float()
+ last_not_dones = torch.as_tensor(self.last_not_dones[idxs], device=self.device)
actions = torch.as_tensor(self.actions[idxs], device=self.device)
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
- next_obses = torch.as_tensor(
- self.next_obses[idxs], device=self.device
- ).float()
+ next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device).float()
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
- return obses, actions, rewards, next_obses, not_dones
+ return last_obses, last_actions, last_rewards, curr_obses, last_not_dones, actions, rewards, next_obses, not_dones
def save(self, save_dir):
if self.idx == self.last_save:
return
path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
payload = [
- self.obses[self.last_save:self.idx],
- self.next_obses[self.last_save:self.idx],
+ self.last_obses[self.last_save:self.idx],
+ self.last_actions[self.last_save:self.idx],
+ self.last_rewards[self.last_save:self.idx],
+ self.curr_obses[self.last_save:self.idx],
+ self.last_not_dones[self.last_save:self.idx],
self.actions[self.last_save:self.idx],
self.rewards[self.last_save:self.idx],
+ self.next_obses[self.last_save:self.idx],
self.not_dones[self.last_save:self.idx]
]
self.last_save = self.idx
@@ -132,10 +146,14 @@ class ReplayBuffer(object):
path = os.path.join(save_dir, chunk)
payload = torch.load(path)
assert self.idx == start
- self.obses[start:end] = payload[0]
- self.next_obses[start:end] = payload[1]
+ self.last_obses[start:end] = payload[0]
+ self.last_actions[start:end] = payload[1]
+ self.last_rewards[start:end] = payload[2]
+ self.curr_obses[start:end] = payload[3]
+ self.last_not_dones[start:end] = payload[4]
self.actions[start:end] = payload[2]
self.rewards[start:end] = payload[3]
+ self.next_obses[start:end] = payload[4]
self.not_dones[start:end] = payload[4]
self.idx = end