Initial commit
This commit is contained in:
commit
0fc1b8bd37
749
CARLA_0.9.6/PythonAPI/carla/agents/navigation/carla_env.py
Executable file
749
CARLA_0.9.6/PythonAPI/carla/agents/navigation/carla_env.py
Executable file
@ -0,0 +1,749 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de
|
||||||
|
# Barcelona (UAB).
|
||||||
|
#
|
||||||
|
# This work is licensed under the terms of the MIT license.
|
||||||
|
# For a copy, see <https://opensource.org/licenses/MIT>.
|
||||||
|
#
|
||||||
|
# Modified for DBC paper.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.PngImagePlugin import PngImageFile, PngInfo
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (
|
||||||
|
sys.version_info.major,
|
||||||
|
sys.version_info.minor,
|
||||||
|
'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import carla
|
||||||
|
import math
|
||||||
|
|
||||||
|
from dotmap import DotMap
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('cannot import pygame, make sure pygame package is installed')
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('cannot import numpy, make sure numpy package is installed')
|
||||||
|
|
||||||
|
try:
|
||||||
|
import queue
|
||||||
|
except ImportError:
|
||||||
|
import Queue as queue
|
||||||
|
|
||||||
|
from agents.navigation.agent import Agent, AgentState
|
||||||
|
from agents.navigation.local_planner import LocalPlanner
|
||||||
|
|
||||||
|
|
||||||
|
class CarlaSyncMode(object):
|
||||||
|
"""
|
||||||
|
Context manager to synchronize output from different sensors. Synchronous
|
||||||
|
mode is enabled as long as we are inside this context
|
||||||
|
|
||||||
|
with CarlaSyncMode(world, sensors) as sync_mode:
|
||||||
|
while True:
|
||||||
|
data = sync_mode.tick(timeout=1.0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, world, *sensors, **kwargs):
|
||||||
|
self.world = world
|
||||||
|
self.sensors = sensors
|
||||||
|
self.frame = None
|
||||||
|
self.delta_seconds = 1.0 / kwargs.get('fps', 20)
|
||||||
|
self._queues = []
|
||||||
|
self._settings = None
|
||||||
|
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._settings = self.world.get_settings()
|
||||||
|
self.frame = self.world.apply_settings(carla.WorldSettings(
|
||||||
|
no_rendering_mode=False,
|
||||||
|
synchronous_mode=True,
|
||||||
|
fixed_delta_seconds=self.delta_seconds))
|
||||||
|
|
||||||
|
def make_queue(register_event):
|
||||||
|
q = queue.Queue()
|
||||||
|
register_event(q.put)
|
||||||
|
self._queues.append(q)
|
||||||
|
|
||||||
|
make_queue(self.world.on_tick)
|
||||||
|
for sensor in self.sensors:
|
||||||
|
make_queue(sensor.listen)
|
||||||
|
|
||||||
|
def tick(self, timeout):
|
||||||
|
self.frame = self.world.tick()
|
||||||
|
data = [self._retrieve_data(q, timeout) for q in self._queues]
|
||||||
|
assert all(x.frame == self.frame for x in data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
self.world.apply_settings(self._settings)
|
||||||
|
|
||||||
|
def _retrieve_data(self, sensor_queue, timeout):
|
||||||
|
while True:
|
||||||
|
data = sensor_queue.get(timeout=timeout)
|
||||||
|
if data.frame == self.frame:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def draw_image(surface, image, blend=False):
|
||||||
|
array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
|
||||||
|
array = np.reshape(array, (image.height, image.width, 4))
|
||||||
|
array = array[:, :, :3]
|
||||||
|
array = array[:, :, ::-1]
|
||||||
|
image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))
|
||||||
|
if blend:
|
||||||
|
image_surface.set_alpha(100)
|
||||||
|
surface.blit(image_surface, (0, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def get_font():
|
||||||
|
fonts = [x for x in pygame.font.get_fonts()]
|
||||||
|
default_font = 'ubuntumono'
|
||||||
|
font = default_font if default_font in fonts else fonts[0]
|
||||||
|
font = pygame.font.match_font(font)
|
||||||
|
return pygame.font.Font(font, 14)
|
||||||
|
|
||||||
|
|
||||||
|
def should_quit():
|
||||||
|
for event in pygame.event.get():
|
||||||
|
if event.type == pygame.QUIT:
|
||||||
|
return True
|
||||||
|
elif event.type == pygame.KEYUP:
|
||||||
|
if event.key == pygame.K_ESCAPE:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(value, minimum=0.0, maximum=100.0):
|
||||||
|
return max(minimum, min(value, maximum))
|
||||||
|
|
||||||
|
|
||||||
|
class Sun(object):
|
||||||
|
def __init__(self, azimuth, altitude):
|
||||||
|
self.azimuth = azimuth
|
||||||
|
self.altitude = altitude
|
||||||
|
self._t = 0.0
|
||||||
|
|
||||||
|
def tick(self, delta_seconds):
|
||||||
|
self._t += 0.008 * delta_seconds
|
||||||
|
self._t %= 2.0 * math.pi
|
||||||
|
self.azimuth += 0.25 * delta_seconds
|
||||||
|
self.azimuth %= 360.0
|
||||||
|
# self.altitude = (70 * math.sin(self._t)) - 20 # [50, -90]
|
||||||
|
min_alt, max_alt = [20, 90]
|
||||||
|
self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)
|
||||||
|
|
||||||
|
|
||||||
|
class Storm(object):
|
||||||
|
def __init__(self, precipitation):
|
||||||
|
self._t = precipitation if precipitation > 0.0 else -50.0
|
||||||
|
self._increasing = True
|
||||||
|
self.clouds = 0.0
|
||||||
|
self.rain = 0.0
|
||||||
|
self.wetness = 0.0
|
||||||
|
self.puddles = 0.0
|
||||||
|
self.wind = 0.0
|
||||||
|
self.fog = 0.0
|
||||||
|
|
||||||
|
def tick(self, delta_seconds):
|
||||||
|
delta = (1.3 if self._increasing else -1.3) * delta_seconds
|
||||||
|
self._t = clamp(delta + self._t, -250.0, 100.0)
|
||||||
|
self.clouds = clamp(self._t + 40.0, 0.0, 60.0)
|
||||||
|
self.rain = clamp(self._t, 0.0, 80.0)
|
||||||
|
self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40
|
||||||
|
if self._t == -250.0:
|
||||||
|
self._increasing = True
|
||||||
|
if self._t == 100.0:
|
||||||
|
self._increasing = False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)
|
||||||
|
|
||||||
|
|
||||||
|
class Weather(object):
|
||||||
|
def __init__(self, world, changing_weather_speed):
|
||||||
|
self.world = world
|
||||||
|
self.reset()
|
||||||
|
self.weather = world.get_weather()
|
||||||
|
self.changing_weather_speed = changing_weather_speed
|
||||||
|
self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)
|
||||||
|
self._storm = Storm(self.weather.precipitation)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
weather_params = carla.WeatherParameters(sun_altitude_angle=90.)
|
||||||
|
self.world.set_weather(weather_params)
|
||||||
|
|
||||||
|
def tick(self):
|
||||||
|
self._sun.tick(self.changing_weather_speed)
|
||||||
|
self._storm.tick(self.changing_weather_speed)
|
||||||
|
self.weather.cloudiness = self._storm.clouds
|
||||||
|
self.weather.precipitation = self._storm.rain
|
||||||
|
self.weather.precipitation_deposits = self._storm.puddles
|
||||||
|
self.weather.wind_intensity = self._storm.wind
|
||||||
|
self.weather.fog_density = self._storm.fog
|
||||||
|
self.weather.wetness = self._storm.wetness
|
||||||
|
self.weather.sun_azimuth_angle = self._sun.azimuth
|
||||||
|
self.weather.sun_altitude_angle = self._sun.altitude
|
||||||
|
self.world.set_weather(self.weather)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '%s %s' % (self._sun, self._storm)
|
||||||
|
|
||||||
|
|
||||||
|
class CarlaEnv(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
render_display=0, # 0, 1
|
||||||
|
record_display_images=0, # 0, 1
|
||||||
|
record_rl_images=0, # 0, 1
|
||||||
|
changing_weather_speed=0.0, # [0, +inf)
|
||||||
|
display_text=0, # 0, 1
|
||||||
|
rl_image_size=84,
|
||||||
|
max_episode_steps=1000,
|
||||||
|
frame_skip=1,
|
||||||
|
is_other_cars=True,
|
||||||
|
start_lane=None,
|
||||||
|
fov=60, # degrees for rl camera
|
||||||
|
num_cameras=5,
|
||||||
|
port=2000
|
||||||
|
):
|
||||||
|
if record_display_images:
|
||||||
|
assert render_display
|
||||||
|
self.render_display = render_display
|
||||||
|
self.save_display_images = record_display_images
|
||||||
|
self.save_rl_images = record_rl_images
|
||||||
|
self.changing_weather_speed = changing_weather_speed
|
||||||
|
self.display_text = display_text
|
||||||
|
self.rl_image_size = rl_image_size
|
||||||
|
self._max_episode_steps = max_episode_steps # DMC uses this
|
||||||
|
self.frame_skip = frame_skip
|
||||||
|
self.is_other_cars = is_other_cars
|
||||||
|
self.start_lane = start_lane
|
||||||
|
self.num_cameras = num_cameras
|
||||||
|
|
||||||
|
self.actor_list = []
|
||||||
|
|
||||||
|
if self.render_display:
|
||||||
|
pygame.init()
|
||||||
|
self.display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF)
|
||||||
|
self.font = get_font()
|
||||||
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
self.client = carla.Client('localhost', port)
|
||||||
|
self.client.set_timeout(5.0)
|
||||||
|
|
||||||
|
self.world = self.client.load_world("Town04")
|
||||||
|
self.map = self.world.get_map()
|
||||||
|
assert self.map.name == "Town04"
|
||||||
|
|
||||||
|
# remove old vehicles and sensors (in case they survived)
|
||||||
|
self.world.tick()
|
||||||
|
actor_list = self.world.get_actors()
|
||||||
|
for vehicle in actor_list.filter("*vehicle*"):
|
||||||
|
# if vehicle.id != self.vehicle.id:
|
||||||
|
print("Warning: removing old vehicle")
|
||||||
|
vehicle.destroy()
|
||||||
|
for sensor in actor_list.filter("*sensor*"):
|
||||||
|
print("Warning: removing old sensor")
|
||||||
|
sensor.destroy()
|
||||||
|
|
||||||
|
self.vehicle = None
|
||||||
|
self.vehicle_start_pose = None
|
||||||
|
self.vehicles_list = [] # their ids
|
||||||
|
self.vehicles = None
|
||||||
|
self.reset_vehicle() # creates self.vehicle
|
||||||
|
self.actor_list.append(self.vehicle)
|
||||||
|
|
||||||
|
blueprint_library = self.world.get_blueprint_library()
|
||||||
|
|
||||||
|
if render_display:
|
||||||
|
self.camera_rgb = self.world.spawn_actor(
|
||||||
|
blueprint_library.find('sensor.camera.rgb'),
|
||||||
|
carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)),
|
||||||
|
attach_to=self.vehicle)
|
||||||
|
self.actor_list.append(self.camera_rgb)
|
||||||
|
|
||||||
|
# we'll use up to five cameras, which we'll stitch together
|
||||||
|
bp = blueprint_library.find('sensor.camera.rgb')
|
||||||
|
bp.set_attribute('image_size_x', str(self.rl_image_size))
|
||||||
|
bp.set_attribute('image_size_y', str(self.rl_image_size))
|
||||||
|
bp.set_attribute('fov', str(fov))
|
||||||
|
location = carla.Location(x=1.6, z=1.7)
|
||||||
|
self.camera_rl = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_left = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=-float(fov))), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_lefter = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=-2*float(fov))), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_right = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=float(fov))), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_righter = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=2*float(fov))), attach_to=self.vehicle)
|
||||||
|
self.actor_list.append(self.camera_rl)
|
||||||
|
self.actor_list.append(self.camera_rl_left)
|
||||||
|
self.actor_list.append(self.camera_rl_lefter)
|
||||||
|
self.actor_list.append(self.camera_rl_right)
|
||||||
|
self.actor_list.append(self.camera_rl_righter)
|
||||||
|
|
||||||
|
bp = self.world.get_blueprint_library().find('sensor.other.collision')
|
||||||
|
self.collision_sensor = self.world.spawn_actor(bp, carla.Transform(), attach_to=self.vehicle)
|
||||||
|
self.collision_sensor.listen(lambda event: self._on_collision(event))
|
||||||
|
self.actor_list.append(self.collision_sensor)
|
||||||
|
self._collision_intensities_during_last_time_step = []
|
||||||
|
|
||||||
|
if self.save_display_images or self.save_rl_images:
|
||||||
|
import datetime
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
image_dir = "images-" + now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
os.mkdir(image_dir)
|
||||||
|
self.image_dir = image_dir
|
||||||
|
|
||||||
|
if self.render_display:
|
||||||
|
self.sync_mode = CarlaSyncMode(self.world, self.camera_rgb, self.camera_rl, self.camera_rl_left, self.camera_rl_lefter, self.camera_rl_right, self.camera_rl_righter, fps=20)
|
||||||
|
else:
|
||||||
|
self.sync_mode = CarlaSyncMode(self.world, self.camera_rl, self.camera_rl_left, self.camera_rl_lefter, self.camera_rl_right, self.camera_rl_righter, fps=20)
|
||||||
|
|
||||||
|
# weather
|
||||||
|
self.weather = Weather(self.world, self.changing_weather_speed)
|
||||||
|
|
||||||
|
# dummy variables given bisim's assumption on deep-mind-control suite APIs
|
||||||
|
low = -1.0
|
||||||
|
high = 1.0
|
||||||
|
self.action_space = DotMap()
|
||||||
|
self.action_space.low.min = lambda: low
|
||||||
|
self.action_space.high.max = lambda: high
|
||||||
|
self.action_space.shape = [2]
|
||||||
|
self.observation_space = DotMap()
|
||||||
|
self.observation_space.shape = (3, rl_image_size, num_cameras * rl_image_size)
|
||||||
|
self.observation_space.dtype = np.dtype(np.uint8)
|
||||||
|
self.reward_range = None
|
||||||
|
self.metadata = None
|
||||||
|
self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32)
|
||||||
|
|
||||||
|
# roaming carla agent
|
||||||
|
self.agent = None
|
||||||
|
self.count = 0
|
||||||
|
self.dist_s = 0
|
||||||
|
self.return_ = 0
|
||||||
|
self.velocities = []
|
||||||
|
self.world.tick()
|
||||||
|
self.reset() # creates self.agent
|
||||||
|
|
||||||
|
def dist_from_center_lane(self, vehicle, info):
|
||||||
|
|
||||||
|
# assume on highway
|
||||||
|
|
||||||
|
vehicle_location = vehicle.get_location()
|
||||||
|
vehicle_waypoint = self.map.get_waypoint(vehicle_location)
|
||||||
|
vehicle_velocity = vehicle.get_velocity() # Vecor3D
|
||||||
|
vehicle_velocity_xy = np.array([vehicle_velocity.x, vehicle_velocity.y])
|
||||||
|
speed = np.linalg.norm(vehicle_velocity_xy)
|
||||||
|
|
||||||
|
vehicle_waypoint_closest_to_road = \
|
||||||
|
self.map.get_waypoint(vehicle_location, project_to_road=True, lane_type=carla.LaneType.Driving)
|
||||||
|
road_id = vehicle_waypoint_closest_to_road.road_id
|
||||||
|
assert road_id is not None
|
||||||
|
lane_id_sign = int(np.sign(vehicle_waypoint_closest_to_road.lane_id))
|
||||||
|
assert lane_id_sign in [-1, 1]
|
||||||
|
|
||||||
|
current_waypoint = self.map.get_waypoint(vehicle_location, project_to_road=False)
|
||||||
|
if current_waypoint is None:
|
||||||
|
print("Episode fail: current waypoint is off the road! (frame %d)" % self.count)
|
||||||
|
info['reason_episode_ended'] = 'off_road'
|
||||||
|
done, dist, vel_s = True, 100., 0.
|
||||||
|
return dist, vel_s, speed, done, info
|
||||||
|
|
||||||
|
goal_waypoint = current_waypoint.next(5.)[0]
|
||||||
|
|
||||||
|
if goal_waypoint is None:
|
||||||
|
print("Episode fail: goal waypoint is off the road! (frame %d)" % self.count)
|
||||||
|
info['reason_episode_ended'] = 'off_road'
|
||||||
|
done, dist, vel_s = True, 100., 0.
|
||||||
|
else:
|
||||||
|
goal_location = goal_waypoint.transform.location
|
||||||
|
goal_xy = np.array([goal_location.x, goal_location.y])
|
||||||
|
dist = 0.
|
||||||
|
|
||||||
|
next_goal_waypoint = goal_waypoint.next(0.1) # waypoints are ever 0.02 meters
|
||||||
|
if len(next_goal_waypoint) != 1:
|
||||||
|
print('warning: {} waypoints (not 1)'.format(len(next_goal_waypoint)))
|
||||||
|
if len(next_goal_waypoint) == 0:
|
||||||
|
print("Episode done: no more waypoints left. (frame %d)" % self.count)
|
||||||
|
info['reason_episode_ended'] = 'no_waypoints'
|
||||||
|
done, vel_s = True, 0.
|
||||||
|
else:
|
||||||
|
location_ahead = next_goal_waypoint[0].transform.location
|
||||||
|
highway_vector = np.array([location_ahead.x, location_ahead.y]) - goal_xy
|
||||||
|
highway_unit_vector = np.array(highway_vector) / np.linalg.norm(highway_vector)
|
||||||
|
vel_s = np.dot(vehicle_velocity_xy, highway_unit_vector)
|
||||||
|
done = False
|
||||||
|
|
||||||
|
# not algorithm's fault, but the simulator sometimes throws the car in the air wierdly
|
||||||
|
if vehicle_velocity.z > 1. and self.count < 20:
|
||||||
|
print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_velocity.z, self.count))
|
||||||
|
info['reason_episode_ended'] = 'carla_bug'
|
||||||
|
done = True
|
||||||
|
if vehicle_location.z > 0.5 and self.count < 20:
|
||||||
|
print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_location.z, self.count))
|
||||||
|
info['reason_episode_ended'] = 'carla_bug'
|
||||||
|
done = True
|
||||||
|
|
||||||
|
return dist, vel_s, speed, done, info
|
||||||
|
|
||||||
|
def _on_collision(self, event):
|
||||||
|
impulse = event.normal_impulse
|
||||||
|
intensity = math.sqrt(impulse.x ** 2 + impulse.y ** 2 + impulse.z ** 2)
|
||||||
|
print('Collision (intensity {})'.format(intensity))
|
||||||
|
self._collision_intensities_during_last_time_step.append(intensity)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.reset_vehicle()
|
||||||
|
self.world.tick()
|
||||||
|
self.reset_other_vehicles()
|
||||||
|
self.world.tick()
|
||||||
|
self.agent = RoamingAgentModified(self.vehicle, follow_traffic_lights=False)
|
||||||
|
self.count = 0
|
||||||
|
self.dist_s = 0
|
||||||
|
self.return_ = 0
|
||||||
|
self.velocities = []
|
||||||
|
|
||||||
|
# get obs:
|
||||||
|
obs, _, _, _ = self.step(action=None)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def reset_vehicle(self):
|
||||||
|
start_lane = self.start_lane if self.start_lane is not None else np.random.choice([1, 2, 3, 4])
|
||||||
|
start_x = 1.5 + 3.5 * start_lane # 3.5 = lane width
|
||||||
|
self.vehicle_start_pose = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90))
|
||||||
|
if self.vehicle is None:
|
||||||
|
# create vehicle
|
||||||
|
blueprint_library = self.world.get_blueprint_library()
|
||||||
|
vehicle_blueprint = blueprint_library.find('vehicle.audi.a2')
|
||||||
|
self.vehicle = self.world.spawn_actor(vehicle_blueprint, self.vehicle_start_pose)
|
||||||
|
else:
|
||||||
|
self.vehicle.set_transform(self.vehicle_start_pose)
|
||||||
|
self.vehicle.set_velocity(carla.Vector3D())
|
||||||
|
self.vehicle.set_angular_velocity(carla.Vector3D())
|
||||||
|
|
||||||
|
def reset_other_vehicles(self):
|
||||||
|
if not self.is_other_cars:
|
||||||
|
return
|
||||||
|
|
||||||
|
# clear out old vehicles
|
||||||
|
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
|
||||||
|
self.world.tick()
|
||||||
|
self.vehicles_list = []
|
||||||
|
|
||||||
|
blueprints = self.world.get_blueprint_library().filter('vehicle.*')
|
||||||
|
blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4]
|
||||||
|
|
||||||
|
num_vehicles = 10
|
||||||
|
other_car_transforms = []
|
||||||
|
for _ in range(num_vehicles):
|
||||||
|
lane_id = random.choice([1, 2, 3, 4])
|
||||||
|
start_x = 1.5 + 3.5 * lane_id
|
||||||
|
start_y = random.uniform(-40., 40.)
|
||||||
|
transform = carla.Transform(carla.Location(x=start_x, y=start_y, z=0.1), carla.Rotation(yaw=-90))
|
||||||
|
other_car_transforms.append(transform)
|
||||||
|
|
||||||
|
# Spawn vehicles
|
||||||
|
batch = []
|
||||||
|
for n, transform in enumerate(other_car_transforms):
|
||||||
|
blueprint = random.choice(blueprints)
|
||||||
|
if blueprint.has_attribute('color'):
|
||||||
|
color = random.choice(blueprint.get_attribute('color').recommended_values)
|
||||||
|
blueprint.set_attribute('color', color)
|
||||||
|
if blueprint.has_attribute('driver_id'):
|
||||||
|
driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values)
|
||||||
|
blueprint.set_attribute('driver_id', driver_id)
|
||||||
|
blueprint.set_attribute('role_name', 'autopilot')
|
||||||
|
batch.append(carla.command.SpawnActor(blueprint, transform).then(
|
||||||
|
carla.command.SetAutopilot(carla.command.FutureActor, True)))
|
||||||
|
for response in self.client.apply_batch_sync(batch, False):
|
||||||
|
self.vehicles_list.append(response.actor_id)
|
||||||
|
|
||||||
|
for response in self.client.apply_batch_sync(batch):
|
||||||
|
if response.error:
|
||||||
|
pass
|
||||||
|
# print(response.error)
|
||||||
|
else:
|
||||||
|
self.vehicles_list.append(response.actor_id)
|
||||||
|
|
||||||
|
def compute_steer_action(self):
|
||||||
|
control = self.agent.run_step() # PID decides control.steer
|
||||||
|
steer = control.steer
|
||||||
|
throttle = control.throttle
|
||||||
|
brake = control.brake
|
||||||
|
throttle_brake = -brake
|
||||||
|
if throttle > 0.:
|
||||||
|
throttle_brake = throttle
|
||||||
|
steer_action = np.array([steer, throttle_brake], dtype=np.float32)
|
||||||
|
return steer_action
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
rewards = []
|
||||||
|
for _ in range(self.frame_skip): # default 1
|
||||||
|
next_obs, reward, done, info = self._simulator_step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
return next_obs, np.mean(rewards), done, info # just last info?
|
||||||
|
|
||||||
|
def _simulator_step(self, action, dt=0.05):
|
||||||
|
|
||||||
|
if self.render_display:
|
||||||
|
if should_quit():
|
||||||
|
return
|
||||||
|
self.clock.tick()
|
||||||
|
|
||||||
|
if action is not None:
|
||||||
|
steer = float(action[0])
|
||||||
|
throttle_brake = float(action[1])
|
||||||
|
if throttle_brake >= 0.0:
|
||||||
|
throttle = throttle_brake
|
||||||
|
brake = 0.0
|
||||||
|
else:
|
||||||
|
throttle = 0.0
|
||||||
|
brake = -throttle_brake
|
||||||
|
|
||||||
|
assert 0.0 <= throttle <= 1.0
|
||||||
|
assert -1.0 <= steer <= 1.0
|
||||||
|
assert 0.0 <= brake <= 1.0
|
||||||
|
vehicle_control = carla.VehicleControl(
|
||||||
|
throttle=throttle,
|
||||||
|
steer=steer,
|
||||||
|
brake=brake,
|
||||||
|
hand_brake=False,
|
||||||
|
reverse=False,
|
||||||
|
manual_gear_shift=False
|
||||||
|
)
|
||||||
|
self.vehicle.apply_control(vehicle_control)
|
||||||
|
else:
|
||||||
|
throttle, steer, brake = 0., 0., 0.
|
||||||
|
|
||||||
|
# Advance the simulation and wait for the data.
|
||||||
|
if self.render_display:
|
||||||
|
snapshot, image_rgb, image_rl, image_rl_left, image_rl_lefter, image_rl_right, image_rl_righter = self.sync_mode.tick(timeout=2.0)
|
||||||
|
else:
|
||||||
|
snapshot, image_rl, image_rl_left, image_rl_lefter, image_rl_right, image_rl_righter = self.sync_mode.tick(timeout=2.0)
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
info['reason_episode_ended'] = ''
|
||||||
|
dist_from_center, vel_s, speed, done, info = self.dist_from_center_lane(self.vehicle, info)
|
||||||
|
collision_intensities_during_last_time_step = sum(self._collision_intensities_during_last_time_step)
|
||||||
|
self._collision_intensities_during_last_time_step.clear() # clear it ready for next time step
|
||||||
|
assert collision_intensities_during_last_time_step >= 0.
|
||||||
|
collision_cost = 0.0001 * collision_intensities_during_last_time_step
|
||||||
|
vel_t = math.sqrt(speed**2 - vel_s**2)
|
||||||
|
reward = vel_s * dt - collision_cost - abs(steer) # doesn't work if 0.001 cost collisions
|
||||||
|
|
||||||
|
info['crash_intensity'] = collision_intensities_during_last_time_step
|
||||||
|
info['steer'] = steer
|
||||||
|
info['brake'] = brake
|
||||||
|
info['distance'] = vel_s * dt
|
||||||
|
|
||||||
|
self.dist_s += vel_s * dt
|
||||||
|
self.return_ += reward
|
||||||
|
|
||||||
|
self.weather.tick()
|
||||||
|
|
||||||
|
# Draw the display.
|
||||||
|
if self.render_display:
|
||||||
|
draw_image(self.display, image_rgb)
|
||||||
|
if self.display_text:
|
||||||
|
self.display.blit(self.font.render('frame %d' % self.count, True, (255, 255, 255)), (8, 10))
|
||||||
|
self.display.blit(self.font.render('highway progression %4.1f m/s (%5.1f m) (%5.2f speed)' % (vel_s, self.dist_s, speed), True, (255, 255, 255)), (8, 28))
|
||||||
|
self.display.blit(self.font.render('%5.2f meters off center' % dist_from_center, True, (255, 255, 255)), (8, 46))
|
||||||
|
self.display.blit(self.font.render('%5.2f reward (return %.2f)' % (reward, self.return_), True, (255, 255, 255)), (8, 64))
|
||||||
|
self.display.blit(self.font.render('%5.2f collision intensity ' % collision_intensities_during_last_time_step, True, (255, 255, 255)), (8, 82))
|
||||||
|
self.display.blit(self.font.render('%5.2f thottle, %3.2f steer, %3.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 100))
|
||||||
|
self.display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 118))
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
rgbs = []
|
||||||
|
if self.num_cameras == 1:
|
||||||
|
ims = [image_rl]
|
||||||
|
elif self.num_cameras == 3:
|
||||||
|
ims = [image_rl_left, image_rl, image_rl_right]
|
||||||
|
elif self.num_cameras == 5:
|
||||||
|
ims = [image_rl_lefter, image_rl_left, image_rl, image_rl_right, image_rl_righter]
|
||||||
|
else:
|
||||||
|
raise ValueError("num cameras must be 1 or 3 or 5")
|
||||||
|
for im in ims:
|
||||||
|
bgra = np.array(im.raw_data).reshape(self.rl_image_size, self.rl_image_size, 4) # BGRA format
|
||||||
|
bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3)
|
||||||
|
rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3)
|
||||||
|
rgbs.append(rgb)
|
||||||
|
rgb = np.concatenate(rgbs, axis=1) # (84 x 252 x 3)
|
||||||
|
|
||||||
|
# Rowan added
|
||||||
|
if self.render_display and self.save_display_images:
|
||||||
|
image_name = os.path.join(self.image_dir, "display%08d.jpg" % self.count)
|
||||||
|
pygame.image.save(self.display, image_name)
|
||||||
|
# ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4
|
||||||
|
if self.save_rl_images:
|
||||||
|
image_name = os.path.join(self.image_dir, "rl%08d.png" % self.count)
|
||||||
|
|
||||||
|
im = Image.fromarray(rgb)
|
||||||
|
metadata = PngInfo()
|
||||||
|
metadata.add_text("throttle", str(throttle))
|
||||||
|
metadata.add_text("steer", str(steer))
|
||||||
|
metadata.add_text("brake", str(brake))
|
||||||
|
im.save(image_name, "PNG", pnginfo=metadata)
|
||||||
|
|
||||||
|
# # Example usage:
|
||||||
|
# from PIL.PngImagePlugin import PngImageFile
|
||||||
|
# im = PngImageFile("rl00001234.png")
|
||||||
|
# # Actions are stored in the image's metadata:
|
||||||
|
# print("Actions: %s" % im.text)
|
||||||
|
# throttle = float(im.text['throttle']) # range [0, 1]
|
||||||
|
# steer = float(im.text['steer']) # range [-1, 1]
|
||||||
|
# brake = float(im.text['brake']) # range [0, 1]
|
||||||
|
self.count += 1
|
||||||
|
|
||||||
|
next_obs = rgb # (84 x 252 x 3) or (84 x 420 x 3)
|
||||||
|
# debugging - to inspect images:
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
# plt.imshow(next_obs)
|
||||||
|
# plt.show()
|
||||||
|
next_obs = np.transpose(next_obs, [2, 0, 1]) # 3 x 84 x 84/252/420
|
||||||
|
assert next_obs.shape == self.observation_space.shape
|
||||||
|
if self.count >= self._max_episode_steps:
|
||||||
|
print("Episode success: I've reached the episode horizon ({}).".format(self._max_episode_steps))
|
||||||
|
info['reason_episode_ended'] = 'success'
|
||||||
|
done = True
|
||||||
|
if speed < 0.02 and self.count >= 100 and self.count % 100 == 0: # a hack, instead of a counter
|
||||||
|
print("Episode fail: speed too small ({}), think I'm stuck! (frame {})".format(speed, self.count))
|
||||||
|
info['reason_episode_ended'] = 'stuck'
|
||||||
|
done = True
|
||||||
|
return next_obs, reward, done, info
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
print('destroying actors.')
|
||||||
|
for actor in self.actor_list:
|
||||||
|
actor.destroy()
|
||||||
|
print('\ndestroying %d vehicles' % len(self.vehicles_list))
|
||||||
|
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
|
||||||
|
time.sleep(0.5)
|
||||||
|
pygame.quit()
|
||||||
|
print('done.')
|
||||||
|
|
||||||
|
|
||||||
|
class LocalPlannerModified(LocalPlanner):
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
pass # otherwise it deletes our vehicle object
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera
|
||||||
|
|
||||||
|
|
||||||
|
class RoamingAgentModified(Agent):
|
||||||
|
"""
|
||||||
|
RoamingAgent implements a basic agent that navigates scenes making random
|
||||||
|
choices when facing an intersection.
|
||||||
|
|
||||||
|
This agent respects traffic lights and other vehicles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vehicle, follow_traffic_lights=True):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param vehicle: actor to apply to local planner logic onto
|
||||||
|
"""
|
||||||
|
super(RoamingAgentModified, self).__init__(vehicle)
|
||||||
|
self._proximity_threshold = 10.0 # meters
|
||||||
|
self._state = AgentState.NAVIGATING
|
||||||
|
self._follow_traffic_lights = follow_traffic_lights
|
||||||
|
|
||||||
|
# for throttle 0.5, 0.75, 1.0
|
||||||
|
args_lateral_dict = {
|
||||||
|
'K_P': 1.0,
|
||||||
|
'K_D': 0.005,
|
||||||
|
'K_I': 0.0,
|
||||||
|
'dt': 1.0 / 20.0}
|
||||||
|
opt_dict = {'lateral_control_dict': args_lateral_dict}
|
||||||
|
|
||||||
|
self._local_planner = LocalPlannerModified(self._vehicle, opt_dict)
|
||||||
|
|
||||||
|
def run_step(self, debug=False):
|
||||||
|
"""
|
||||||
|
Execute one step of navigation.
|
||||||
|
:return: carla.VehicleControl
|
||||||
|
"""
|
||||||
|
|
||||||
|
# is there an obstacle in front of us?
|
||||||
|
hazard_detected = False
|
||||||
|
|
||||||
|
# retrieve relevant elements for safe navigation, i.e.: traffic lights
|
||||||
|
# and other vehicles
|
||||||
|
actor_list = self._world.get_actors()
|
||||||
|
vehicle_list = actor_list.filter("*vehicle*")
|
||||||
|
lights_list = actor_list.filter("*traffic_light*")
|
||||||
|
|
||||||
|
# check possible obstacles
|
||||||
|
vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)
|
||||||
|
if vehicle_state:
|
||||||
|
if debug:
|
||||||
|
print('!!! VEHICLE BLOCKING AHEAD [{}])'.format(vehicle.id))
|
||||||
|
|
||||||
|
self._state = AgentState.BLOCKED_BY_VEHICLE
|
||||||
|
hazard_detected = True
|
||||||
|
|
||||||
|
# check for the state of the traffic lights
|
||||||
|
light_state, traffic_light = self._is_light_red(lights_list)
|
||||||
|
if light_state and self._follow_traffic_lights:
|
||||||
|
if debug:
|
||||||
|
print('=== RED LIGHT AHEAD [{}])'.format(traffic_light.id))
|
||||||
|
|
||||||
|
self._state = AgentState.BLOCKED_RED_LIGHT
|
||||||
|
hazard_detected = True
|
||||||
|
|
||||||
|
if hazard_detected:
|
||||||
|
control = self.emergency_stop()
|
||||||
|
else:
|
||||||
|
self._state = AgentState.NAVIGATING
|
||||||
|
# standard local planner behavior
|
||||||
|
control = self._local_planner.run_step()
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
env = CarlaEnv(
|
||||||
|
render_display=1, # 0, 1
|
||||||
|
record_display_images=0, # 0, 1
|
||||||
|
record_rl_images=1, # 0, 1
|
||||||
|
changing_weather_speed=1.0, # [0, +inf)
|
||||||
|
display_text=1, # 0, 1
|
||||||
|
is_other_cars=True,
|
||||||
|
frame_skip=4,
|
||||||
|
max_episode_steps=100000,
|
||||||
|
rl_image_size=84,
|
||||||
|
start_lane=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action = env.compute_steer_action()
|
||||||
|
next_obs, reward, done, info = env.step(action)
|
||||||
|
obs = env.reset()
|
||||||
|
finally:
|
||||||
|
env.finish()
|
744
CARLA_0.9.8/PythonAPI/carla/agents/navigation/carla_env.py
Normal file
744
CARLA_0.9.8/PythonAPI/carla/agents/navigation/carla_env.py
Normal file
@ -0,0 +1,744 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de
|
||||||
|
# Barcelona (UAB).
|
||||||
|
#
|
||||||
|
# This work is licensed under the terms of the MIT license.
|
||||||
|
# For a copy, see <https://opensource.org/licenses/MIT>.
|
||||||
|
#
|
||||||
|
# Modified for DBC paper.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.PngImagePlugin import PngImageFile, PngInfo
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (
|
||||||
|
sys.version_info.major,
|
||||||
|
sys.version_info.minor,
|
||||||
|
'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import carla
|
||||||
|
import math
|
||||||
|
|
||||||
|
from dotmap import DotMap
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('cannot import pygame, make sure pygame package is installed')
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('cannot import numpy, make sure numpy package is installed')
|
||||||
|
|
||||||
|
try:
|
||||||
|
import queue
|
||||||
|
except ImportError:
|
||||||
|
import Queue as queue
|
||||||
|
|
||||||
|
from agents.navigation.agent import Agent, AgentState
|
||||||
|
from agents.navigation.local_planner import LocalPlanner
|
||||||
|
|
||||||
|
class CarlaSyncMode(object):
|
||||||
|
"""
|
||||||
|
Context manager to synchronize output from different sensors. Synchronous
|
||||||
|
mode is enabled as long as we are inside this context
|
||||||
|
|
||||||
|
with CarlaSyncMode(world, sensors) as sync_mode:
|
||||||
|
while True:
|
||||||
|
data = sync_mode.tick(timeout=1.0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, world, *sensors, **kwargs):
|
||||||
|
self.world = world
|
||||||
|
self.sensors = sensors
|
||||||
|
self.frame = None
|
||||||
|
self.delta_seconds = 1.0 / kwargs.get('fps', 20)
|
||||||
|
self._queues = []
|
||||||
|
self._settings = None
|
||||||
|
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._settings = self.world.get_settings()
|
||||||
|
self.frame = self.world.apply_settings(carla.WorldSettings(
|
||||||
|
no_rendering_mode=False,
|
||||||
|
synchronous_mode=True,
|
||||||
|
fixed_delta_seconds=self.delta_seconds))
|
||||||
|
|
||||||
|
def make_queue(register_event):
|
||||||
|
q = queue.Queue()
|
||||||
|
register_event(q.put)
|
||||||
|
self._queues.append(q)
|
||||||
|
|
||||||
|
make_queue(self.world.on_tick)
|
||||||
|
for sensor in self.sensors:
|
||||||
|
make_queue(sensor.listen)
|
||||||
|
|
||||||
|
def tick(self, timeout):
|
||||||
|
self.frame = self.world.tick()
|
||||||
|
data = [self._retrieve_data(q, timeout) for q in self._queues]
|
||||||
|
assert all(x.frame == self.frame for x in data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
self.world.apply_settings(self._settings)
|
||||||
|
|
||||||
|
def _retrieve_data(self, sensor_queue, timeout):
|
||||||
|
while True:
|
||||||
|
data = sensor_queue.get(timeout=timeout)
|
||||||
|
if data.frame == self.frame:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def draw_image(surface, image, blend=False):
|
||||||
|
array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
|
||||||
|
array = np.reshape(array, (image.height, image.width, 4))
|
||||||
|
array = array[:, :, :3]
|
||||||
|
array = array[:, :, ::-1]
|
||||||
|
image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))
|
||||||
|
if blend:
|
||||||
|
image_surface.set_alpha(100)
|
||||||
|
surface.blit(image_surface, (0, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def get_font():
|
||||||
|
fonts = [x for x in pygame.font.get_fonts()]
|
||||||
|
default_font = 'ubuntumono'
|
||||||
|
font = default_font if default_font in fonts else fonts[0]
|
||||||
|
font = pygame.font.match_font(font)
|
||||||
|
return pygame.font.Font(font, 14)
|
||||||
|
|
||||||
|
|
||||||
|
def should_quit():
|
||||||
|
for event in pygame.event.get():
|
||||||
|
if event.type == pygame.QUIT:
|
||||||
|
return True
|
||||||
|
elif event.type == pygame.KEYUP:
|
||||||
|
if event.key == pygame.K_ESCAPE:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(value, minimum=0.0, maximum=100.0):
|
||||||
|
return max(minimum, min(value, maximum))
|
||||||
|
|
||||||
|
|
||||||
|
class Sun(object):
|
||||||
|
def __init__(self, azimuth, altitude):
|
||||||
|
self.azimuth = azimuth
|
||||||
|
self.altitude = altitude
|
||||||
|
self._t = 0.0
|
||||||
|
|
||||||
|
def tick(self, delta_seconds):
|
||||||
|
self._t += 0.008 * delta_seconds
|
||||||
|
self._t %= 2.0 * math.pi
|
||||||
|
self.azimuth += 0.25 * delta_seconds
|
||||||
|
self.azimuth %= 360.0
|
||||||
|
# self.altitude = (70 * math.sin(self._t)) - 20 # [50, -90]
|
||||||
|
min_alt, max_alt = [30, 90]
|
||||||
|
self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)
|
||||||
|
|
||||||
|
|
||||||
|
class Storm(object):
|
||||||
|
def __init__(self, precipitation):
|
||||||
|
self._t = precipitation if precipitation > 0.0 else -50.0
|
||||||
|
self._increasing = True
|
||||||
|
self.clouds = 0.0
|
||||||
|
self.rain = 0.0
|
||||||
|
self.wetness = 0.0
|
||||||
|
self.puddles = 0.0
|
||||||
|
self.wind = 0.0
|
||||||
|
self.fog = 0.0
|
||||||
|
|
||||||
|
def tick(self, delta_seconds):
|
||||||
|
delta = (1.3 if self._increasing else -1.3) * delta_seconds
|
||||||
|
self._t = clamp(delta + self._t, -250.0, 100.0)
|
||||||
|
self.clouds = clamp(self._t + 40.0, 0.0, 60.0)
|
||||||
|
self.rain = clamp(self._t, 0.0, 80.0)
|
||||||
|
self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40
|
||||||
|
if self._t == -250.0:
|
||||||
|
self._increasing = True
|
||||||
|
if self._t == 100.0:
|
||||||
|
self._increasing = False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)
|
||||||
|
|
||||||
|
|
||||||
|
class Weather(object):
|
||||||
|
def __init__(self, world, changing_weather_speed):
|
||||||
|
self.world = world
|
||||||
|
self.reset()
|
||||||
|
self.weather = world.get_weather()
|
||||||
|
self.changing_weather_speed = changing_weather_speed
|
||||||
|
self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)
|
||||||
|
self._storm = Storm(self.weather.precipitation)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
weather_params = carla.WeatherParameters(sun_altitude_angle=90.)
|
||||||
|
self.world.set_weather(weather_params)
|
||||||
|
|
||||||
|
def tick(self):
|
||||||
|
self._sun.tick(self.changing_weather_speed)
|
||||||
|
self._storm.tick(self.changing_weather_speed)
|
||||||
|
self.weather.cloudiness = self._storm.clouds
|
||||||
|
self.weather.precipitation = self._storm.rain
|
||||||
|
self.weather.precipitation_deposits = self._storm.puddles
|
||||||
|
self.weather.wind_intensity = self._storm.wind
|
||||||
|
self.weather.fog_density = self._storm.fog
|
||||||
|
self.weather.wetness = self._storm.wetness
|
||||||
|
self.weather.sun_azimuth_angle = self._sun.azimuth
|
||||||
|
self.weather.sun_altitude_angle = self._sun.altitude
|
||||||
|
self.world.set_weather(self.weather)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '%s %s' % (self._sun, self._storm)
|
||||||
|
|
||||||
|
|
||||||
|
class CarlaEnv(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
render_display=0, # 0, 1
|
||||||
|
record_display_images=0, # 0, 1
|
||||||
|
record_rl_images=0, # 0, 1
|
||||||
|
changing_weather_speed=0.0, # [0, +inf)
|
||||||
|
display_text=0, # 0, 1
|
||||||
|
rl_image_size=84,
|
||||||
|
max_episode_steps=1000,
|
||||||
|
frame_skip=1,
|
||||||
|
is_other_cars=True,
|
||||||
|
fov=60, # degrees for rl camera
|
||||||
|
num_cameras=3,
|
||||||
|
port=2000
|
||||||
|
):
|
||||||
|
if record_display_images:
|
||||||
|
assert render_display
|
||||||
|
self.render_display = render_display
|
||||||
|
self.save_display_images = record_display_images
|
||||||
|
self.save_rl_images = record_rl_images
|
||||||
|
self.changing_weather_speed = changing_weather_speed
|
||||||
|
self.display_text = display_text
|
||||||
|
self.rl_image_size = rl_image_size
|
||||||
|
self._max_episode_steps = max_episode_steps # DMC uses this
|
||||||
|
self.frame_skip = frame_skip
|
||||||
|
self.is_other_cars = is_other_cars
|
||||||
|
self.num_cameras = num_cameras
|
||||||
|
|
||||||
|
self.actor_list = []
|
||||||
|
|
||||||
|
if self.render_display:
|
||||||
|
pygame.init()
|
||||||
|
self.display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF)
|
||||||
|
self.font = get_font()
|
||||||
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
self.client = carla.Client('localhost', port)
|
||||||
|
self.client.set_timeout(5.0)
|
||||||
|
self.world = self.client.load_world("Town04")
|
||||||
|
self.map = self.world.get_map()
|
||||||
|
assert self.map.name == "Town04"
|
||||||
|
|
||||||
|
# remove old vehicles and sensors (in case they survived)
|
||||||
|
self.world.tick()
|
||||||
|
actor_list = self.world.get_actors()
|
||||||
|
for vehicle in actor_list.filter("*vehicle*"):
|
||||||
|
print("Warning: removing old vehicle")
|
||||||
|
vehicle.destroy()
|
||||||
|
for sensor in actor_list.filter("*sensor*"):
|
||||||
|
print("Warning: removing old sensor")
|
||||||
|
sensor.destroy()
|
||||||
|
|
||||||
|
self.vehicle = None
|
||||||
|
self.vehicle_start_pose = None
|
||||||
|
self.vehicles_list = [] # their ids
|
||||||
|
self.vehicles = None
|
||||||
|
self.reset_vehicle() # creates self.vehicle
|
||||||
|
self.actor_list.append(self.vehicle)
|
||||||
|
|
||||||
|
blueprint_library = self.world.get_blueprint_library()
|
||||||
|
|
||||||
|
if render_display:
|
||||||
|
bp = blueprint_library.find('sensor.camera.rgb')
|
||||||
|
bp.set_attribute('enable_postprocess_effects', str(True))
|
||||||
|
self.camera_rgb = self.world.spawn_actor(bp, carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)), attach_to=self.vehicle)
|
||||||
|
self.actor_list.append(self.camera_rgb)
|
||||||
|
|
||||||
|
# we'll use up to five cameras, which we'll stitch together
|
||||||
|
bp = blueprint_library.find('sensor.camera.rgb')
|
||||||
|
bp.set_attribute('image_size_x', str(self.rl_image_size))
|
||||||
|
bp.set_attribute('image_size_y', str(self.rl_image_size))
|
||||||
|
bp.set_attribute('fov', str(fov))
|
||||||
|
bp.set_attribute('enable_postprocess_effects', str(True))
|
||||||
|
location = carla.Location(x=1.6, z=1.7)
|
||||||
|
self.camera_rl = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_left = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=-float(fov))), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_lefter = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=-2*float(fov))), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_right = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=float(fov))), attach_to=self.vehicle)
|
||||||
|
self.camera_rl_righter = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=2*float(fov))), attach_to=self.vehicle)
|
||||||
|
self.actor_list.append(self.camera_rl)
|
||||||
|
self.actor_list.append(self.camera_rl_left)
|
||||||
|
self.actor_list.append(self.camera_rl_lefter)
|
||||||
|
self.actor_list.append(self.camera_rl_right)
|
||||||
|
self.actor_list.append(self.camera_rl_righter)
|
||||||
|
|
||||||
|
bp = self.world.get_blueprint_library().find('sensor.other.collision')
|
||||||
|
self.collision_sensor = self.world.spawn_actor(bp, carla.Transform(), attach_to=self.vehicle)
|
||||||
|
self.collision_sensor.listen(lambda event: self._on_collision(event))
|
||||||
|
self.actor_list.append(self.collision_sensor)
|
||||||
|
self._collision_intensities_during_last_time_step = []
|
||||||
|
|
||||||
|
if self.save_display_images or self.save_rl_images:
|
||||||
|
import datetime
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
image_dir = "images-" + now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
os.mkdir(image_dir)
|
||||||
|
self.image_dir = image_dir
|
||||||
|
|
||||||
|
if self.render_display:
|
||||||
|
self.sync_mode = CarlaSyncMode(self.world, self.camera_rgb, self.camera_rl, self.camera_rl_left, self.camera_rl_lefter, self.camera_rl_right, self.camera_rl_righter, fps=20)
|
||||||
|
else:
|
||||||
|
self.sync_mode = CarlaSyncMode(self.world, self.camera_rl, self.camera_rl_left, self.camera_rl_lefter, self.camera_rl_right, self.camera_rl_righter, fps=20)
|
||||||
|
|
||||||
|
# weather
|
||||||
|
self.weather = Weather(self.world, self.changing_weather_speed)
|
||||||
|
|
||||||
|
# dummy variables given bisim's assumption on deep-mind-control suite APIs
|
||||||
|
low = -1.0
|
||||||
|
high = 1.0
|
||||||
|
self.action_space = DotMap()
|
||||||
|
self.action_space.low.min = lambda: low
|
||||||
|
self.action_space.high.max = lambda: high
|
||||||
|
self.action_space.shape = [2]
|
||||||
|
self.observation_space = DotMap()
|
||||||
|
self.observation_space.shape = (3, rl_image_size, num_cameras * rl_image_size)
|
||||||
|
self.observation_space.dtype = np.dtype(np.uint8)
|
||||||
|
self.reward_range = None
|
||||||
|
self.metadata = None
|
||||||
|
self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32)
|
||||||
|
|
||||||
|
# roaming carla agent
|
||||||
|
self.agent = None
|
||||||
|
self.count = 0
|
||||||
|
self.dist_s = 0
|
||||||
|
self.return_ = 0
|
||||||
|
self.collide_count = 0
|
||||||
|
self.velocities = []
|
||||||
|
self.world.tick()
|
||||||
|
self.reset() # creates self.agent
|
||||||
|
|
||||||
|
def dist_from_center_lane(self, vehicle):
|
||||||
|
# assume on highway
|
||||||
|
vehicle_location = vehicle.get_location()
|
||||||
|
vehicle_waypoint = self.map.get_waypoint(vehicle_location)
|
||||||
|
vehicle_xy = np.array([vehicle_location.x, vehicle_location.y])
|
||||||
|
vehicle_s = vehicle_waypoint.s
|
||||||
|
vehicle_velocity = vehicle.get_velocity() # Vecor3D
|
||||||
|
vehicle_velocity_xy = np.array([vehicle_velocity.x, vehicle_velocity.y])
|
||||||
|
speed = np.linalg.norm(vehicle_velocity_xy)
|
||||||
|
|
||||||
|
vehicle_waypoint_closest_to_road = \
|
||||||
|
self.map.get_waypoint(vehicle_location, project_to_road=True, lane_type=carla.LaneType.Driving)
|
||||||
|
road_id = vehicle_waypoint_closest_to_road.road_id
|
||||||
|
assert road_id is not None
|
||||||
|
lane_id = int(vehicle_waypoint_closest_to_road.lane_id)
|
||||||
|
goal_lane_id = lane_id
|
||||||
|
|
||||||
|
current_waypoint = self.map.get_waypoint(vehicle_location, project_to_road=False)
|
||||||
|
goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s)
|
||||||
|
if goal_waypoint is None:
|
||||||
|
# try to fix, bit of a hack, with CARLA waypoint discretizations
|
||||||
|
carla_waypoint_discretization = 0.02 # meters
|
||||||
|
goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s - carla_waypoint_discretization)
|
||||||
|
if goal_waypoint is None:
|
||||||
|
goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s + carla_waypoint_discretization)
|
||||||
|
|
||||||
|
if goal_waypoint is None:
|
||||||
|
print("Episode fail: goal waypoint is off the road! (frame %d)" % self.count)
|
||||||
|
done, dist, vel_s = True, 100., 0.
|
||||||
|
else:
|
||||||
|
goal_location = goal_waypoint.transform.location
|
||||||
|
goal_xy = np.array([goal_location.x, goal_location.y])
|
||||||
|
dist = np.linalg.norm(vehicle_xy - goal_xy)
|
||||||
|
|
||||||
|
next_goal_waypoint = goal_waypoint.next(0.1) # waypoints are ever 0.02 meters
|
||||||
|
if len(next_goal_waypoint) != 1:
|
||||||
|
print('warning: {} waypoints (not 1)'.format(len(next_goal_waypoint)))
|
||||||
|
if len(next_goal_waypoint) == 0:
|
||||||
|
print("Episode done: no more waypoints left. (frame %d)" % self.count)
|
||||||
|
done, vel_s = True, 0.
|
||||||
|
else:
|
||||||
|
location_ahead = next_goal_waypoint[0].transform.location
|
||||||
|
highway_vector = np.array([location_ahead.x, location_ahead.y]) - goal_xy
|
||||||
|
highway_unit_vector = np.array(highway_vector) / np.linalg.norm(highway_vector)
|
||||||
|
vel_s = np.dot(vehicle_velocity_xy, highway_unit_vector)
|
||||||
|
done = False
|
||||||
|
|
||||||
|
# not algorithm's fault, but the simulator sometimes throws the car in the air wierdly
|
||||||
|
if vehicle_velocity.z > 1. and self.count < 20:
|
||||||
|
print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_velocity.z, self.count))
|
||||||
|
done = True
|
||||||
|
if vehicle_location.z > 0.5 and self.count < 20:
|
||||||
|
print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_location.z, self.count))
|
||||||
|
done = True
|
||||||
|
|
||||||
|
return dist, vel_s, speed, done
|
||||||
|
|
||||||
|
def _on_collision(self, event):
|
||||||
|
impulse = event.normal_impulse
|
||||||
|
intensity = math.sqrt(impulse.x ** 2 + impulse.y ** 2 + impulse.z ** 2)
|
||||||
|
print('Collision (intensity {})'.format(intensity))
|
||||||
|
self._collision_intensities_during_last_time_step.append(intensity)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.reset_vehicle()
|
||||||
|
self.world.tick()
|
||||||
|
self.reset_other_vehicles()
|
||||||
|
self.world.tick()
|
||||||
|
self.agent = RoamingAgentModified(self.vehicle, follow_traffic_lights=False)
|
||||||
|
self.count = 0
|
||||||
|
self.dist_s = 0
|
||||||
|
self.return_ = 0
|
||||||
|
self.velocities = []
|
||||||
|
|
||||||
|
obs, _, _, _ = self.step(action=None)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def reset_vehicle(self):
|
||||||
|
start_lane = random.choice([-1, -2, -3, -4])
|
||||||
|
self.vehicle_start_pose = self.map.get_waypoint_xodr(road_id=45, lane_id=start_lane, s=100.).transform
|
||||||
|
if self.vehicle is None:
|
||||||
|
# create vehicle
|
||||||
|
blueprint_library = self.world.get_blueprint_library()
|
||||||
|
vehicle_blueprint = blueprint_library.find('vehicle.audi.a2')
|
||||||
|
self.vehicle = self.world.spawn_actor(vehicle_blueprint, self.vehicle_start_pose)
|
||||||
|
# self.vehicle.set_light_state(carla.libcarla.VehicleLightState.HighBeam) # HighBeam # LowBeam # All
|
||||||
|
else:
|
||||||
|
self.vehicle.set_transform(self.vehicle_start_pose)
|
||||||
|
self.vehicle.set_velocity(carla.Vector3D())
|
||||||
|
self.vehicle.set_angular_velocity(carla.Vector3D())
|
||||||
|
|
||||||
|
def reset_other_vehicles(self):
|
||||||
|
if not self.is_other_cars:
|
||||||
|
return
|
||||||
|
|
||||||
|
# clear out old vehicles
|
||||||
|
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
|
||||||
|
self.world.tick()
|
||||||
|
self.vehicles_list = []
|
||||||
|
|
||||||
|
traffic_manager = self.client.get_trafficmanager() # 8000? which port?
|
||||||
|
traffic_manager.set_global_distance_to_leading_vehicle(2.0)
|
||||||
|
traffic_manager.set_synchronous_mode(True)
|
||||||
|
blueprints = self.world.get_blueprint_library().filter('vehicle.*')
|
||||||
|
blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4]
|
||||||
|
|
||||||
|
road_id = 45
|
||||||
|
num_vehicles = 20
|
||||||
|
other_car_waypoints = []
|
||||||
|
for _ in range(num_vehicles):
|
||||||
|
lane_id = random.choice([-1, -2, -3, -4])
|
||||||
|
vehicle_s = np.random.uniform(100., 300.)
|
||||||
|
other_car_waypoints.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s))
|
||||||
|
|
||||||
|
# Spawn vehicles
|
||||||
|
batch = []
|
||||||
|
for n, waypoint in enumerate(other_car_waypoints):
|
||||||
|
transform = waypoint.transform
|
||||||
|
transform.location.z += 0.1
|
||||||
|
blueprint = random.choice(blueprints)
|
||||||
|
if blueprint.has_attribute('color'):
|
||||||
|
color = random.choice(blueprint.get_attribute('color').recommended_values)
|
||||||
|
blueprint.set_attribute('color', color)
|
||||||
|
if blueprint.has_attribute('driver_id'):
|
||||||
|
driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values)
|
||||||
|
blueprint.set_attribute('driver_id', driver_id)
|
||||||
|
blueprint.set_attribute('role_name', 'autopilot')
|
||||||
|
batch.append(carla.command.SpawnActor(blueprint, transform).then(
|
||||||
|
carla.command.SetAutopilot(carla.command.FutureActor, True)))
|
||||||
|
for response in self.client.apply_batch_sync(batch, False):
|
||||||
|
self.vehicles_list.append(response.actor_id)
|
||||||
|
|
||||||
|
for response in self.client.apply_batch_sync(batch):
|
||||||
|
if response.error:
|
||||||
|
pass
|
||||||
|
# print(response.error)
|
||||||
|
else:
|
||||||
|
self.vehicles_list.append(response.actor_id)
|
||||||
|
|
||||||
|
traffic_manager.global_percentage_speed_difference(30.0)
|
||||||
|
|
||||||
|
def compute_steer_action(self):
|
||||||
|
control = self.agent.run_step() # PID decides control.steer
|
||||||
|
steer = control.steer
|
||||||
|
throttle = control.throttle
|
||||||
|
brake = control.brake
|
||||||
|
throttle_brake = -brake
|
||||||
|
if throttle > 0.:
|
||||||
|
throttle_brake = throttle
|
||||||
|
steer_action = np.array([steer, throttle_brake], dtype=np.float32)
|
||||||
|
return steer_action
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
rewards = []
|
||||||
|
for _ in range(self.frame_skip): # default 1
|
||||||
|
next_obs, reward, done, info = self._simulator_step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
return next_obs, np.mean(rewards), done, info # just last info?
|
||||||
|
|
||||||
|
def _simulator_step(self, action, dt=0.05):
|
||||||
|
|
||||||
|
if self.render_display:
|
||||||
|
if should_quit():
|
||||||
|
return
|
||||||
|
self.clock.tick()
|
||||||
|
|
||||||
|
if action is not None:
|
||||||
|
steer = float(action[0])
|
||||||
|
throttle_brake = float(action[1])
|
||||||
|
if throttle_brake >= 0.0:
|
||||||
|
throttle = throttle_brake
|
||||||
|
brake = 0.0
|
||||||
|
else:
|
||||||
|
throttle = 0.0
|
||||||
|
brake = -throttle_brake
|
||||||
|
|
||||||
|
assert 0.0 <= throttle <= 1.0
|
||||||
|
assert -1.0 <= steer <= 1.0
|
||||||
|
assert 0.0 <= brake <= 1.0
|
||||||
|
vehicle_control = carla.VehicleControl(
|
||||||
|
throttle=throttle,
|
||||||
|
steer=steer,
|
||||||
|
brake=brake,
|
||||||
|
hand_brake=False,
|
||||||
|
reverse=False,
|
||||||
|
manual_gear_shift=False
|
||||||
|
)
|
||||||
|
self.vehicle.apply_control(vehicle_control)
|
||||||
|
else:
|
||||||
|
throttle, steer, brake = 0., 0., 0.
|
||||||
|
|
||||||
|
# Advance the simulation and wait for the data.
|
||||||
|
if self.render_display:
|
||||||
|
snapshot, image_rgb, image_rl, image_rl_left, image_rl_lefter, image_rl_right, image_rl_righter = self.sync_mode.tick(timeout=2.0)
|
||||||
|
else:
|
||||||
|
snapshot, image_rl, image_rl_left, image_rl_lefter, image_rl_right, image_rl_righter = self.sync_mode.tick(timeout=2.0)
|
||||||
|
|
||||||
|
dist_from_center, vel_s, speed, done = self.dist_from_center_lane(self.vehicle)
|
||||||
|
collision_intensities_during_last_time_step = sum(self._collision_intensities_during_last_time_step)
|
||||||
|
self._collision_intensities_during_last_time_step.clear() # clear it ready for next time step
|
||||||
|
assert collision_intensities_during_last_time_step >= 0.
|
||||||
|
colliding = float(collision_intensities_during_last_time_step > 0.)
|
||||||
|
if colliding:
|
||||||
|
self.collide_count += 1
|
||||||
|
else:
|
||||||
|
self.collide_count = 0
|
||||||
|
if self.collide_count >= 20:
|
||||||
|
print("Episode fail: too many collisions ({})! (frame {})".format(speed, self.collide_count))
|
||||||
|
done = True
|
||||||
|
|
||||||
|
reward = vel_s * dt / (1. + dist_from_center) - 1.0 * colliding - 0.1 * brake - 0.1 * abs(steer)
|
||||||
|
|
||||||
|
self.dist_s += vel_s * dt
|
||||||
|
self.return_ += reward
|
||||||
|
|
||||||
|
self.weather.tick()
|
||||||
|
|
||||||
|
# Draw the display.
|
||||||
|
if self.render_display:
|
||||||
|
draw_image(self.display, image_rgb)
|
||||||
|
if self.display_text:
|
||||||
|
self.display.blit(self.font.render('frame %d' % self.count, True, (255, 255, 255)), (8, 10))
|
||||||
|
self.display.blit(self.font.render('highway progression %4.1f m/s (%5.2f m) (%5.2f speed)' % (vel_s, self.dist_s, speed), True, (255, 255, 255)), (8, 28))
|
||||||
|
self.display.blit(self.font.render('%5.2f meters off center' % dist_from_center, True, (255, 255, 255)), (8, 46))
|
||||||
|
self.display.blit(self.font.render('%5.2f reward (return %.1f)' % (reward, self.return_), True, (255, 255, 255)), (8, 64))
|
||||||
|
self.display.blit(self.font.render('%5.2f collision intensity ' % collision_intensities_during_last_time_step, True, (255, 255, 255)), (8, 82))
|
||||||
|
self.display.blit(self.font.render('%5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 100))
|
||||||
|
self.display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 118))
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
rgbs = []
|
||||||
|
if self.num_cameras == 1:
|
||||||
|
ims = [image_rl]
|
||||||
|
elif self.num_cameras == 3:
|
||||||
|
ims = [image_rl_left, image_rl, image_rl_right]
|
||||||
|
elif self.num_cameras == 5:
|
||||||
|
ims = [image_rl_lefter, image_rl_left, image_rl, image_rl_right, image_rl_righter]
|
||||||
|
else:
|
||||||
|
raise ValueError("num cameras must be 1 or 3 or 5")
|
||||||
|
for im in ims:
|
||||||
|
bgra = np.array(im.raw_data).reshape(self.rl_image_size, self.rl_image_size, 4) # BGRA format
|
||||||
|
bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3)
|
||||||
|
rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3)
|
||||||
|
rgbs.append(rgb)
|
||||||
|
rgb = np.concatenate(rgbs, axis=1) # (84 x 252 x 3)
|
||||||
|
|
||||||
|
# Rowan added
|
||||||
|
if self.render_display and self.save_display_images:
|
||||||
|
image_name = os.path.join(self.image_dir, "display%08d.jpg" % self.count)
|
||||||
|
pygame.image.save(self.display, image_name)
|
||||||
|
# ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4
|
||||||
|
if self.save_rl_images:
|
||||||
|
image_name = os.path.join(self.image_dir, "rl%08d.png" % self.count)
|
||||||
|
im = Image.fromarray(rgb)
|
||||||
|
metadata = PngInfo()
|
||||||
|
metadata.add_text("throttle", str(throttle))
|
||||||
|
metadata.add_text("steer", str(steer))
|
||||||
|
metadata.add_text("brake", str(brake))
|
||||||
|
im.save(image_name, "PNG", pnginfo=metadata)
|
||||||
|
|
||||||
|
# # Example usage:
|
||||||
|
# from PIL.PngImagePlugin import PngImageFile
|
||||||
|
# im = PngImageFile("rl00001234.png")
|
||||||
|
# # Actions are stored in the image's metadata:
|
||||||
|
# print("Actions: %s" % im.text)
|
||||||
|
# throttle = float(im.text['throttle']) # range [0, 1]
|
||||||
|
# steer = float(im.text['steer']) # range [-1, 1]
|
||||||
|
# brake = float(im.text['brake']) # range [0, 1]
|
||||||
|
self.count += 1
|
||||||
|
|
||||||
|
next_obs = rgb # (84 x 252 x 3) or (84 x 420 x 3)
|
||||||
|
# debugging - to inspect images:
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
# plt.imshow(next_obs)
|
||||||
|
# plt.show()
|
||||||
|
next_obs = np.transpose(next_obs, [2, 0, 1]) # 3 x 84 x 84/252/420
|
||||||
|
assert next_obs.shape == self.observation_space.shape
|
||||||
|
if self.count >= self._max_episode_steps:
|
||||||
|
print("Episode success: I've reached the episode horizon ({}).".format(self._max_episode_steps))
|
||||||
|
done = True
|
||||||
|
if speed < 0.02 and self.count >= 100 and self.count % 100 == 0: # a hack, instead of a counter
|
||||||
|
print("Episode fail: speed too small ({}), think I'm stuck! (frame {})".format(speed, self.count))
|
||||||
|
done = True
|
||||||
|
info = None
|
||||||
|
return next_obs, reward, done, info
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
print('destroying actors.')
|
||||||
|
for actor in self.actor_list:
|
||||||
|
actor.destroy()
|
||||||
|
print('\ndestroying %d vehicles' % len(self.vehicles_list))
|
||||||
|
self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
|
||||||
|
time.sleep(0.5)
|
||||||
|
pygame.quit()
|
||||||
|
print('done.')
|
||||||
|
|
||||||
|
|
||||||
|
class LocalPlannerModified(LocalPlanner):
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
pass # otherwise it deletes our vehicle object
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera
|
||||||
|
|
||||||
|
|
||||||
|
class RoamingAgentModified(Agent):
|
||||||
|
"""
|
||||||
|
RoamingAgent implements a basic agent that navigates scenes making random
|
||||||
|
choices when facing an intersection.
|
||||||
|
|
||||||
|
This agent respects traffic lights and other vehicles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vehicle, follow_traffic_lights=True):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param vehicle: actor to apply to local planner logic onto
|
||||||
|
"""
|
||||||
|
super(RoamingAgentModified, self).__init__(vehicle)
|
||||||
|
self._proximity_threshold = 10.0 # meters
|
||||||
|
self._state = AgentState.NAVIGATING
|
||||||
|
self._follow_traffic_lights = follow_traffic_lights
|
||||||
|
|
||||||
|
# for throttle 0.5, 0.75, 1.0
|
||||||
|
args_lateral_dict = {
|
||||||
|
'K_P': 1.0,
|
||||||
|
'K_D': 0.005,
|
||||||
|
'K_I': 0.0,
|
||||||
|
'dt': 1.0 / 20.0}
|
||||||
|
opt_dict = {'lateral_control_dict': args_lateral_dict}
|
||||||
|
|
||||||
|
self._local_planner = LocalPlannerModified(self._vehicle, opt_dict)
|
||||||
|
|
||||||
|
def run_step(self, debug=False):
|
||||||
|
"""
|
||||||
|
Execute one step of navigation.
|
||||||
|
:return: carla.VehicleControl
|
||||||
|
"""
|
||||||
|
|
||||||
|
# is there an obstacle in front of us?
|
||||||
|
hazard_detected = False
|
||||||
|
|
||||||
|
# retrieve relevant elements for safe navigation, i.e.: traffic lights
|
||||||
|
# and other vehicles
|
||||||
|
actor_list = self._world.get_actors()
|
||||||
|
vehicle_list = actor_list.filter("*vehicle*")
|
||||||
|
lights_list = actor_list.filter("*traffic_light*")
|
||||||
|
|
||||||
|
# check possible obstacles
|
||||||
|
vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)
|
||||||
|
if vehicle_state:
|
||||||
|
if debug:
|
||||||
|
print('!!! VEHICLE BLOCKING AHEAD [{}])'.format(vehicle.id))
|
||||||
|
|
||||||
|
self._state = AgentState.BLOCKED_BY_VEHICLE
|
||||||
|
hazard_detected = True
|
||||||
|
|
||||||
|
# check for the state of the traffic lights
|
||||||
|
light_state, traffic_light = self._is_light_red(lights_list)
|
||||||
|
if light_state and self._follow_traffic_lights:
|
||||||
|
if debug:
|
||||||
|
print('=== RED LIGHT AHEAD [{}])'.format(traffic_light.id))
|
||||||
|
|
||||||
|
self._state = AgentState.BLOCKED_RED_LIGHT
|
||||||
|
hazard_detected = True
|
||||||
|
|
||||||
|
if hazard_detected:
|
||||||
|
control = self.emergency_stop()
|
||||||
|
else:
|
||||||
|
self._state = AgentState.NAVIGATING
|
||||||
|
# standard local planner behavior
|
||||||
|
control = self._local_planner.run_step()
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
env = CarlaEnv(
|
||||||
|
render_display=1, # 0, 1
|
||||||
|
record_display_images=1, # 0, 1
|
||||||
|
record_rl_images=1, # 0, 1
|
||||||
|
changing_weather_speed=1.0, # [0, +inf)
|
||||||
|
display_text=0, # 0, 1
|
||||||
|
is_other_cars=True,
|
||||||
|
frame_skip=4,
|
||||||
|
max_episode_steps=100000,
|
||||||
|
rl_image_size=84
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action = env.compute_steer_action()
|
||||||
|
next_obs, reward, done, info = env.step(action)
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
env.finish()
|
45
CODE_OF_CONDUCT.md
Normal file
45
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Open Source Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to creating a positive environment include:
|
||||||
|
|
||||||
|
Using welcoming and inclusive language
|
||||||
|
Being respectful of differing viewpoints and experiences
|
||||||
|
Gracefully accepting constructive criticism
|
||||||
|
Focusing on what is best for the community
|
||||||
|
Showing empathy towards other community members
|
||||||
|
Examples of unacceptable behavior by participants include:
|
||||||
|
|
||||||
|
The use of sexualized language or imagery and unwelcome sexual attention or advances
|
||||||
|
Trolling, insulting/derogatory comments, and personal or political attacks
|
||||||
|
Public or private harassment
|
||||||
|
Publishing others’ private information, such as a physical or electronic address, without explicit permission
|
||||||
|
Other conduct which could reasonably be considered inappropriate in a professional setting
|
||||||
|
|
||||||
|
## Our Responsibilities
|
||||||
|
|
||||||
|
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
|
||||||
|
|
||||||
|
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
|
||||||
|
|
||||||
|
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||||
|
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
39
CONTRIBUTING.md
Normal file
39
CONTRIBUTING.md
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Contributing to __________
|
||||||
|
We want to make contributing to this project as easy and transparent as
|
||||||
|
possible.
|
||||||
|
|
||||||
|
## Our Development Process
|
||||||
|
... (in particular how this is synced with internal changes to the project)
|
||||||
|
|
||||||
|
## Pull Requests
|
||||||
|
We actively welcome your pull requests.
|
||||||
|
|
||||||
|
1. Fork the repo and create your branch from `master`.
|
||||||
|
2. If you've added code that should be tested, add tests.
|
||||||
|
3. If you've changed APIs, update the documentation.
|
||||||
|
4. Ensure the test suite passes.
|
||||||
|
5. Make sure your code lints.
|
||||||
|
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||||
|
|
||||||
|
## Contributor License Agreement ("CLA")
|
||||||
|
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||||
|
to do this once to work on any of Facebook's open source projects.
|
||||||
|
|
||||||
|
Complete your CLA here: <https://code.facebook.com/cla>
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
We use GitHub issues to track public bugs. Please ensure your description is
|
||||||
|
clear and has sufficient instructions to be able to reproduce the issue.
|
||||||
|
|
||||||
|
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||||
|
disclosure of security bugs. In those cases, please go through the process
|
||||||
|
outlined on that page and do not file a public issue.
|
||||||
|
|
||||||
|
## Coding Style
|
||||||
|
* 2 spaces for indentation rather than tabs
|
||||||
|
* 80 character line length
|
||||||
|
* ...
|
||||||
|
|
||||||
|
## License
|
||||||
|
By contributing to __________, you agree that your contributions will be licensed
|
||||||
|
under the LICENSE file in the root directory of this source tree.
|
399
LICENSE
Normal file
399
LICENSE
Normal file
@ -0,0 +1,399 @@
|
|||||||
|
Attribution-NonCommercial 4.0 International
|
||||||
|
|
||||||
|
=======================================================================
|
||||||
|
|
||||||
|
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||||
|
does not provide legal services or legal advice. Distribution of
|
||||||
|
Creative Commons public licenses does not create a lawyer-client or
|
||||||
|
other relationship. Creative Commons makes its licenses and related
|
||||||
|
information available on an "as-is" basis. Creative Commons gives no
|
||||||
|
warranties regarding its licenses, any material licensed under their
|
||||||
|
terms and conditions, or any related information. Creative Commons
|
||||||
|
disclaims all liability for damages resulting from their use to the
|
||||||
|
fullest extent possible.
|
||||||
|
|
||||||
|
Using Creative Commons Public Licenses
|
||||||
|
|
||||||
|
Creative Commons public licenses provide a standard set of terms and
|
||||||
|
conditions that creators and other rights holders may use to share
|
||||||
|
original works of authorship and other material subject to copyright
|
||||||
|
and certain other rights specified in the public license below. The
|
||||||
|
following considerations are for informational purposes only, are not
|
||||||
|
exhaustive, and do not form part of our licenses.
|
||||||
|
|
||||||
|
Considerations for licensors: Our public licenses are
|
||||||
|
intended for use by those authorized to give the public
|
||||||
|
permission to use material in ways otherwise restricted by
|
||||||
|
copyright and certain other rights. Our licenses are
|
||||||
|
irrevocable. Licensors should read and understand the terms
|
||||||
|
and conditions of the license they choose before applying it.
|
||||||
|
Licensors should also secure all rights necessary before
|
||||||
|
applying our licenses so that the public can reuse the
|
||||||
|
material as expected. Licensors should clearly mark any
|
||||||
|
material not subject to the license. This includes other CC-
|
||||||
|
licensed material, or material used under an exception or
|
||||||
|
limitation to copyright. More considerations for licensors:
|
||||||
|
wiki.creativecommons.org/Considerations_for_licensors
|
||||||
|
|
||||||
|
Considerations for the public: By using one of our public
|
||||||
|
licenses, a licensor grants the public permission to use the
|
||||||
|
licensed material under specified terms and conditions. If
|
||||||
|
the licensor's permission is not necessary for any reason--for
|
||||||
|
example, because of any applicable exception or limitation to
|
||||||
|
copyright--then that use is not regulated by the license. Our
|
||||||
|
licenses grant only permissions under copyright and certain
|
||||||
|
other rights that a licensor has authority to grant. Use of
|
||||||
|
the licensed material may still be restricted for other
|
||||||
|
reasons, including because others have copyright or other
|
||||||
|
rights in the material. A licensor may make special requests,
|
||||||
|
such as asking that all changes be marked or described.
|
||||||
|
Although not required by our licenses, you are encouraged to
|
||||||
|
respect those requests where reasonable. More_considerations
|
||||||
|
for the public:
|
||||||
|
wiki.creativecommons.org/Considerations_for_licensees
|
||||||
|
|
||||||
|
=======================================================================
|
||||||
|
|
||||||
|
Creative Commons Attribution-NonCommercial 4.0 International Public
|
||||||
|
License
|
||||||
|
|
||||||
|
By exercising the Licensed Rights (defined below), You accept and agree
|
||||||
|
to be bound by the terms and conditions of this Creative Commons
|
||||||
|
Attribution-NonCommercial 4.0 International Public License ("Public
|
||||||
|
License"). To the extent this Public License may be interpreted as a
|
||||||
|
contract, You are granted the Licensed Rights in consideration of Your
|
||||||
|
acceptance of these terms and conditions, and the Licensor grants You
|
||||||
|
such rights in consideration of benefits the Licensor receives from
|
||||||
|
making the Licensed Material available under these terms and
|
||||||
|
conditions.
|
||||||
|
|
||||||
|
Section 1 -- Definitions.
|
||||||
|
|
||||||
|
a. Adapted Material means material subject to Copyright and Similar
|
||||||
|
Rights that is derived from or based upon the Licensed Material
|
||||||
|
and in which the Licensed Material is translated, altered,
|
||||||
|
arranged, transformed, or otherwise modified in a manner requiring
|
||||||
|
permission under the Copyright and Similar Rights held by the
|
||||||
|
Licensor. For purposes of this Public License, where the Licensed
|
||||||
|
Material is a musical work, performance, or sound recording,
|
||||||
|
Adapted Material is always produced where the Licensed Material is
|
||||||
|
synched in timed relation with a moving image.
|
||||||
|
|
||||||
|
b. Adapter's License means the license You apply to Your Copyright
|
||||||
|
and Similar Rights in Your contributions to Adapted Material in
|
||||||
|
accordance with the terms and conditions of this Public License.
|
||||||
|
|
||||||
|
c. Copyright and Similar Rights means copyright and/or similar rights
|
||||||
|
closely related to copyright including, without limitation,
|
||||||
|
performance, broadcast, sound recording, and Sui Generis Database
|
||||||
|
Rights, without regard to how the rights are labeled or
|
||||||
|
categorized. For purposes of this Public License, the rights
|
||||||
|
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||||
|
Rights.
|
||||||
|
d. Effective Technological Measures means those measures that, in the
|
||||||
|
absence of proper authority, may not be circumvented under laws
|
||||||
|
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||||
|
Treaty adopted on December 20, 1996, and/or similar international
|
||||||
|
agreements.
|
||||||
|
|
||||||
|
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||||
|
any other exception or limitation to Copyright and Similar Rights
|
||||||
|
that applies to Your use of the Licensed Material.
|
||||||
|
|
||||||
|
f. Licensed Material means the artistic or literary work, database,
|
||||||
|
or other material to which the Licensor applied this Public
|
||||||
|
License.
|
||||||
|
|
||||||
|
g. Licensed Rights means the rights granted to You subject to the
|
||||||
|
terms and conditions of this Public License, which are limited to
|
||||||
|
all Copyright and Similar Rights that apply to Your use of the
|
||||||
|
Licensed Material and that the Licensor has authority to license.
|
||||||
|
|
||||||
|
h. Licensor means the individual(s) or entity(ies) granting rights
|
||||||
|
under this Public License.
|
||||||
|
|
||||||
|
i. NonCommercial means not primarily intended for or directed towards
|
||||||
|
commercial advantage or monetary compensation. For purposes of
|
||||||
|
this Public License, the exchange of the Licensed Material for
|
||||||
|
other material subject to Copyright and Similar Rights by digital
|
||||||
|
file-sharing or similar means is NonCommercial provided there is
|
||||||
|
no payment of monetary compensation in connection with the
|
||||||
|
exchange.
|
||||||
|
|
||||||
|
j. Share means to provide material to the public by any means or
|
||||||
|
process that requires permission under the Licensed Rights, such
|
||||||
|
as reproduction, public display, public performance, distribution,
|
||||||
|
dissemination, communication, or importation, and to make material
|
||||||
|
available to the public including in ways that members of the
|
||||||
|
public may access the material from a place and at a time
|
||||||
|
individually chosen by them.
|
||||||
|
|
||||||
|
k. Sui Generis Database Rights means rights other than copyright
|
||||||
|
resulting from Directive 96/9/EC of the European Parliament and of
|
||||||
|
the Council of 11 March 1996 on the legal protection of databases,
|
||||||
|
as amended and/or succeeded, as well as other essentially
|
||||||
|
equivalent rights anywhere in the world.
|
||||||
|
|
||||||
|
l. You means the individual or entity exercising the Licensed Rights
|
||||||
|
under this Public License. Your has a corresponding meaning.
|
||||||
|
|
||||||
|
Section 2 -- Scope.
|
||||||
|
|
||||||
|
a. License grant.
|
||||||
|
|
||||||
|
1. Subject to the terms and conditions of this Public License,
|
||||||
|
the Licensor hereby grants You a worldwide, royalty-free,
|
||||||
|
non-sublicensable, non-exclusive, irrevocable license to
|
||||||
|
exercise the Licensed Rights in the Licensed Material to:
|
||||||
|
|
||||||
|
a. reproduce and Share the Licensed Material, in whole or
|
||||||
|
in part, for NonCommercial purposes only; and
|
||||||
|
|
||||||
|
b. produce, reproduce, and Share Adapted Material for
|
||||||
|
NonCommercial purposes only.
|
||||||
|
|
||||||
|
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||||
|
Exceptions and Limitations apply to Your use, this Public
|
||||||
|
License does not apply, and You do not need to comply with
|
||||||
|
its terms and conditions.
|
||||||
|
|
||||||
|
3. Term. The term of this Public License is specified in Section
|
||||||
|
6(a).
|
||||||
|
|
||||||
|
4. Media and formats; technical modifications allowed. The
|
||||||
|
Licensor authorizes You to exercise the Licensed Rights in
|
||||||
|
all media and formats whether now known or hereafter created,
|
||||||
|
and to make technical modifications necessary to do so. The
|
||||||
|
Licensor waives and/or agrees not to assert any right or
|
||||||
|
authority to forbid You from making technical modifications
|
||||||
|
necessary to exercise the Licensed Rights, including
|
||||||
|
technical modifications necessary to circumvent Effective
|
||||||
|
Technological Measures. For purposes of this Public License,
|
||||||
|
simply making modifications authorized by this Section 2(a)
|
||||||
|
(4) never produces Adapted Material.
|
||||||
|
|
||||||
|
5. Downstream recipients.
|
||||||
|
|
||||||
|
a. Offer from the Licensor -- Licensed Material. Every
|
||||||
|
recipient of the Licensed Material automatically
|
||||||
|
receives an offer from the Licensor to exercise the
|
||||||
|
Licensed Rights under the terms and conditions of this
|
||||||
|
Public License.
|
||||||
|
|
||||||
|
b. No downstream restrictions. You may not offer or impose
|
||||||
|
any additional or different terms or conditions on, or
|
||||||
|
apply any Effective Technological Measures to, the
|
||||||
|
Licensed Material if doing so restricts exercise of the
|
||||||
|
Licensed Rights by any recipient of the Licensed
|
||||||
|
Material.
|
||||||
|
|
||||||
|
6. No endorsement. Nothing in this Public License constitutes or
|
||||||
|
may be construed as permission to assert or imply that You
|
||||||
|
are, or that Your use of the Licensed Material is, connected
|
||||||
|
with, or sponsored, endorsed, or granted official status by,
|
||||||
|
the Licensor or others designated to receive attribution as
|
||||||
|
provided in Section 3(a)(1)(A)(i).
|
||||||
|
|
||||||
|
b. Other rights.
|
||||||
|
|
||||||
|
1. Moral rights, such as the right of integrity, are not
|
||||||
|
licensed under this Public License, nor are publicity,
|
||||||
|
privacy, and/or other similar personality rights; however, to
|
||||||
|
the extent possible, the Licensor waives and/or agrees not to
|
||||||
|
assert any such rights held by the Licensor to the limited
|
||||||
|
extent necessary to allow You to exercise the Licensed
|
||||||
|
Rights, but not otherwise.
|
||||||
|
|
||||||
|
2. Patent and trademark rights are not licensed under this
|
||||||
|
Public License.
|
||||||
|
|
||||||
|
3. To the extent possible, the Licensor waives any right to
|
||||||
|
collect royalties from You for the exercise of the Licensed
|
||||||
|
Rights, whether directly or through a collecting society
|
||||||
|
under any voluntary or waivable statutory or compulsory
|
||||||
|
licensing scheme. In all other cases the Licensor expressly
|
||||||
|
reserves any right to collect such royalties, including when
|
||||||
|
the Licensed Material is used other than for NonCommercial
|
||||||
|
purposes.
|
||||||
|
|
||||||
|
Section 3 -- License Conditions.
|
||||||
|
|
||||||
|
Your exercise of the Licensed Rights is expressly made subject to the
|
||||||
|
following conditions.
|
||||||
|
|
||||||
|
a. Attribution.
|
||||||
|
|
||||||
|
1. If You Share the Licensed Material (including in modified
|
||||||
|
form), You must:
|
||||||
|
|
||||||
|
a. retain the following if it is supplied by the Licensor
|
||||||
|
with the Licensed Material:
|
||||||
|
|
||||||
|
i. identification of the creator(s) of the Licensed
|
||||||
|
Material and any others designated to receive
|
||||||
|
attribution, in any reasonable manner requested by
|
||||||
|
the Licensor (including by pseudonym if
|
||||||
|
designated);
|
||||||
|
|
||||||
|
ii. a copyright notice;
|
||||||
|
|
||||||
|
iii. a notice that refers to this Public License;
|
||||||
|
|
||||||
|
iv. a notice that refers to the disclaimer of
|
||||||
|
warranties;
|
||||||
|
|
||||||
|
v. a URI or hyperlink to the Licensed Material to the
|
||||||
|
extent reasonably practicable;
|
||||||
|
|
||||||
|
b. indicate if You modified the Licensed Material and
|
||||||
|
retain an indication of any previous modifications; and
|
||||||
|
|
||||||
|
c. indicate the Licensed Material is licensed under this
|
||||||
|
Public License, and include the text of, or the URI or
|
||||||
|
hyperlink to, this Public License.
|
||||||
|
|
||||||
|
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||||
|
reasonable manner based on the medium, means, and context in
|
||||||
|
which You Share the Licensed Material. For example, it may be
|
||||||
|
reasonable to satisfy the conditions by providing a URI or
|
||||||
|
hyperlink to a resource that includes the required
|
||||||
|
information.
|
||||||
|
|
||||||
|
3. If requested by the Licensor, You must remove any of the
|
||||||
|
information required by Section 3(a)(1)(A) to the extent
|
||||||
|
reasonably practicable.
|
||||||
|
|
||||||
|
4. If You Share Adapted Material You produce, the Adapter's
|
||||||
|
License You apply must not prevent recipients of the Adapted
|
||||||
|
Material from complying with this Public License.
|
||||||
|
|
||||||
|
Section 4 -- Sui Generis Database Rights.
|
||||||
|
|
||||||
|
Where the Licensed Rights include Sui Generis Database Rights that
|
||||||
|
apply to Your use of the Licensed Material:
|
||||||
|
|
||||||
|
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||||
|
to extract, reuse, reproduce, and Share all or a substantial
|
||||||
|
portion of the contents of the database for NonCommercial purposes
|
||||||
|
only;
|
||||||
|
|
||||||
|
b. if You include all or a substantial portion of the database
|
||||||
|
contents in a database in which You have Sui Generis Database
|
||||||
|
Rights, then the database in which You have Sui Generis Database
|
||||||
|
Rights (but not its individual contents) is Adapted Material; and
|
||||||
|
|
||||||
|
c. You must comply with the conditions in Section 3(a) if You Share
|
||||||
|
all or a substantial portion of the contents of the database.
|
||||||
|
|
||||||
|
For the avoidance of doubt, this Section 4 supplements and does not
|
||||||
|
replace Your obligations under this Public License where the Licensed
|
||||||
|
Rights include other Copyright and Similar Rights.
|
||||||
|
|
||||||
|
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||||
|
|
||||||
|
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||||
|
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||||
|
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||||
|
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||||
|
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||||
|
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||||
|
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||||
|
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||||
|
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||||
|
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||||
|
|
||||||
|
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||||
|
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||||
|
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||||
|
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||||
|
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||||
|
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||||
|
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||||
|
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||||
|
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||||
|
|
||||||
|
c. The disclaimer of warranties and limitation of liability provided
|
||||||
|
above shall be interpreted in a manner that, to the extent
|
||||||
|
possible, most closely approximates an absolute disclaimer and
|
||||||
|
waiver of all liability.
|
||||||
|
|
||||||
|
Section 6 -- Term and Termination.
|
||||||
|
|
||||||
|
a. This Public License applies for the term of the Copyright and
|
||||||
|
Similar Rights licensed here. However, if You fail to comply with
|
||||||
|
this Public License, then Your rights under this Public License
|
||||||
|
terminate automatically.
|
||||||
|
|
||||||
|
b. Where Your right to use the Licensed Material has terminated under
|
||||||
|
Section 6(a), it reinstates:
|
||||||
|
|
||||||
|
1. automatically as of the date the violation is cured, provided
|
||||||
|
it is cured within 30 days of Your discovery of the
|
||||||
|
violation; or
|
||||||
|
|
||||||
|
2. upon express reinstatement by the Licensor.
|
||||||
|
|
||||||
|
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||||
|
right the Licensor may have to seek remedies for Your violations
|
||||||
|
of this Public License.
|
||||||
|
|
||||||
|
c. For the avoidance of doubt, the Licensor may also offer the
|
||||||
|
Licensed Material under separate terms or conditions or stop
|
||||||
|
distributing the Licensed Material at any time; however, doing so
|
||||||
|
will not terminate this Public License.
|
||||||
|
|
||||||
|
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||||
|
License.
|
||||||
|
|
||||||
|
Section 7 -- Other Terms and Conditions.
|
||||||
|
|
||||||
|
a. The Licensor shall not be bound by any additional or different
|
||||||
|
terms or conditions communicated by You unless expressly agreed.
|
||||||
|
|
||||||
|
b. Any arrangements, understandings, or agreements regarding the
|
||||||
|
Licensed Material not stated herein are separate from and
|
||||||
|
independent of the terms and conditions of this Public License.
|
||||||
|
|
||||||
|
Section 8 -- Interpretation.
|
||||||
|
|
||||||
|
a. For the avoidance of doubt, this Public License does not, and
|
||||||
|
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||||
|
conditions on any use of the Licensed Material that could lawfully
|
||||||
|
be made without permission under this Public License.
|
||||||
|
|
||||||
|
b. To the extent possible, if any provision of this Public License is
|
||||||
|
deemed unenforceable, it shall be automatically reformed to the
|
||||||
|
minimum extent necessary to make it enforceable. If the provision
|
||||||
|
cannot be reformed, it shall be severed from this Public License
|
||||||
|
without affecting the enforceability of the remaining terms and
|
||||||
|
conditions.
|
||||||
|
|
||||||
|
c. No term or condition of this Public License will be waived and no
|
||||||
|
failure to comply consented to unless expressly agreed to by the
|
||||||
|
Licensor.
|
||||||
|
|
||||||
|
d. Nothing in this Public License constitutes or may be interpreted
|
||||||
|
as a limitation upon, or waiver of, any privileges and immunities
|
||||||
|
that apply to the Licensor or You, including from the legal
|
||||||
|
processes of any jurisdiction or authority.
|
||||||
|
|
||||||
|
=======================================================================
|
||||||
|
|
||||||
|
Creative Commons is not a party to its public
|
||||||
|
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||||
|
its public licenses to material it publishes and in those instances
|
||||||
|
will be considered the “Licensor.” The text of the Creative Commons
|
||||||
|
public licenses is dedicated to the public domain under the CC0 Public
|
||||||
|
Domain Dedication. Except for the limited purpose of indicating that
|
||||||
|
material is shared under a Creative Commons public license or as
|
||||||
|
otherwise permitted by the Creative Commons policies published at
|
||||||
|
creativecommons.org/policies, Creative Commons does not authorize the
|
||||||
|
use of the trademark "Creative Commons" or any other trademark or logo
|
||||||
|
of Creative Commons without its prior written consent including,
|
||||||
|
without limitation, in connection with any unauthorized modifications
|
||||||
|
to any of its public licenses or any other arrangements,
|
||||||
|
understandings, or agreements concerning use of licensed material. For
|
||||||
|
the avoidance of doubt, this paragraph does not form part of the
|
||||||
|
public licenses.
|
||||||
|
|
||||||
|
Creative Commons may be contacted at creativecommons.org.
|
96
README.md
Normal file
96
README.md
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# Learning Invariant Representations for Reinforcement Learning without Reconstruction
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
We assume you have access to a gpu that can run CUDA 9.2. Then, the simplest way to install all required dependencies is to create an anaconda environment by running:
|
||||||
|
```
|
||||||
|
conda env create -f conda_env.yml
|
||||||
|
```
|
||||||
|
After the installation ends you can activate your environment with:
|
||||||
|
```
|
||||||
|
source activate dbc
|
||||||
|
```
|
||||||
|
|
||||||
|
## Instructions
|
||||||
|
To train a DBC agent on the `cheetah run` task from image-based observations run:
|
||||||
|
```
|
||||||
|
python train.py \
|
||||||
|
--domain_name cheetah \
|
||||||
|
--task_name run \
|
||||||
|
--encoder_type pixel \
|
||||||
|
--decoder_type identity \
|
||||||
|
--action_repeat 4 \
|
||||||
|
--save_video \
|
||||||
|
--save_tb \
|
||||||
|
--work_dir ./log \
|
||||||
|
--seed 1
|
||||||
|
```
|
||||||
|
This will produce 'log' folder, where all the outputs are going to be stored including train/eval logs, tensorboard blobs, and evaluation episode videos. One can attacha tensorboard to monitor training by running:
|
||||||
|
```
|
||||||
|
tensorboard --logdir log
|
||||||
|
```
|
||||||
|
and opening up tensorboad in your browser.
|
||||||
|
|
||||||
|
The console output is also available in a form:
|
||||||
|
```
|
||||||
|
| train | E: 1 | S: 1000 | D: 0.8 s | R: 0.0000 | BR: 0.0000 | ALOSS: 0.0000 | CLOSS: 0.0000 | RLOSS: 0.0000
|
||||||
|
```
|
||||||
|
a training entry decodes as:
|
||||||
|
```
|
||||||
|
train - training episode
|
||||||
|
E - total number of episodes
|
||||||
|
S - total number of environment steps
|
||||||
|
D - duration in seconds to train 1 episode
|
||||||
|
R - episode reward
|
||||||
|
BR - average reward of sampled batch
|
||||||
|
ALOSS - average loss of actor
|
||||||
|
CLOSS - average loss of critic
|
||||||
|
RLOSS - average reconstruction loss (only if it is trained from pixels and decoder)
|
||||||
|
```
|
||||||
|
while an evaluation entry:
|
||||||
|
```
|
||||||
|
| eval | S: 0 | ER: 21.1676
|
||||||
|
```
|
||||||
|
which just tells the expected reward `ER` evaluating current policy after `S` steps. Note that `ER` is average evaluation performance over `num_eval_episodes` episodes (usually 10).
|
||||||
|
|
||||||
|
## CARLA
|
||||||
|
Download CARLA from https://github.com/carla-simulator/carla/releases, e.g.:
|
||||||
|
1. https://carla-releases.s3.eu-west-3.amazonaws.com/Linux/CARLA_0.9.8.tar.gz
|
||||||
|
2. https://carla-releases.s3.eu-west-3.amazonaws.com/Linux/AdditionalMaps_0.9.8.tar.gz
|
||||||
|
|
||||||
|
Add to your python path:
|
||||||
|
```
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.8/PythonAPI
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.8/PythonAPI/carla
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.8/PythonAPI/carla/dist/carla-0.9.8-py3.5-linux-x86_64.egg
|
||||||
|
```
|
||||||
|
and merge the directories.
|
||||||
|
|
||||||
|
Then pull altered carla branch files:
|
||||||
|
```
|
||||||
|
git fetch
|
||||||
|
git checkout carla
|
||||||
|
```
|
||||||
|
|
||||||
|
Install:
|
||||||
|
```
|
||||||
|
pip install pygame
|
||||||
|
pip install networkx
|
||||||
|
```
|
||||||
|
|
||||||
|
Terminal 1:
|
||||||
|
```
|
||||||
|
cd CARLA_0.9.6
|
||||||
|
bash CarlaUE4.sh -fps 20
|
||||||
|
```
|
||||||
|
|
||||||
|
Terminal 2:
|
||||||
|
```
|
||||||
|
cd CARLA_0.9.6
|
||||||
|
# can run expert autopilot (uses privileged game-state information):
|
||||||
|
python PythonAPI/carla/agents/navigation/carla_env.py
|
||||||
|
# or can run bisim:
|
||||||
|
./run_local_carla096.sh --agent bisim --transition_model_type probabilistic --domain_name carla
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
This project is CC-BY-NC 4.0 licensed, as found in the LICENSE file.
|
0
agent/__init__.py
Normal file
0
agent/__init__.py
Normal file
363
agent/baseline_agent.py
Normal file
363
agent/baseline_agent.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from sac_ae import Actor, Critic, weight_init, LOG_FREQ
|
||||||
|
from transition_model import make_transition_model
|
||||||
|
from decoder import make_decoder
|
||||||
|
|
||||||
|
|
||||||
|
class BaselineAgent(object):
|
||||||
|
"""Baseline algorithm with transition model and various decoder types."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obs_shape,
|
||||||
|
action_shape,
|
||||||
|
device,
|
||||||
|
hidden_dim=256,
|
||||||
|
discount=0.99,
|
||||||
|
init_temperature=0.01,
|
||||||
|
alpha_lr=1e-3,
|
||||||
|
alpha_beta=0.9,
|
||||||
|
actor_lr=1e-3,
|
||||||
|
actor_beta=0.9,
|
||||||
|
actor_log_std_min=-10,
|
||||||
|
actor_log_std_max=2,
|
||||||
|
actor_update_freq=2,
|
||||||
|
critic_lr=1e-3,
|
||||||
|
critic_beta=0.9,
|
||||||
|
critic_tau=0.005,
|
||||||
|
critic_target_update_freq=2,
|
||||||
|
encoder_type='pixel',
|
||||||
|
encoder_stride=2,
|
||||||
|
encoder_feature_dim=50,
|
||||||
|
encoder_lr=1e-3,
|
||||||
|
encoder_tau=0.005,
|
||||||
|
decoder_type='pixel',
|
||||||
|
decoder_lr=1e-3,
|
||||||
|
decoder_update_freq=1,
|
||||||
|
decoder_weight_lambda=0.0,
|
||||||
|
transition_model_type='deterministic',
|
||||||
|
num_layers=4,
|
||||||
|
num_filters=32
|
||||||
|
):
|
||||||
|
self.device = device
|
||||||
|
self.discount = discount
|
||||||
|
self.critic_tau = critic_tau
|
||||||
|
self.encoder_tau = encoder_tau
|
||||||
|
self.actor_update_freq = actor_update_freq
|
||||||
|
self.critic_target_update_freq = critic_target_update_freq
|
||||||
|
self.decoder_update_freq = decoder_update_freq
|
||||||
|
self.decoder_type = decoder_type
|
||||||
|
self.hinge = 1.
|
||||||
|
self.sigma = 0.5
|
||||||
|
|
||||||
|
self.actor = Actor(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
||||||
|
num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic = Critic(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic_target = Critic(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
|
self.transition_model = make_transition_model(
|
||||||
|
transition_model_type, encoder_feature_dim, action_shape
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# optimizer for decoder
|
||||||
|
self.decoder_optimizer = torch.optim.Adam(
|
||||||
|
self.transition_model.parameters(),
|
||||||
|
lr=decoder_lr,
|
||||||
|
weight_decay=decoder_weight_lambda
|
||||||
|
)
|
||||||
|
|
||||||
|
# tie encoders between actor and critic
|
||||||
|
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
||||||
|
|
||||||
|
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
|
||||||
|
self.log_alpha.requires_grad = True
|
||||||
|
# set target entropy to -|A|
|
||||||
|
self.target_entropy = -np.prod(action_shape)
|
||||||
|
|
||||||
|
self.decoder = None
|
||||||
|
encoder_params = list(self.critic.encoder.parameters()) + list(self.transition_model.parameters())
|
||||||
|
if decoder_type == 'pixel':
|
||||||
|
# create decoder
|
||||||
|
self.decoder = make_decoder(
|
||||||
|
decoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||||
|
num_filters
|
||||||
|
).to(device)
|
||||||
|
self.decoder.apply(weight_init)
|
||||||
|
elif decoder_type == 'inverse':
|
||||||
|
self.inverse_model = nn.Sequential(
|
||||||
|
nn.Linear(encoder_feature_dim * 2, 512),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(512, action_shape[0])).to(device)
|
||||||
|
encoder_params += list(self.inverse_model.parameters())
|
||||||
|
if decoder_type != 'identity':
|
||||||
|
# optimizer for critic encoder for reconstruction loss
|
||||||
|
self.encoder_optimizer = torch.optim.Adam(encoder_params, lr=encoder_lr)
|
||||||
|
if decoder_type == 'pixel': # optimizer for decoder
|
||||||
|
self.decoder_optimizer = torch.optim.Adam(
|
||||||
|
self.decoder.parameters(),
|
||||||
|
lr=decoder_lr,
|
||||||
|
weight_decay=decoder_weight_lambda
|
||||||
|
)
|
||||||
|
# optimizer for critic encoder for reconstruction loss
|
||||||
|
self.encoder_optimizer = torch.optim.Adam(
|
||||||
|
self.critic.encoder.parameters(), lr=encoder_lr
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimizers
|
||||||
|
self.actor_optimizer = torch.optim.Adam(
|
||||||
|
self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.critic_optimizer = torch.optim.Adam(
|
||||||
|
self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_alpha_optimizer = torch.optim.Adam(
|
||||||
|
[self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
self.critic_target.train()
|
||||||
|
|
||||||
|
def energy(self, state, action, next_state, no_trans=False):
|
||||||
|
"""Energy function based on normalized squared L2 norm."""
|
||||||
|
|
||||||
|
norm = 0.5 / (self.sigma**2)
|
||||||
|
|
||||||
|
if no_trans:
|
||||||
|
diff = state - next_state
|
||||||
|
normalization = 0.
|
||||||
|
else:
|
||||||
|
pred_trans_mu, pred_trans_sigma = self.transition_model(torch.cat([state, action], dim=1))
|
||||||
|
if pred_trans_sigma is None:
|
||||||
|
pred_trans_sigma = torch.Tensor([1.]).to(self.device)
|
||||||
|
if isinstance(pred_trans_mu, list): # i.e. comes from an ensemble
|
||||||
|
raise NotImplementedError # TODO: handle the additional ensemble dimension (0) in this case
|
||||||
|
diff = (state + pred_trans_mu - next_state) / pred_trans_sigma
|
||||||
|
normalization = torch.log(pred_trans_sigma)
|
||||||
|
return norm * (diff.pow(2) + normalization).sum(1)
|
||||||
|
|
||||||
|
def contrastive_loss(self, state, action, next_state):
|
||||||
|
|
||||||
|
# Sample negative state across episodes at random
|
||||||
|
batch_size = state.size(0)
|
||||||
|
perm = np.random.permutation(batch_size)
|
||||||
|
neg_state = state[perm]
|
||||||
|
|
||||||
|
self.pos_loss = self.energy(state, action, next_state)
|
||||||
|
zeros = torch.zeros_like(self.pos_loss)
|
||||||
|
|
||||||
|
self.pos_loss = self.pos_loss.mean()
|
||||||
|
self.neg_loss = torch.max(
|
||||||
|
zeros, self.hinge - self.energy(
|
||||||
|
state, action, neg_state, no_trans=True)).mean()
|
||||||
|
|
||||||
|
loss = self.pos_loss + self.neg_loss
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def train(self, training=True):
|
||||||
|
self.training = training
|
||||||
|
self.actor.train(training)
|
||||||
|
self.critic.train(training)
|
||||||
|
if self.decoder is not None:
|
||||||
|
self.decoder.train(training)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alpha(self):
|
||||||
|
return self.log_alpha.exp()
|
||||||
|
|
||||||
|
def select_action(self, obs):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs = torch.FloatTensor(obs).to(self.device)
|
||||||
|
obs = obs.unsqueeze(0)
|
||||||
|
mu, _, _, _ = self.actor(
|
||||||
|
obs, compute_pi=False, compute_log_pi=False
|
||||||
|
)
|
||||||
|
return mu.cpu().data.numpy().flatten()
|
||||||
|
|
||||||
|
def sample_action(self, obs):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs = torch.FloatTensor(obs).to(self.device)
|
||||||
|
obs = obs.unsqueeze(0)
|
||||||
|
mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
|
||||||
|
return pi.cpu().data.numpy().flatten()
|
||||||
|
|
||||||
|
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
||||||
|
with torch.no_grad():
|
||||||
|
_, policy_action, log_pi, _ = self.actor(next_obs)
|
||||||
|
target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
|
||||||
|
target_V = torch.min(target_Q1,
|
||||||
|
target_Q2) - self.alpha.detach() * log_pi
|
||||||
|
target_Q = reward + (not_done * self.discount * target_V)
|
||||||
|
|
||||||
|
# get current Q estimates
|
||||||
|
current_Q1, current_Q2 = self.critic(obs, action, detach_encoder=False)
|
||||||
|
critic_loss = F.mse_loss(current_Q1,
|
||||||
|
target_Q) + F.mse_loss(current_Q2, target_Q)
|
||||||
|
L.log('train_critic/loss', critic_loss, step)
|
||||||
|
|
||||||
|
# Optimize the critic
|
||||||
|
self.critic_optimizer.zero_grad()
|
||||||
|
critic_loss.backward()
|
||||||
|
self.critic_optimizer.step()
|
||||||
|
|
||||||
|
self.critic.log(L, step)
|
||||||
|
|
||||||
|
def update_actor_and_alpha(self, obs, L, step):
|
||||||
|
# detach encoder, so we don't update it with the actor loss
|
||||||
|
_, pi, log_pi, log_std = self.actor(obs, detach_encoder=True)
|
||||||
|
actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
|
||||||
|
|
||||||
|
actor_Q = torch.min(actor_Q1, actor_Q2)
|
||||||
|
actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
|
||||||
|
|
||||||
|
L.log('train_actor/loss', actor_loss, step)
|
||||||
|
L.log('train_actor/target_entropy', self.target_entropy, step)
|
||||||
|
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
|
||||||
|
) + log_std.sum(dim=-1)
|
||||||
|
L.log('train_actor/entropy', entropy.mean(), step)
|
||||||
|
|
||||||
|
# optimize the actor
|
||||||
|
self.actor_optimizer.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self.actor_optimizer.step()
|
||||||
|
|
||||||
|
self.actor.log(L, step)
|
||||||
|
|
||||||
|
self.log_alpha_optimizer.zero_grad()
|
||||||
|
alpha_loss = (self.alpha *
|
||||||
|
(-log_pi - self.target_entropy).detach()).mean()
|
||||||
|
L.log('train_alpha/loss', alpha_loss, step)
|
||||||
|
L.log('train_alpha/value', self.alpha, step)
|
||||||
|
alpha_loss.backward()
|
||||||
|
self.log_alpha_optimizer.step()
|
||||||
|
|
||||||
|
def update_decoder(self, obs, action, target_obs, L, step): # uses transition model
|
||||||
|
# image might be stacked, just grab the first 3 (rgb)!
|
||||||
|
assert target_obs.dim() == 4
|
||||||
|
target_obs = target_obs[:, :3, :, :]
|
||||||
|
|
||||||
|
h = self.critic.encoder(obs)
|
||||||
|
next_h = self.transition_model.sample_prediction(torch.cat([h, action], dim=1))
|
||||||
|
if target_obs.dim() == 4:
|
||||||
|
# preprocess images to be in [-0.5, 0.5] range
|
||||||
|
target_obs = utils.preprocess_obs(target_obs)
|
||||||
|
rec_obs = self.decoder(next_h)
|
||||||
|
loss = F.mse_loss(target_obs, rec_obs)
|
||||||
|
|
||||||
|
self.encoder_optimizer.zero_grad()
|
||||||
|
self.decoder_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
self.encoder_optimizer.step()
|
||||||
|
self.decoder_optimizer.step()
|
||||||
|
L.log('train_ae/ae_loss', loss, step)
|
||||||
|
|
||||||
|
self.decoder.log(L, step, log_freq=LOG_FREQ)
|
||||||
|
|
||||||
|
def update_contrastive(self, obs, action, next_obs, L, step):
|
||||||
|
latent = self.critic.encoder(obs)
|
||||||
|
next_latent = self.critic.encoder(next_obs)
|
||||||
|
loss = self.contrastive_loss(latent, action, next_latent)
|
||||||
|
self.encoder_optimizer.zero_grad()
|
||||||
|
self.decoder_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.encoder_optimizer.step()
|
||||||
|
self.decoder_optimizer.step()
|
||||||
|
L.log('train_ae/contrastive_loss', loss, step)
|
||||||
|
|
||||||
|
def update_inverse(self, obs, action, next_obs, L, step):
|
||||||
|
non_final_mask = torch.tensor(tuple(map(lambda s: not (s == 0).all(), next_obs)), device=self.device).long() # hack
|
||||||
|
latent = self.critic.encoder(obs[non_final_mask])
|
||||||
|
next_latent = self.critic.encoder(next_obs[non_final_mask].to(self.device).float())
|
||||||
|
# pred_next_latent = self.transition_model(torch.cat([latent, action], dim=1))
|
||||||
|
# fpred_action = self.inverse_model(latent, pred_next_latent)
|
||||||
|
pred_action = self.inverse_model(torch.cat([latent, next_latent], dim=1))
|
||||||
|
loss = F.mse_loss(pred_action, action[non_final_mask]) # + F.mse_loss(fpred_action, action)
|
||||||
|
self.encoder_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.encoder_optimizer.step()
|
||||||
|
L.log('train_ae/inverse_loss', loss, step)
|
||||||
|
|
||||||
|
def update(self, replay_buffer, L, step):
|
||||||
|
if self.decoder_type == 'inverse':
|
||||||
|
obs, action, reward, next_obs, not_done, k_obs = replay_buffer.sample(k=True)
|
||||||
|
else:
|
||||||
|
obs, action, _, reward, next_obs, not_done = replay_buffer.sample()
|
||||||
|
|
||||||
|
L.log('train/batch_reward', reward.mean(), step)
|
||||||
|
|
||||||
|
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
||||||
|
|
||||||
|
if step % self.actor_update_freq == 0:
|
||||||
|
self.update_actor_and_alpha(obs, L, step)
|
||||||
|
|
||||||
|
if step % self.critic_target_update_freq == 0:
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.Q1, self.critic_target.Q1, self.critic_tau
|
||||||
|
)
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.Q2, self.critic_target.Q2, self.critic_tau
|
||||||
|
)
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.encoder, self.critic_target.encoder,
|
||||||
|
self.encoder_tau
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.decoder is not None and step % self.decoder_update_freq == 0: # decoder_type is pixel
|
||||||
|
self.update_decoder(obs, action, next_obs, L, step)
|
||||||
|
|
||||||
|
if self.decoder_type == 'contrastive':
|
||||||
|
self.update_contrastive(obs, action, next_obs, L, step)
|
||||||
|
elif self.decoder_type == 'inverse':
|
||||||
|
self.update_inverse(obs, action, k_obs, L, step)
|
||||||
|
|
||||||
|
def save(self, model_dir, step):
|
||||||
|
torch.save(
|
||||||
|
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
torch.save(
|
||||||
|
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
if self.decoder is not None:
|
||||||
|
torch.save(
|
||||||
|
self.decoder.state_dict(),
|
||||||
|
'%s/decoder_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, model_dir, step):
|
||||||
|
self.actor.load_state_dict(
|
||||||
|
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
||||||
|
)
|
||||||
|
self.critic.load_state_dict(
|
||||||
|
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
||||||
|
)
|
||||||
|
if self.decoder is not None:
|
||||||
|
self.decoder.load_state_dict(
|
||||||
|
torch.load('%s/decoder_%s.pt' % (model_dir, step))
|
||||||
|
)
|
314
agent/bisim_agent.py
Normal file
314
agent/bisim_agent.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from sac_ae import Actor, Critic, LOG_FREQ
|
||||||
|
from transition_model import make_transition_model
|
||||||
|
|
||||||
|
|
||||||
|
class BisimAgent(object):
|
||||||
|
"""Bisimulation metric algorithm."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obs_shape,
|
||||||
|
action_shape,
|
||||||
|
device,
|
||||||
|
transition_model_type,
|
||||||
|
hidden_dim=256,
|
||||||
|
discount=0.99,
|
||||||
|
init_temperature=0.01,
|
||||||
|
alpha_lr=1e-3,
|
||||||
|
alpha_beta=0.9,
|
||||||
|
actor_lr=1e-3,
|
||||||
|
actor_beta=0.9,
|
||||||
|
actor_log_std_min=-10,
|
||||||
|
actor_log_std_max=2,
|
||||||
|
actor_update_freq=2,
|
||||||
|
encoder_stride=2,
|
||||||
|
critic_lr=1e-3,
|
||||||
|
critic_beta=0.9,
|
||||||
|
critic_tau=0.005,
|
||||||
|
critic_target_update_freq=2,
|
||||||
|
encoder_type='pixel',
|
||||||
|
encoder_feature_dim=50,
|
||||||
|
encoder_lr=1e-3,
|
||||||
|
encoder_tau=0.005,
|
||||||
|
decoder_type='pixel',
|
||||||
|
decoder_lr=1e-3,
|
||||||
|
decoder_update_freq=1,
|
||||||
|
decoder_latent_lambda=0.0,
|
||||||
|
decoder_weight_lambda=0.0,
|
||||||
|
num_layers=4,
|
||||||
|
num_filters=32,
|
||||||
|
bisim_coef=0.5
|
||||||
|
):
|
||||||
|
self.device = device
|
||||||
|
self.discount = discount
|
||||||
|
self.critic_tau = critic_tau
|
||||||
|
self.encoder_tau = encoder_tau
|
||||||
|
self.actor_update_freq = actor_update_freq
|
||||||
|
self.critic_target_update_freq = critic_target_update_freq
|
||||||
|
self.decoder_update_freq = decoder_update_freq
|
||||||
|
self.decoder_latent_lambda = decoder_latent_lambda
|
||||||
|
self.transition_model_type = transition_model_type
|
||||||
|
self.bisim_coef = bisim_coef
|
||||||
|
|
||||||
|
self.actor = Actor(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
||||||
|
num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic = Critic(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic_target = Critic(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
|
self.transition_model = make_transition_model(
|
||||||
|
transition_model_type, encoder_feature_dim, action_shape
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.reward_decoder = nn.Sequential(
|
||||||
|
nn.Linear(encoder_feature_dim, 512),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(512, 1)).to(device)
|
||||||
|
|
||||||
|
# tie encoders between actor and critic
|
||||||
|
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
||||||
|
|
||||||
|
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
|
||||||
|
self.log_alpha.requires_grad = True
|
||||||
|
# set target entropy to -|A|
|
||||||
|
self.target_entropy = -np.prod(action_shape)
|
||||||
|
|
||||||
|
# optimizers
|
||||||
|
self.actor_optimizer = torch.optim.Adam(
|
||||||
|
self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.critic_optimizer = torch.optim.Adam(
|
||||||
|
self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_alpha_optimizer = torch.optim.Adam(
|
||||||
|
[self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimizer for decoder
|
||||||
|
self.decoder_optimizer = torch.optim.Adam(
|
||||||
|
list(self.reward_decoder.parameters()) + list(self.transition_model.parameters()),
|
||||||
|
lr=decoder_lr,
|
||||||
|
weight_decay=decoder_weight_lambda
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimizer for critic encoder for reconstruction loss
|
||||||
|
self.encoder_optimizer = torch.optim.Adam(
|
||||||
|
self.critic.encoder.parameters(), lr=encoder_lr
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
self.critic_target.train()
|
||||||
|
|
||||||
|
def train(self, training=True):
|
||||||
|
self.training = training
|
||||||
|
self.actor.train(training)
|
||||||
|
self.critic.train(training)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alpha(self):
|
||||||
|
return self.log_alpha.exp()
|
||||||
|
|
||||||
|
def select_action(self, obs):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs = torch.FloatTensor(obs).to(self.device)
|
||||||
|
obs = obs.unsqueeze(0)
|
||||||
|
mu, _, _, _ = self.actor(
|
||||||
|
obs, compute_pi=False, compute_log_pi=False
|
||||||
|
)
|
||||||
|
return mu.cpu().data.numpy().flatten()
|
||||||
|
|
||||||
|
def sample_action(self, obs):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs = torch.FloatTensor(obs).to(self.device)
|
||||||
|
obs = obs.unsqueeze(0)
|
||||||
|
mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
|
||||||
|
return pi.cpu().data.numpy().flatten()
|
||||||
|
|
||||||
|
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
||||||
|
with torch.no_grad():
|
||||||
|
_, policy_action, log_pi, _ = self.actor(next_obs)
|
||||||
|
target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
|
||||||
|
target_V = torch.min(target_Q1,
|
||||||
|
target_Q2) - self.alpha.detach() * log_pi
|
||||||
|
target_Q = reward + (not_done * self.discount * target_V)
|
||||||
|
|
||||||
|
# get current Q estimates
|
||||||
|
current_Q1, current_Q2 = self.critic(obs, action, detach_encoder=False)
|
||||||
|
critic_loss = F.mse_loss(current_Q1,
|
||||||
|
target_Q) + F.mse_loss(current_Q2, target_Q)
|
||||||
|
L.log('train_critic/loss', critic_loss, step)
|
||||||
|
|
||||||
|
# Optimize the critic
|
||||||
|
self.critic_optimizer.zero_grad()
|
||||||
|
critic_loss.backward()
|
||||||
|
self.critic_optimizer.step()
|
||||||
|
|
||||||
|
self.critic.log(L, step)
|
||||||
|
|
||||||
|
def update_actor_and_alpha(self, obs, L, step):
|
||||||
|
# detach encoder, so we don't update it with the actor loss
|
||||||
|
_, pi, log_pi, log_std = self.actor(obs, detach_encoder=True)
|
||||||
|
actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
|
||||||
|
|
||||||
|
actor_Q = torch.min(actor_Q1, actor_Q2)
|
||||||
|
actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
|
||||||
|
|
||||||
|
L.log('train_actor/loss', actor_loss, step)
|
||||||
|
L.log('train_actor/target_entropy', self.target_entropy, step)
|
||||||
|
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
|
||||||
|
) + log_std.sum(dim=-1)
|
||||||
|
L.log('train_actor/entropy', entropy.mean(), step)
|
||||||
|
|
||||||
|
# optimize the actor
|
||||||
|
self.actor_optimizer.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self.actor_optimizer.step()
|
||||||
|
|
||||||
|
self.actor.log(L, step)
|
||||||
|
|
||||||
|
self.log_alpha_optimizer.zero_grad()
|
||||||
|
alpha_loss = (self.alpha *
|
||||||
|
(-log_pi - self.target_entropy).detach()).mean()
|
||||||
|
L.log('train_alpha/loss', alpha_loss, step)
|
||||||
|
L.log('train_alpha/value', self.alpha, step)
|
||||||
|
alpha_loss.backward()
|
||||||
|
self.log_alpha_optimizer.step()
|
||||||
|
|
||||||
|
def update_encoder(self, obs, action, reward, L, step):
|
||||||
|
h = self.critic.encoder(obs)
|
||||||
|
|
||||||
|
# Sample random states across episodes at random
|
||||||
|
batch_size = obs.size(0)
|
||||||
|
perm = np.random.permutation(batch_size)
|
||||||
|
h2 = h[perm]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# action, _, _, _ = self.actor(obs, compute_pi=False, compute_log_pi=False)
|
||||||
|
pred_next_latent_mu1, pred_next_latent_sigma1 = self.transition_model(torch.cat([h, action], dim=1))
|
||||||
|
# reward = self.reward_decoder(pred_next_latent_mu1)
|
||||||
|
reward2 = reward[perm]
|
||||||
|
if pred_next_latent_sigma1 is None:
|
||||||
|
pred_next_latent_sigma1 = torch.zeros_like(pred_next_latent_mu1)
|
||||||
|
if pred_next_latent_mu1.ndim == 2: # shape (B, Z), no ensemble
|
||||||
|
pred_next_latent_mu2 = pred_next_latent_mu1[perm]
|
||||||
|
pred_next_latent_sigma2 = pred_next_latent_sigma1[perm]
|
||||||
|
elif pred_next_latent_mu1.ndim == 3: # shape (B, E, Z), using an ensemble
|
||||||
|
pred_next_latent_mu2 = pred_next_latent_mu1[:, perm]
|
||||||
|
pred_next_latent_sigma2 = pred_next_latent_sigma1[:, perm]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
z_dist = F.smooth_l1_loss(h, h2, reduction='none')
|
||||||
|
r_dist = F.smooth_l1_loss(reward, reward2, reduction='none')
|
||||||
|
if self.transition_model_type == '':
|
||||||
|
transition_dist = F.smooth_l1_loss(pred_next_latent_mu1, pred_next_latent_mu2, reduction='none')
|
||||||
|
else:
|
||||||
|
transition_dist = torch.sqrt(
|
||||||
|
(pred_next_latent_mu1 - pred_next_latent_mu2).pow(2) +
|
||||||
|
(pred_next_latent_sigma1 - pred_next_latent_sigma2).pow(2)
|
||||||
|
)
|
||||||
|
# transition_dist = F.smooth_l1_loss(pred_next_latent_mu1, pred_next_latent_mu2, reduction='none') \
|
||||||
|
# + F.smooth_l1_loss(pred_next_latent_sigma1, pred_next_latent_sigma2, reduction='none')
|
||||||
|
|
||||||
|
bisimilarity = r_dist + self.discount * transition_dist
|
||||||
|
loss = (z_dist - bisimilarity).pow(2).mean()
|
||||||
|
L.log('train_ae/encoder_loss', loss, step)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def update_transition_reward_model(self, obs, action, next_obs, reward, L, step):
|
||||||
|
h = self.critic.encoder(obs)
|
||||||
|
pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
|
||||||
|
if pred_next_latent_sigma is None:
|
||||||
|
pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)
|
||||||
|
|
||||||
|
next_h = self.critic.encoder(next_obs)
|
||||||
|
diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
|
||||||
|
loss = torch.mean(0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma))
|
||||||
|
L.log('train_ae/transition_loss', loss, step)
|
||||||
|
|
||||||
|
pred_next_latent = self.transition_model.sample_prediction(torch.cat([h, action], dim=1))
|
||||||
|
pred_next_reward = self.reward_decoder(pred_next_latent)
|
||||||
|
reward_loss = F.mse_loss(pred_next_reward, reward)
|
||||||
|
total_loss = loss + reward_loss
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
def update(self, replay_buffer, L, step):
|
||||||
|
obs, action, _, reward, next_obs, not_done = replay_buffer.sample()
|
||||||
|
|
||||||
|
L.log('train/batch_reward', reward.mean(), step)
|
||||||
|
|
||||||
|
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
||||||
|
transition_reward_loss = self.update_transition_reward_model(obs, action, next_obs, reward, L, step)
|
||||||
|
encoder_loss = self.update_encoder(obs, action, reward, L, step)
|
||||||
|
total_loss = self.bisim_coef * encoder_loss + transition_reward_loss
|
||||||
|
self.encoder_optimizer.zero_grad()
|
||||||
|
self.decoder_optimizer.zero_grad()
|
||||||
|
total_loss.backward()
|
||||||
|
self.encoder_optimizer.step()
|
||||||
|
self.decoder_optimizer.step()
|
||||||
|
|
||||||
|
if step % self.actor_update_freq == 0:
|
||||||
|
self.update_actor_and_alpha(obs, L, step)
|
||||||
|
|
||||||
|
if step % self.critic_target_update_freq == 0:
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.Q1, self.critic_target.Q1, self.critic_tau
|
||||||
|
)
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.Q2, self.critic_target.Q2, self.critic_tau
|
||||||
|
)
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.encoder, self.critic_target.encoder,
|
||||||
|
self.encoder_tau
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save(self, model_dir, step):
|
||||||
|
torch.save(
|
||||||
|
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
torch.save(
|
||||||
|
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
torch.save(
|
||||||
|
self.reward_decoder.state_dict(),
|
||||||
|
'%s/reward_decoder_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, model_dir, step):
|
||||||
|
self.actor.load_state_dict(
|
||||||
|
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
||||||
|
)
|
||||||
|
self.critic.load_state_dict(
|
||||||
|
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
||||||
|
)
|
||||||
|
self.reward_decoder.load_state_dict(
|
||||||
|
torch.load('%s/reward_decoder_%s.pt' % (model_dir, step))
|
||||||
|
)
|
||||||
|
|
314
agent/deepmdp_agent.py
Normal file
314
agent/deepmdp_agent.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from sac_ae import Actor, Critic, weight_init, LOG_FREQ
|
||||||
|
from transition_model import make_transition_model
|
||||||
|
from decoder import make_decoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DeepMDPAgent(object):
|
||||||
|
"""Baseline algorithm with transition model and various decoder types."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obs_shape,
|
||||||
|
action_shape,
|
||||||
|
device,
|
||||||
|
hidden_dim=256,
|
||||||
|
discount=0.99,
|
||||||
|
init_temperature=0.01,
|
||||||
|
alpha_lr=1e-3,
|
||||||
|
alpha_beta=0.9,
|
||||||
|
actor_lr=1e-3,
|
||||||
|
actor_beta=0.9,
|
||||||
|
actor_log_std_min=-10,
|
||||||
|
actor_log_std_max=2,
|
||||||
|
actor_update_freq=2,
|
||||||
|
encoder_stride=2,
|
||||||
|
critic_lr=1e-3,
|
||||||
|
critic_beta=0.9,
|
||||||
|
critic_tau=0.005,
|
||||||
|
critic_target_update_freq=2,
|
||||||
|
encoder_type='pixel',
|
||||||
|
encoder_feature_dim=50,
|
||||||
|
encoder_lr=1e-3,
|
||||||
|
encoder_tau=0.005,
|
||||||
|
decoder_type='pixel',
|
||||||
|
decoder_lr=1e-3,
|
||||||
|
decoder_update_freq=1,
|
||||||
|
decoder_weight_lambda=0.0,
|
||||||
|
transition_model_type='deterministic',
|
||||||
|
num_layers=4,
|
||||||
|
num_filters=32
|
||||||
|
):
|
||||||
|
self.reconstruction = False
|
||||||
|
if decoder_type == 'reconstruction':
|
||||||
|
decoder_type = 'pixel'
|
||||||
|
self.reconstruction = True
|
||||||
|
self.device = device
|
||||||
|
self.discount = discount
|
||||||
|
self.critic_tau = critic_tau
|
||||||
|
self.encoder_tau = encoder_tau
|
||||||
|
self.actor_update_freq = actor_update_freq
|
||||||
|
self.critic_target_update_freq = critic_target_update_freq
|
||||||
|
self.decoder_update_freq = decoder_update_freq
|
||||||
|
self.decoder_type = decoder_type
|
||||||
|
|
||||||
|
self.actor = Actor(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, actor_log_std_min, actor_log_std_max,
|
||||||
|
num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic = Critic(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic_target = Critic(
|
||||||
|
obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, num_layers, num_filters, encoder_stride
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
|
self.transition_model = make_transition_model(
|
||||||
|
transition_model_type, encoder_feature_dim, action_shape
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.reward_decoder = nn.Sequential(
|
||||||
|
nn.Linear(encoder_feature_dim + action_shape[0], 512),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(512, 1)).to(device)
|
||||||
|
|
||||||
|
decoder_params = list(self.transition_model.parameters()) + list(self.reward_decoder.parameters())
|
||||||
|
|
||||||
|
# tie encoders between actor and critic
|
||||||
|
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
|
||||||
|
|
||||||
|
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
|
||||||
|
self.log_alpha.requires_grad = True
|
||||||
|
# set target entropy to -|A|
|
||||||
|
self.target_entropy = -np.prod(action_shape)
|
||||||
|
|
||||||
|
self.decoder = None
|
||||||
|
if decoder_type == 'pixel':
|
||||||
|
# create decoder
|
||||||
|
self.decoder = make_decoder(
|
||||||
|
decoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||||
|
num_filters
|
||||||
|
).to(device)
|
||||||
|
self.decoder.apply(weight_init)
|
||||||
|
decoder_params += list(self.decoder.parameters())
|
||||||
|
|
||||||
|
self.decoder_optimizer = torch.optim.Adam(
|
||||||
|
decoder_params,
|
||||||
|
lr=decoder_lr,
|
||||||
|
weight_decay=decoder_weight_lambda
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimizer for critic encoder for reconstruction loss
|
||||||
|
self.encoder_optimizer = torch.optim.Adam(
|
||||||
|
self.critic.encoder.parameters(), lr=encoder_lr
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimizers
|
||||||
|
self.actor_optimizer = torch.optim.Adam(
|
||||||
|
self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.critic_optimizer = torch.optim.Adam(
|
||||||
|
self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_alpha_optimizer = torch.optim.Adam(
|
||||||
|
[self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
self.critic_target.train()
|
||||||
|
|
||||||
|
def train(self, training=True):
|
||||||
|
self.training = training
|
||||||
|
self.actor.train(training)
|
||||||
|
self.critic.train(training)
|
||||||
|
if self.decoder is not None:
|
||||||
|
self.decoder.train(training)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alpha(self):
|
||||||
|
return self.log_alpha.exp()
|
||||||
|
|
||||||
|
def select_action(self, obs):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs = torch.FloatTensor(obs).to(self.device)
|
||||||
|
obs = obs.unsqueeze(0)
|
||||||
|
mu, _, _, _ = self.actor(
|
||||||
|
obs, compute_pi=False, compute_log_pi=False
|
||||||
|
)
|
||||||
|
return mu.cpu().data.numpy().flatten()
|
||||||
|
|
||||||
|
def sample_action(self, obs):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs = torch.FloatTensor(obs).to(self.device)
|
||||||
|
obs = obs.unsqueeze(0)
|
||||||
|
mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
|
||||||
|
return pi.cpu().data.numpy().flatten()
|
||||||
|
|
||||||
|
def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
|
||||||
|
with torch.no_grad():
|
||||||
|
_, policy_action, log_pi, _ = self.actor(next_obs)
|
||||||
|
target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
|
||||||
|
target_V = torch.min(target_Q1,
|
||||||
|
target_Q2) - self.alpha.detach() * log_pi
|
||||||
|
target_Q = reward + (not_done * self.discount * target_V)
|
||||||
|
|
||||||
|
# get current Q estimates
|
||||||
|
current_Q1, current_Q2 = self.critic(obs, action, detach_encoder=False)
|
||||||
|
critic_loss = F.mse_loss(current_Q1,
|
||||||
|
target_Q) + F.mse_loss(current_Q2, target_Q)
|
||||||
|
L.log('train_critic/loss', critic_loss, step)
|
||||||
|
|
||||||
|
# Optimize the critic
|
||||||
|
self.critic_optimizer.zero_grad()
|
||||||
|
critic_loss.backward()
|
||||||
|
self.critic_optimizer.step()
|
||||||
|
|
||||||
|
self.critic.log(L, step)
|
||||||
|
|
||||||
|
def update_actor_and_alpha(self, obs, L, step):
|
||||||
|
# detach encoder, so we don't update it with the actor loss
|
||||||
|
_, pi, log_pi, log_std = self.actor(obs, detach_encoder=True)
|
||||||
|
actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
|
||||||
|
|
||||||
|
actor_Q = torch.min(actor_Q1, actor_Q2)
|
||||||
|
actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
|
||||||
|
|
||||||
|
L.log('train_actor/loss', actor_loss, step)
|
||||||
|
L.log('train_actor/target_entropy', self.target_entropy, step)
|
||||||
|
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
|
||||||
|
) + log_std.sum(dim=-1)
|
||||||
|
L.log('train_actor/entropy', entropy.mean(), step)
|
||||||
|
|
||||||
|
# optimize the actor
|
||||||
|
self.actor_optimizer.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self.actor_optimizer.step()
|
||||||
|
|
||||||
|
self.actor.log(L, step)
|
||||||
|
|
||||||
|
self.log_alpha_optimizer.zero_grad()
|
||||||
|
alpha_loss = (self.alpha *
|
||||||
|
(-log_pi - self.target_entropy).detach()).mean()
|
||||||
|
L.log('train_alpha/loss', alpha_loss, step)
|
||||||
|
L.log('train_alpha/value', self.alpha, step)
|
||||||
|
alpha_loss.backward()
|
||||||
|
self.log_alpha_optimizer.step()
|
||||||
|
|
||||||
|
def update_transition_reward_model(self, obs, action, next_obs, reward, L, step):
|
||||||
|
h = self.critic.encoder(obs)
|
||||||
|
pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
|
||||||
|
if pred_next_latent_sigma is None:
|
||||||
|
pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)
|
||||||
|
|
||||||
|
next_h = self.critic.encoder(next_obs)
|
||||||
|
diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
|
||||||
|
loss = torch.mean(0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma))
|
||||||
|
L.log('train_ae/transition_loss', loss, step)
|
||||||
|
|
||||||
|
pred_next_reward = self.reward_decoder(torch.cat([h, action], dim=1))
|
||||||
|
reward_loss = F.mse_loss(pred_next_reward, reward)
|
||||||
|
total_loss = loss + reward_loss
|
||||||
|
self.encoder_optimizer.zero_grad()
|
||||||
|
self.decoder_optimizer.zero_grad()
|
||||||
|
total_loss.backward()
|
||||||
|
self.encoder_optimizer.step()
|
||||||
|
self.decoder_optimizer.step()
|
||||||
|
|
||||||
|
def update_decoder(self, obs, action, target_obs, L, step): # uses transition model
|
||||||
|
# image might be stacked, just grab the first 3 (rgb)!
|
||||||
|
assert target_obs.dim() == 4
|
||||||
|
target_obs = target_obs[:, :3, :, :]
|
||||||
|
|
||||||
|
h = self.critic.encoder(obs)
|
||||||
|
if not self.reconstruction:
|
||||||
|
next_h = self.transition_model.sample_prediction(torch.cat([h, action], dim=1))
|
||||||
|
if target_obs.dim() == 4:
|
||||||
|
# preprocess images to be in [-0.5, 0.5] range
|
||||||
|
target_obs = utils.preprocess_obs(target_obs)
|
||||||
|
rec_obs = self.decoder(next_h)
|
||||||
|
loss = F.mse_loss(target_obs, rec_obs)
|
||||||
|
else:
|
||||||
|
rec_obs = self.decoder(h)
|
||||||
|
loss = F.mse_loss(obs, rec_obs)
|
||||||
|
|
||||||
|
self.encoder_optimizer.zero_grad()
|
||||||
|
self.decoder_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
self.encoder_optimizer.step()
|
||||||
|
self.decoder_optimizer.step()
|
||||||
|
L.log('train_ae/ae_loss', loss, step)
|
||||||
|
|
||||||
|
self.decoder.log(L, step, log_freq=LOG_FREQ)
|
||||||
|
|
||||||
|
def update(self, replay_buffer, L, step):
|
||||||
|
obs, action, _, reward, next_obs, not_done = replay_buffer.sample()
|
||||||
|
|
||||||
|
L.log('train/batch_reward', reward.mean(), step)
|
||||||
|
|
||||||
|
self.update_critic(obs, action, reward, next_obs, not_done, L, step)
|
||||||
|
self.update_transition_reward_model(obs, action, next_obs, reward, L, step)
|
||||||
|
|
||||||
|
if step % self.actor_update_freq == 0:
|
||||||
|
self.update_actor_and_alpha(obs, L, step)
|
||||||
|
|
||||||
|
if step % self.critic_target_update_freq == 0:
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.Q1, self.critic_target.Q1, self.critic_tau
|
||||||
|
)
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.Q2, self.critic_target.Q2, self.critic_tau
|
||||||
|
)
|
||||||
|
utils.soft_update_params(
|
||||||
|
self.critic.encoder, self.critic_target.encoder,
|
||||||
|
self.encoder_tau
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.decoder is not None and step % self.decoder_update_freq == 0: # decoder_type is pixel
|
||||||
|
self.update_decoder(obs, action, next_obs, L, step)
|
||||||
|
|
||||||
|
def save(self, model_dir, step):
|
||||||
|
torch.save(
|
||||||
|
self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
torch.save(
|
||||||
|
self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
if self.decoder is not None:
|
||||||
|
torch.save(
|
||||||
|
self.decoder.state_dict(),
|
||||||
|
'%s/decoder_%s.pt' % (model_dir, step)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, model_dir, step):
|
||||||
|
self.actor.load_state_dict(
|
||||||
|
torch.load('%s/actor_%s.pt' % (model_dir, step))
|
||||||
|
)
|
||||||
|
self.critic.load_state_dict(
|
||||||
|
torch.load('%s/critic_%s.pt' % (model_dir, step))
|
||||||
|
)
|
||||||
|
if self.decoder is not None:
|
||||||
|
self.decoder.load_state_dict(
|
||||||
|
torch.load('%s/decoder_%s.pt' % (model_dir, step))
|
||||||
|
)
|
25
conda_env.yml
Normal file
25
conda_env.yml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
name: dbc
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.6
|
||||||
|
- pytorch
|
||||||
|
- torchvision
|
||||||
|
- cudatoolkit=9.2
|
||||||
|
- absl-py
|
||||||
|
- pyparsing
|
||||||
|
- pip:
|
||||||
|
- termcolor
|
||||||
|
- git+git://github.com/deepmind/dm_control.git
|
||||||
|
- git+git://github.com/1nadequacy/dmc2gym.git
|
||||||
|
- opencv-python
|
||||||
|
- pillow=6.1
|
||||||
|
- scikit-image
|
||||||
|
- scikit-video
|
||||||
|
- tb-nightly
|
||||||
|
- tqdm
|
||||||
|
- imageio
|
||||||
|
- imageio-ffmpeg
|
||||||
|
- pygame
|
||||||
|
- networkx
|
||||||
|
- dotmap
|
83
decoder.py
Normal file
83
decoder.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class PixelDecoder(nn.Module):
|
||||||
|
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_filters = num_filters
|
||||||
|
self.init_height = 4
|
||||||
|
self.init_width = 25
|
||||||
|
num_out_channels = 3 # rgb
|
||||||
|
kernel = 3
|
||||||
|
|
||||||
|
self.fc = nn.Linear(
|
||||||
|
feature_dim, num_filters * self.init_height * self.init_width
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deconvs = nn.ModuleList()
|
||||||
|
|
||||||
|
pads = [0, 1, 0]
|
||||||
|
for i in range(self.num_layers - 1):
|
||||||
|
output_padding = pads[i]
|
||||||
|
self.deconvs.append(
|
||||||
|
nn.ConvTranspose2d(num_filters, num_filters, kernel, stride=2, output_padding=output_padding)
|
||||||
|
)
|
||||||
|
self.deconvs.append(
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
num_filters, num_out_channels, kernel, stride=2, output_padding=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.outputs = dict()
|
||||||
|
|
||||||
|
def forward(self, h):
|
||||||
|
h = torch.relu(self.fc(h))
|
||||||
|
self.outputs['fc'] = h
|
||||||
|
|
||||||
|
deconv = h.view(-1, self.num_filters, self.init_height, self.init_width)
|
||||||
|
self.outputs['deconv1'] = deconv
|
||||||
|
|
||||||
|
for i in range(0, self.num_layers - 1):
|
||||||
|
deconv = torch.relu(self.deconvs[i](deconv))
|
||||||
|
self.outputs['deconv%s' % (i + 1)] = deconv
|
||||||
|
|
||||||
|
obs = self.deconvs[-1](deconv)
|
||||||
|
self.outputs['obs'] = obs
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def log(self, L, step, log_freq):
|
||||||
|
if step % log_freq != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for k, v in self.outputs.items():
|
||||||
|
L.log_histogram('train_decoder/%s_hist' % k, v, step)
|
||||||
|
if len(v.shape) > 2:
|
||||||
|
L.log_image('train_decoder/%s_i' % k, v[0], step)
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
L.log_param(
|
||||||
|
'train_decoder/deconv%s' % (i + 1), self.deconvs[i], step
|
||||||
|
)
|
||||||
|
L.log_param('train_decoder/fc', self.fc, step)
|
||||||
|
|
||||||
|
|
||||||
|
_AVAILABLE_DECODERS = {'pixel': PixelDecoder}
|
||||||
|
|
||||||
|
|
||||||
|
def make_decoder(
|
||||||
|
decoder_type, obs_shape, feature_dim, num_layers, num_filters
|
||||||
|
):
|
||||||
|
assert decoder_type in _AVAILABLE_DECODERS
|
||||||
|
return _AVAILABLE_DECODERS[decoder_type](
|
||||||
|
obs_shape, feature_dim, num_layers, num_filters
|
||||||
|
)
|
178
distractors/n_body_problem.py
Normal file
178
distractors/n_body_problem.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.animation as animation
|
||||||
|
|
||||||
|
from scipy.integrate import odeint
|
||||||
|
|
||||||
|
|
||||||
|
class Planets(object):
|
||||||
|
"""
|
||||||
|
Implements a 2D environments where there are N bodies (planets) that attract each other according to a 1/r law.
|
||||||
|
|
||||||
|
We assume the mass of each body is 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# For each dimension of the hypercube
|
||||||
|
MIN_POS = 0. # if box exists
|
||||||
|
MAX_POS = 1. # if box exists
|
||||||
|
INIT_MAX_VEL = 1.
|
||||||
|
GRAVITATIONAL_CONSTANT = 1.
|
||||||
|
|
||||||
|
def __init__(self, num_bodies, num_dimensions=2, dt=0.01, contained_in_a_box=True):
|
||||||
|
self.num_bodies = num_bodies
|
||||||
|
self.num_dimensions = num_dimensions
|
||||||
|
self.dt = dt
|
||||||
|
self.contained_in_a_box = contained_in_a_box
|
||||||
|
|
||||||
|
# state variables
|
||||||
|
self.body_positions = None
|
||||||
|
self.body_velocities = None
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.body_positions = np.random.uniform(self.MIN_POS, self.MAX_POS, size=(self.num_bodies, self.num_dimensions))
|
||||||
|
self.body_velocities = self.INIT_MAX_VEL * np.random.uniform(-1, 1, size=(self.num_bodies, self.num_dimensions))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return np.concatenate((self.body_positions, self.body_velocities), axis=1) # (N, 2D)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
|
||||||
|
# Helper functions since ode solver requires flattened inputs
|
||||||
|
def flatten(positions, velocities): # positions shape (N, D); velocities shape (N, D)
|
||||||
|
system_state = np.concatenate((positions, velocities), axis=1) # (N, 2D)
|
||||||
|
system_state_flat = system_state.flatten() # ode solver requires flat, (N*2D,)
|
||||||
|
return system_state_flat
|
||||||
|
|
||||||
|
def unflatten(system_state_flat): # system_state_flat shape (N*2*D,)
|
||||||
|
system_state = system_state_flat.reshape(self.num_bodies, 2 * self.num_dimensions) # (N, 2*D)
|
||||||
|
positions = system_state[:, :self.num_dimensions] # (N, D)
|
||||||
|
velocities = system_state[:, self.num_dimensions:] # (N, D)
|
||||||
|
return positions, velocities
|
||||||
|
|
||||||
|
# ODE function
|
||||||
|
def system_first_order_ode(system_state_flat, _):
|
||||||
|
|
||||||
|
positions, velocities = unflatten(system_state_flat)
|
||||||
|
accelerations = np.zeros_like(velocities) # init (N, D)
|
||||||
|
|
||||||
|
for i in range(self.num_bodies):
|
||||||
|
relative_positions = positions - positions[i] # (N, D)
|
||||||
|
distances = np.linalg.norm(relative_positions, axis=1, keepdims=True) # (N, 1)
|
||||||
|
distances[i] = 1. # bodies don't affect themselves, and we don't want to divide by zero next
|
||||||
|
|
||||||
|
# forces (see https://en.wikipedia.org/wiki/Numerical_model_of_the_Solar_System)
|
||||||
|
force_vectors = self.GRAVITATIONAL_CONSTANT * relative_positions / (distances**self.num_dimensions) # (N,D)
|
||||||
|
force_vector = np.sum(force_vectors, axis=0) # (D,)
|
||||||
|
accelerations[i] = force_vector # assuming mass 1.
|
||||||
|
|
||||||
|
d_system_state_flat = flatten(velocities, accelerations)
|
||||||
|
return d_system_state_flat
|
||||||
|
|
||||||
|
# integrate + update
|
||||||
|
current_system_state_flat = flatten(self.body_positions, self.body_velocities) # (N*2*D,)
|
||||||
|
_, next_system_state_flat = odeint(system_first_order_ode, current_system_state_flat, [0., self.dt]) # (N*2*D,)
|
||||||
|
self.body_positions, self.body_velocities = unflatten(next_system_state_flat) # (N, D), (N, D)
|
||||||
|
|
||||||
|
# bounce off boundaries of box
|
||||||
|
if self.contained_in_a_box:
|
||||||
|
ind_below_min = self.body_positions < self.MIN_POS
|
||||||
|
ind_above_max = self.body_positions > self.MAX_POS
|
||||||
|
self.body_positions[ind_below_min] += 2. * (self.MIN_POS - self.body_positions[ind_below_min])
|
||||||
|
self.body_positions[ind_above_max] += 2. * (self.MAX_POS - self.body_positions[ind_above_max])
|
||||||
|
self.body_velocities[ind_below_min] *= -1.
|
||||||
|
self.body_velocities[ind_above_max] *= -1.
|
||||||
|
self.assert_bodies_in_box() # check for bugs
|
||||||
|
|
||||||
|
def animate(self, file_name=None, frames=1000, pixel_length=None, tight_format=True):
|
||||||
|
"""
|
||||||
|
Animation function for visual debugging.
|
||||||
|
"""
|
||||||
|
if self.num_dimensions is not 2:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if pixel_length is None:
|
||||||
|
fig = plt.figure()
|
||||||
|
else:
|
||||||
|
# matplotlib can't render if pixel_length is too small, so just run in the background id pixels specified
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
my_dpi = 96 # find your screen's dpi here: https://www.infobyip.com/detectmonitordpi.php
|
||||||
|
fig = plt.figure(facecolor='lightslategray', figsize=(pixel_length/my_dpi, pixel_length/my_dpi), dpi=my_dpi)
|
||||||
|
|
||||||
|
ax = fig.add_subplot(1, 1, 1)
|
||||||
|
plt.axis('off')
|
||||||
|
if tight_format:
|
||||||
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=None, hspace=None)
|
||||||
|
body_colors = np.random.uniform(size=self.num_bodies)
|
||||||
|
|
||||||
|
def render(_):
|
||||||
|
self.step()
|
||||||
|
x = self.body_positions[:, 0]
|
||||||
|
y = self.body_positions[:, 1]
|
||||||
|
|
||||||
|
ax.clear()
|
||||||
|
# if tight_format:
|
||||||
|
# plt.subplots_adjust(left=0., right=1., top=1., bottom=0.)
|
||||||
|
ax.scatter(x, y, marker='o', c=body_colors, cmap='viridis')
|
||||||
|
# ax.set_title(self.__class__.__name__ + "\n(temperature inside box: {:.1f})".format(self.temperature))
|
||||||
|
ax.set_xlim(self.MIN_POS, self.MAX_POS)
|
||||||
|
ax.set_ylim(self.MIN_POS, self.MAX_POS)
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.set_aspect('equal')
|
||||||
|
# ax.axis('off')
|
||||||
|
ax.set_facecolor('black')
|
||||||
|
if tight_format:
|
||||||
|
ax.margins(x=0., y=0.)
|
||||||
|
|
||||||
|
interval_milliseconds = 1000 * self.dt
|
||||||
|
anim = animation.FuncAnimation(fig, render, frames=frames, interval=interval_milliseconds)
|
||||||
|
|
||||||
|
plt.pause(1)
|
||||||
|
if file_name is None:
|
||||||
|
file_name = self.__class__.__name__.lower() + '.gif'
|
||||||
|
file_name = 'images/' + file_name
|
||||||
|
print('Saving file {} ...'.format(file_name))
|
||||||
|
anim.save(file_name, writer='imagemagick')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
def assert_bodies_in_box(self):
|
||||||
|
"""
|
||||||
|
if the sim goes really fast, they can bounce one-step out of box. Let's just check for this for now, fix later
|
||||||
|
"""
|
||||||
|
assert np.all(self.body_positions >= self.MIN_POS) and np.all(self.body_positions <= self.MAX_POS)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temperature(self):
|
||||||
|
"""
|
||||||
|
Temperature is the average kinetic energy of system
|
||||||
|
:return: float
|
||||||
|
"""
|
||||||
|
average_kinetic_energy = 0.5 * np.mean(np.linalg.norm(self.body_velocities, axis=1)) # (N, D) --> (1,)
|
||||||
|
return average_kinetic_energy
|
||||||
|
|
||||||
|
|
||||||
|
class Electrons(Planets):
|
||||||
|
"""
|
||||||
|
Implements a 2D environments where there are N bodies (electrons) that repel each other according to a 1/r law.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# override
|
||||||
|
GRAVITATIONAL_CONSTANT = -1. # negative means they repel
|
||||||
|
|
||||||
|
|
||||||
|
class IdealGas(Planets):
|
||||||
|
"""
|
||||||
|
Implements a 2D environments where there are N bodies (gas molecules) that do not interact with each other.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# override
|
||||||
|
GRAVITATIONAL_CONSTANT = 0. # zero means they don't interact
|
19
distractors/render_n_body_problem_envs.py
Normal file
19
distractors/render_n_body_problem_envs.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from envs.n_body_problem import Planets, Electrons, IdealGas
|
||||||
|
|
||||||
|
|
||||||
|
# env1 = Planets(num_bodies=10, num_dimensions=2, dt=0.01, contained_in_a_box=True)
|
||||||
|
# env1.animate() # only animates if num_dimensions == 2
|
||||||
|
#
|
||||||
|
# env2 = Electrons(num_bodies=10, num_dimensions=2, dt=0.01, contained_in_a_box=True)
|
||||||
|
# env2.animate() # only animates if num_dimensions == 2
|
||||||
|
|
||||||
|
for i in range(1):
|
||||||
|
env3 = IdealGas(num_bodies=10, num_dimensions=2, dt=0.01, contained_in_a_box=True)
|
||||||
|
file_name = 'idealgas{}.mp4'.format(i)
|
||||||
|
env3.animate(file_name=file_name, pixel_length=64)
|
52
dmc2gym/__init__.py
Normal file
52
dmc2gym/__init__.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import gym
|
||||||
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
|
||||||
|
def make(
|
||||||
|
domain_name,
|
||||||
|
task_name,
|
||||||
|
resource_files,
|
||||||
|
img_source,
|
||||||
|
total_frames,
|
||||||
|
seed=1,
|
||||||
|
visualize_reward=True,
|
||||||
|
from_pixels=False,
|
||||||
|
height=84,
|
||||||
|
width=84,
|
||||||
|
camera_id=0,
|
||||||
|
frame_skip=1,
|
||||||
|
episode_length=1000,
|
||||||
|
environment_kwargs=None
|
||||||
|
):
|
||||||
|
env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
|
||||||
|
|
||||||
|
if from_pixels:
|
||||||
|
assert not visualize_reward, 'cannot use visualize reward when learning from pixels'
|
||||||
|
|
||||||
|
# shorten episode length
|
||||||
|
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
|
||||||
|
|
||||||
|
if not env_id in gym.envs.registry.env_specs:
|
||||||
|
register(
|
||||||
|
id=env_id,
|
||||||
|
entry_point='dmc2gym.wrappers:DMCWrapper',
|
||||||
|
kwargs={
|
||||||
|
'domain_name': domain_name,
|
||||||
|
'task_name': task_name,
|
||||||
|
'resource_files': resource_files,
|
||||||
|
'img_source': img_source,
|
||||||
|
'total_frames': total_frames,
|
||||||
|
'task_kwargs': {
|
||||||
|
'random': seed
|
||||||
|
},
|
||||||
|
'environment_kwargs': environment_kwargs,
|
||||||
|
'visualize_reward': visualize_reward,
|
||||||
|
'from_pixels': from_pixels,
|
||||||
|
'height': height,
|
||||||
|
'width': width,
|
||||||
|
'camera_id': camera_id,
|
||||||
|
'frame_skip': frame_skip,
|
||||||
|
},
|
||||||
|
max_episode_steps=max_episode_steps
|
||||||
|
)
|
||||||
|
return gym.make(env_id)
|
183
dmc2gym/natural_imgsource.py
Normal file
183
dmc2gym/natural_imgsource.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import skvideo.io
|
||||||
|
import random
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
class BackgroundMatting(object):
|
||||||
|
"""
|
||||||
|
Produce a mask by masking the given color. This is a simple strategy
|
||||||
|
but effective for many games.
|
||||||
|
"""
|
||||||
|
def __init__(self, color):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
color: a (r, g, b) tuple or single value for grayscale
|
||||||
|
"""
|
||||||
|
self._color = color
|
||||||
|
|
||||||
|
def get_mask(self, img):
|
||||||
|
return img == self._color
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSource(object):
|
||||||
|
"""
|
||||||
|
Source of natural images to be added to a simulated environment.
|
||||||
|
"""
|
||||||
|
def get_image(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
an RGB image of [h, w, 3] with a fixed shape.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
""" Called when an episode ends. """
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FixedColorSource(ImageSource):
|
||||||
|
def __init__(self, shape, color):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
shape: [h, w]
|
||||||
|
color: a 3-tuple
|
||||||
|
"""
|
||||||
|
self.arr = np.zeros((shape[0], shape[1], 3))
|
||||||
|
self.arr[:, :] = color
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
return self.arr
|
||||||
|
|
||||||
|
|
||||||
|
class RandomColorSource(ImageSource):
|
||||||
|
def __init__(self, shape):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
shape: [h, w]
|
||||||
|
"""
|
||||||
|
self.shape = shape
|
||||||
|
self.arr = None
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._color = np.random.randint(0, 256, size=(3,))
|
||||||
|
self.arr = np.zeros((self.shape[0], self.shape[1], 3))
|
||||||
|
self.arr[:, :] = self._color
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
return self.arr
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseSource(ImageSource):
|
||||||
|
def __init__(self, shape, strength=255):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
shape: [h, w]
|
||||||
|
strength (int): the strength of noise, in range [0, 255]
|
||||||
|
"""
|
||||||
|
self.shape = shape
|
||||||
|
self.strength = strength
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
return np.random.randn(self.shape[0], self.shape[1], 3) * self.strength
|
||||||
|
|
||||||
|
|
||||||
|
class RandomImageSource(ImageSource):
|
||||||
|
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
shape: [h, w]
|
||||||
|
filelist: a list of image files
|
||||||
|
"""
|
||||||
|
self.grayscale = grayscale
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.shape = shape
|
||||||
|
self.filelist = filelist
|
||||||
|
self.build_arr()
|
||||||
|
self.current_idx = 0
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def build_arr(self):
|
||||||
|
self.total_frames = self.total_frames if self.total_frames else len(self.filelist)
|
||||||
|
self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||||
|
for i in range(self.total_frames):
|
||||||
|
# if i % len(self.filelist) == 0: random.shuffle(self.filelist)
|
||||||
|
fname = self.filelist[i % len(self.filelist)]
|
||||||
|
if self.grayscale: im = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)[..., None]
|
||||||
|
else: im = cv2.imread(fname, cv2.IMREAD_COLOR)
|
||||||
|
self.arr[i] = cv2.resize(im, (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._loc = np.random.randint(0, self.total_frames)
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
return self.arr[self._loc]
|
||||||
|
|
||||||
|
|
||||||
|
class RandomVideoSource(ImageSource):
|
||||||
|
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
shape: [h, w]
|
||||||
|
filelist: a list of video files
|
||||||
|
"""
|
||||||
|
self.grayscale = grayscale
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.shape = shape
|
||||||
|
self.filelist = filelist
|
||||||
|
self.build_arr()
|
||||||
|
self.current_idx = 0
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def build_arr(self):
|
||||||
|
if not self.total_frames:
|
||||||
|
self.total_frames = 0
|
||||||
|
self.arr = None
|
||||||
|
random.shuffle(self.filelist)
|
||||||
|
for fname in tqdm.tqdm(self.filelist, desc="Loading videos for natural", position=0):
|
||||||
|
if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
||||||
|
else: frames = skvideo.io.vread(fname)
|
||||||
|
local_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||||
|
for i in tqdm.tqdm(range(frames.shape[0]), desc="video frames", position=1):
|
||||||
|
local_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
|
||||||
|
if self.arr is None:
|
||||||
|
self.arr = local_arr
|
||||||
|
else:
|
||||||
|
self.arr = np.concatenate([self.arr, local_arr], 0)
|
||||||
|
self.total_frames += local_arr.shape[0]
|
||||||
|
else:
|
||||||
|
self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||||
|
total_frame_i = 0
|
||||||
|
file_i = 0
|
||||||
|
with tqdm.tqdm(total=self.total_frames, desc="Loading videos for natural") as pbar:
|
||||||
|
while total_frame_i < self.total_frames:
|
||||||
|
if file_i % len(self.filelist) == 0: random.shuffle(self.filelist)
|
||||||
|
file_i += 1
|
||||||
|
fname = self.filelist[file_i % len(self.filelist)]
|
||||||
|
if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
||||||
|
else: frames = skvideo.io.vread(fname)
|
||||||
|
for frame_i in range(frames.shape[0]):
|
||||||
|
if total_frame_i >= self.total_frames: break
|
||||||
|
if self.grayscale:
|
||||||
|
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))[..., None] ## THIS IS NOT A BUG! cv2 uses (width, height)
|
||||||
|
else:
|
||||||
|
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))
|
||||||
|
pbar.update(1)
|
||||||
|
total_frame_i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._loc = np.random.randint(0, self.total_frames)
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
img = self.arr[self._loc % self.total_frames]
|
||||||
|
self._loc += 1
|
||||||
|
return img
|
198
dmc2gym/wrappers.py
Normal file
198
dmc2gym/wrappers.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
from gym import core, spaces
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import local_dm_control_suite as suite
|
||||||
|
from dm_env import specs
|
||||||
|
import numpy as np
|
||||||
|
import skimage.io
|
||||||
|
|
||||||
|
from dmc2gym import natural_imgsource
|
||||||
|
|
||||||
|
|
||||||
|
def _spec_to_box(spec):
|
||||||
|
def extract_min_max(s):
|
||||||
|
assert s.dtype == np.float64 or s.dtype == np.float32
|
||||||
|
dim = np.int(np.prod(s.shape))
|
||||||
|
if type(s) == specs.Array:
|
||||||
|
bound = np.inf * np.ones(dim, dtype=np.float32)
|
||||||
|
return -bound, bound
|
||||||
|
elif type(s) == specs.BoundedArray:
|
||||||
|
zeros = np.zeros(dim, dtype=np.float32)
|
||||||
|
return s.minimum + zeros, s.maximum + zeros
|
||||||
|
|
||||||
|
mins, maxs = [], []
|
||||||
|
for s in spec:
|
||||||
|
mn, mx = extract_min_max(s)
|
||||||
|
mins.append(mn)
|
||||||
|
maxs.append(mx)
|
||||||
|
low = np.concatenate(mins, axis=0)
|
||||||
|
high = np.concatenate(maxs, axis=0)
|
||||||
|
assert low.shape == high.shape
|
||||||
|
return spaces.Box(low, high, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_obs(obs):
|
||||||
|
obs_pieces = []
|
||||||
|
for v in obs.values():
|
||||||
|
flat = np.array([v]) if np.isscalar(v) else v.ravel()
|
||||||
|
obs_pieces.append(flat)
|
||||||
|
return np.concatenate(obs_pieces, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DMCWrapper(core.Env):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
domain_name,
|
||||||
|
task_name,
|
||||||
|
resource_files,
|
||||||
|
img_source,
|
||||||
|
total_frames,
|
||||||
|
task_kwargs=None,
|
||||||
|
visualize_reward={},
|
||||||
|
from_pixels=False,
|
||||||
|
height=84,
|
||||||
|
width=84,
|
||||||
|
camera_id=0,
|
||||||
|
frame_skip=1,
|
||||||
|
environment_kwargs=None
|
||||||
|
):
|
||||||
|
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
|
||||||
|
self._from_pixels = from_pixels
|
||||||
|
self._height = height
|
||||||
|
self._width = width
|
||||||
|
self._camera_id = camera_id
|
||||||
|
self._frame_skip = frame_skip
|
||||||
|
self._img_source = img_source
|
||||||
|
|
||||||
|
# create task
|
||||||
|
self._env = suite.load(
|
||||||
|
domain_name=domain_name,
|
||||||
|
task_name=task_name,
|
||||||
|
task_kwargs=task_kwargs,
|
||||||
|
visualize_reward=visualize_reward,
|
||||||
|
environment_kwargs=environment_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# true and normalized action spaces
|
||||||
|
self._true_action_space = _spec_to_box([self._env.action_spec()])
|
||||||
|
self._norm_action_space = spaces.Box(
|
||||||
|
low=-1.0,
|
||||||
|
high=1.0,
|
||||||
|
shape=self._true_action_space.shape,
|
||||||
|
dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# create observation space
|
||||||
|
if from_pixels:
|
||||||
|
self._observation_space = spaces.Box(
|
||||||
|
low=0, high=255, shape=[3, height, width], dtype=np.uint8
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._observation_space = _spec_to_box(
|
||||||
|
self._env.observation_spec().values()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._internal_state_space = spaces.Box(
|
||||||
|
low=-np.inf,
|
||||||
|
high=np.inf,
|
||||||
|
shape=self._env.physics.get_state().shape,
|
||||||
|
dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# background
|
||||||
|
if img_source is not None:
|
||||||
|
shape2d = (height, width)
|
||||||
|
if img_source == "color":
|
||||||
|
self._bg_source = natural_imgsource.RandomColorSource(shape2d)
|
||||||
|
elif img_source == "noise":
|
||||||
|
self._bg_source = natural_imgsource.NoiseSource(shape2d)
|
||||||
|
else:
|
||||||
|
files = glob.glob(os.path.expanduser(resource_files))
|
||||||
|
assert len(files), "Pattern {} does not match any files".format(
|
||||||
|
resource_files
|
||||||
|
)
|
||||||
|
if img_source == "images":
|
||||||
|
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
||||||
|
elif img_source == "video":
|
||||||
|
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
||||||
|
else:
|
||||||
|
raise Exception("img_source %s not defined." % img_source)
|
||||||
|
|
||||||
|
# set seed
|
||||||
|
self.seed(seed=task_kwargs.get('random', 1))
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._env, name)
|
||||||
|
|
||||||
|
def _get_obs(self, time_step):
|
||||||
|
if self._from_pixels:
|
||||||
|
obs = self.render(
|
||||||
|
height=self._height,
|
||||||
|
width=self._width,
|
||||||
|
camera_id=self._camera_id
|
||||||
|
)
|
||||||
|
if self._img_source is not None:
|
||||||
|
mask = np.logical_and((obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])) # hardcoded for dmc
|
||||||
|
bg = self._bg_source.get_image()
|
||||||
|
obs[mask] = bg[mask]
|
||||||
|
obs = obs.transpose(2, 0, 1).copy()
|
||||||
|
else:
|
||||||
|
obs = _flatten_obs(time_step.observation)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _convert_action(self, action):
|
||||||
|
action = action.astype(np.float64)
|
||||||
|
true_delta = self._true_action_space.high - self._true_action_space.low
|
||||||
|
norm_delta = self._norm_action_space.high - self._norm_action_space.low
|
||||||
|
action = (action - self._norm_action_space.low) / norm_delta
|
||||||
|
action = action * true_delta + self._true_action_space.low
|
||||||
|
action = action.astype(np.float32)
|
||||||
|
return action
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_space(self):
|
||||||
|
return self._observation_space
|
||||||
|
|
||||||
|
@property
|
||||||
|
def internal_state_space(self):
|
||||||
|
return self._internal_state_space
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_space(self):
|
||||||
|
return self._norm_action_space
|
||||||
|
|
||||||
|
def seed(self, seed):
|
||||||
|
self._true_action_space.seed(seed)
|
||||||
|
self._norm_action_space.seed(seed)
|
||||||
|
self._observation_space.seed(seed)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
assert self._norm_action_space.contains(action)
|
||||||
|
action = self._convert_action(action)
|
||||||
|
assert self._true_action_space.contains(action)
|
||||||
|
reward = 0
|
||||||
|
extra = {'internal_state': self._env.physics.get_state().copy()}
|
||||||
|
|
||||||
|
for _ in range(self._frame_skip):
|
||||||
|
time_step = self._env.step(action)
|
||||||
|
reward += time_step.reward or 0
|
||||||
|
done = time_step.last()
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
obs = self._get_obs(time_step)
|
||||||
|
extra['discount'] = time_step.discount
|
||||||
|
return obs, reward, done, extra
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
time_step = self._env.reset()
|
||||||
|
obs = self._get_obs(time_step)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
|
||||||
|
assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
|
||||||
|
height = height or self._height
|
||||||
|
width = width or self._width
|
||||||
|
camera_id = camera_id or self._camera_id
|
||||||
|
return self._env.physics.render(
|
||||||
|
height=height, width=width, camera_id=camera_id
|
||||||
|
)
|
169
encoder.py
Normal file
169
encoder.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def tie_weights(src, trg):
|
||||||
|
assert type(src) == type(trg)
|
||||||
|
trg.weight = src.weight
|
||||||
|
trg.bias = src.bias
|
||||||
|
|
||||||
|
|
||||||
|
class PixelEncoder(nn.Module):
|
||||||
|
"""Convolutional encoder of pixels observations."""
|
||||||
|
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32, stride=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(obs_shape) == 3
|
||||||
|
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
|
||||||
|
)
|
||||||
|
for i in range(num_layers - 1):
|
||||||
|
self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
|
||||||
|
|
||||||
|
out_dim = {2: 39, 4: 35, 6: 31}[num_layers]
|
||||||
|
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
|
||||||
|
self.ln = nn.LayerNorm(self.feature_dim)
|
||||||
|
|
||||||
|
self.outputs = dict()
|
||||||
|
|
||||||
|
def reparameterize(self, mu, logstd):
|
||||||
|
std = torch.exp(logstd)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return mu + eps * std
|
||||||
|
|
||||||
|
def forward_conv(self, obs):
|
||||||
|
obs = obs / 255.
|
||||||
|
self.outputs['obs'] = obs
|
||||||
|
|
||||||
|
conv = torch.relu(self.convs[0](obs))
|
||||||
|
self.outputs['conv1'] = conv
|
||||||
|
|
||||||
|
for i in range(1, self.num_layers):
|
||||||
|
conv = torch.relu(self.convs[i](conv))
|
||||||
|
self.outputs['conv%s' % (i + 1)] = conv
|
||||||
|
|
||||||
|
h = conv.view(conv.size(0), -1)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def forward(self, obs, detach=False):
|
||||||
|
h = self.forward_conv(obs)
|
||||||
|
|
||||||
|
if detach:
|
||||||
|
h = h.detach()
|
||||||
|
|
||||||
|
h_fc = self.fc(h)
|
||||||
|
self.outputs['fc'] = h_fc
|
||||||
|
|
||||||
|
out = self.ln(h_fc)
|
||||||
|
self.outputs['ln'] = out
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def copy_conv_weights_from(self, source):
|
||||||
|
"""Tie convolutional layers"""
|
||||||
|
# only tie conv layers
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
tie_weights(src=source.convs[i], trg=self.convs[i])
|
||||||
|
|
||||||
|
def log(self, L, step, log_freq):
|
||||||
|
if step % log_freq != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for k, v in self.outputs.items():
|
||||||
|
L.log_histogram('train_encoder/%s_hist' % k, v, step)
|
||||||
|
if len(v.shape) > 2:
|
||||||
|
L.log_image('train_encoder/%s_img' % k, v[0], step)
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
|
||||||
|
L.log_param('train_encoder/fc', self.fc, step)
|
||||||
|
L.log_param('train_encoder/ln', self.ln, step)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelEncoderCarla096(PixelEncoder):
|
||||||
|
"""Convolutional encoder of pixels observations."""
|
||||||
|
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32, stride=1):
|
||||||
|
super(PixelEncoder, self).__init__()
|
||||||
|
|
||||||
|
assert len(obs_shape) == 3
|
||||||
|
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
|
||||||
|
)
|
||||||
|
for i in range(num_layers - 1):
|
||||||
|
self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=stride))
|
||||||
|
|
||||||
|
out_dims = 100 # if defaults change, adjust this as needed
|
||||||
|
self.fc = nn.Linear(num_filters * out_dims, self.feature_dim)
|
||||||
|
self.ln = nn.LayerNorm(self.feature_dim)
|
||||||
|
|
||||||
|
self.outputs = dict()
|
||||||
|
|
||||||
|
|
||||||
|
class PixelEncoderCarla098(PixelEncoder):
|
||||||
|
"""Convolutional encoder of pixels observations."""
|
||||||
|
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32, stride=1):
|
||||||
|
super(PixelEncoder, self).__init__()
|
||||||
|
|
||||||
|
assert len(obs_shape) == 3
|
||||||
|
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
self.convs.append(nn.Conv2d(obs_shape[0], 64, 5, stride=2))
|
||||||
|
self.convs.append(nn.Conv2d(64, 128, 3, stride=2))
|
||||||
|
self.convs.append(nn.Conv2d(128, 256, 3, stride=2))
|
||||||
|
self.convs.append(nn.Conv2d(256, 256, 3, stride=2))
|
||||||
|
|
||||||
|
out_dims = 56 # 3 cameras
|
||||||
|
# out_dims = 100 # 5 cameras
|
||||||
|
self.fc = nn.Linear(256 * out_dims, self.feature_dim)
|
||||||
|
self.ln = nn.LayerNorm(self.feature_dim)
|
||||||
|
|
||||||
|
self.outputs = dict()
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityEncoder(nn.Module):
|
||||||
|
def __init__(self, obs_shape, feature_dim, num_layers, num_filters):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(obs_shape) == 1
|
||||||
|
self.feature_dim = obs_shape[0]
|
||||||
|
|
||||||
|
def forward(self, obs, detach=False):
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def copy_conv_weights_from(self, source):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log(self, L, step, log_freq):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_AVAILABLE_ENCODERS = {'pixel': PixelEncoder,
|
||||||
|
'pixelCarla096': PixelEncoderCarla096,
|
||||||
|
'pixelCarla098': PixelEncoderCarla098,
|
||||||
|
'identity': IdentityEncoder}
|
||||||
|
|
||||||
|
|
||||||
|
def make_encoder(
|
||||||
|
encoder_type, obs_shape, feature_dim, num_layers, num_filters, stride
|
||||||
|
):
|
||||||
|
assert encoder_type in _AVAILABLE_ENCODERS
|
||||||
|
return _AVAILABLE_ENCODERS[encoder_type](
|
||||||
|
obs_shape, feature_dim, num_layers, num_filters, stride
|
||||||
|
)
|
313
graph_utils.py
Normal file
313
graph_utils.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import scipy.interpolate
|
||||||
|
|
||||||
|
|
||||||
|
def read_log_file(file_name, key_name, value_name, smooth=3):
|
||||||
|
keys, values = [], []
|
||||||
|
try:
|
||||||
|
with open(file_name, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
e = json.loads(line.strip())
|
||||||
|
key, value = e[key_name], e[value_name]
|
||||||
|
keys.append(int(key))
|
||||||
|
values.append(float(value))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
print('bad file: %s' % file_name)
|
||||||
|
return None, None
|
||||||
|
keys, values = np.array(keys), np.array(values)
|
||||||
|
if smooth > 1 and values.shape[0] > 0:
|
||||||
|
K = np.ones(smooth)
|
||||||
|
ones = np.ones(values.shape[0])
|
||||||
|
values = np.convolve(values, K, 'same') / np.convolve(ones, K, 'same')
|
||||||
|
|
||||||
|
return keys, values
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_files(
|
||||||
|
file_name_template,
|
||||||
|
key_name,
|
||||||
|
value_name,
|
||||||
|
num_seeds,
|
||||||
|
smooth,
|
||||||
|
best_k=None,
|
||||||
|
max_key=True
|
||||||
|
):
|
||||||
|
all_values = []
|
||||||
|
all_keys = []
|
||||||
|
actual_keys = None
|
||||||
|
for seed in range(1, num_seeds + 1):
|
||||||
|
file_name = file_name_template % seed
|
||||||
|
keys, values = read_log_file(file_name, key_name, value_name, smooth)
|
||||||
|
if keys is None or keys.shape[0] == 0:
|
||||||
|
continue
|
||||||
|
all_keys.append(keys)
|
||||||
|
all_values.append(values)
|
||||||
|
|
||||||
|
if len(all_values) == 0:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
all_keys_tmp = sorted(all_keys, key=lambda x: x[-1])
|
||||||
|
keys = all_keys_tmp[-1] if max_key else all_keys_tmp[0]
|
||||||
|
threshold = keys.shape[0]
|
||||||
|
|
||||||
|
# interpolate
|
||||||
|
for idx, (key, value) in enumerate(zip(all_keys, all_values)):
|
||||||
|
f = scipy.interpolate.interp1d(key, value, fill_value='extrapolate')
|
||||||
|
all_keys[idx] = keys
|
||||||
|
all_values[idx] = f(keys)
|
||||||
|
|
||||||
|
means, half_stds = [], []
|
||||||
|
for i in range(threshold):
|
||||||
|
vals = []
|
||||||
|
|
||||||
|
for v in all_values:
|
||||||
|
if i < v.shape[0]:
|
||||||
|
vals.append(v[i])
|
||||||
|
if best_k is not None:
|
||||||
|
vals = sorted(vals)[-best_k:]
|
||||||
|
means.append(np.mean(vals))
|
||||||
|
half_stds.append(0.5 * np.std(vals))
|
||||||
|
|
||||||
|
means = np.array(means)
|
||||||
|
half_stds = np.array(half_stds)
|
||||||
|
|
||||||
|
keys = all_keys[-1][:threshold]
|
||||||
|
assert means.shape[0] == keys.shape[0]
|
||||||
|
|
||||||
|
print(file_name_template, means[-1])
|
||||||
|
return keys, means, half_stds
|
||||||
|
# return all_keys, all_values
|
||||||
|
|
||||||
|
|
||||||
|
def print_result(
|
||||||
|
root,
|
||||||
|
title,
|
||||||
|
label=None,
|
||||||
|
num_seeds=1,
|
||||||
|
smooth=3,
|
||||||
|
train=False,
|
||||||
|
key_name='step',
|
||||||
|
value_name='episode_reward',
|
||||||
|
max_time=None,
|
||||||
|
best_k=None,
|
||||||
|
timescale=1,
|
||||||
|
max_key=False
|
||||||
|
):
|
||||||
|
file_name = 'train.log' if train else 'eval.log'
|
||||||
|
file_name_template = os.path.join(root, 'seed_%d', file_name)
|
||||||
|
keys, means, half_stds = parse_log_files(
|
||||||
|
file_name_template,
|
||||||
|
key_name,
|
||||||
|
value_name,
|
||||||
|
num_seeds,
|
||||||
|
smooth=smooth,
|
||||||
|
best_k=best_k,
|
||||||
|
max_key=max_key
|
||||||
|
)
|
||||||
|
label = label or root.split('/')[-1]
|
||||||
|
if keys is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if max_time is not None:
|
||||||
|
idxs = np.where(keys <= max_time)
|
||||||
|
keys = keys[idxs]
|
||||||
|
means = means[idxs]
|
||||||
|
half_stds = half_stds[idxs]
|
||||||
|
|
||||||
|
keys *= timescale
|
||||||
|
|
||||||
|
plt.plot(keys, means, label=label)
|
||||||
|
plt.locator_params(nbins=10, axis='x')
|
||||||
|
plt.locator_params(nbins=10, axis='y')
|
||||||
|
plt.rcParams['figure.figsize'] = (10, 7)
|
||||||
|
plt.rcParams['figure.dpi'] = 100
|
||||||
|
plt.rcParams['font.size'] = 10
|
||||||
|
plt.subplots_adjust(left=0.165, right=0.99, bottom=0.16, top=0.95)
|
||||||
|
#plt.ylim(0, 1050)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
plt.grid(alpha=0.8)
|
||||||
|
plt.title(title)
|
||||||
|
plt.fill_between(keys, means - half_stds, means + half_stds, alpha=0.2)
|
||||||
|
plt.legend(loc='lower right', prop={
|
||||||
|
'size': 6
|
||||||
|
}).get_frame().set_edgecolor('0.1')
|
||||||
|
plt.xlabel(key_name)
|
||||||
|
plt.ylabel(value_name)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_seeds(
|
||||||
|
task,
|
||||||
|
exp_query,
|
||||||
|
root,
|
||||||
|
train=True,
|
||||||
|
smooth=3,
|
||||||
|
key_name='step',
|
||||||
|
value_name='episode_reward',
|
||||||
|
num_seeds=10
|
||||||
|
):
|
||||||
|
# root = os.path.join(root, task)
|
||||||
|
experiment = None
|
||||||
|
for exp in os.listdir(root):
|
||||||
|
if re.match(exp_query, exp):
|
||||||
|
experiment = os.path.join(root, exp)
|
||||||
|
break
|
||||||
|
if experiment is None:
|
||||||
|
return
|
||||||
|
file_name = 'train.log' if train else 'eval.log'
|
||||||
|
file_name_template = os.path.join(experiment, 'seed_%d', file_name)
|
||||||
|
|
||||||
|
plt.locator_params(nbins=10, axis='x')
|
||||||
|
plt.locator_params(nbins=10, axis='y')
|
||||||
|
plt.rcParams['figure.figsize'] = (10, 7)
|
||||||
|
plt.rcParams['figure.dpi'] = 100
|
||||||
|
plt.rcParams['font.size'] = 10
|
||||||
|
plt.subplots_adjust(left=0.165, right=0.99, bottom=0.16, top=0.95)
|
||||||
|
plt.grid(alpha=0.8)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.title(task)
|
||||||
|
|
||||||
|
plt.xlabel(key_name)
|
||||||
|
plt.ylabel(value_name)
|
||||||
|
|
||||||
|
for seed in range(1, num_seeds + 1):
|
||||||
|
file_name = file_name_template % seed
|
||||||
|
keys, values = read_log_file(file_name, key_name, value_name, smooth=smooth)
|
||||||
|
if keys is None or keys.shape[0] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
plt.plot(keys, values, label='seed_%d' % seed, linewidth=0.5)
|
||||||
|
|
||||||
|
plt.legend(loc='lower right', prop={
|
||||||
|
'size': 6
|
||||||
|
}).get_frame().set_edgecolor('0.1')
|
||||||
|
|
||||||
|
|
||||||
|
def print_baseline(task, baseline, data, color):
|
||||||
|
try:
|
||||||
|
value = data[task][baseline]
|
||||||
|
except:
|
||||||
|
return
|
||||||
|
|
||||||
|
plt.axhline(y=value, label=baseline, linestyle='--', color=color)
|
||||||
|
plt.legend(loc='lower right', prop={
|
||||||
|
'size': 6
|
||||||
|
}).get_frame().set_edgecolor('0.1')
|
||||||
|
|
||||||
|
|
||||||
|
def print_planet_baseline(
|
||||||
|
task, data, max_time=None, label='planet', color='black', offset=0
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
keys, means, half_stds = data[task]
|
||||||
|
except:
|
||||||
|
return
|
||||||
|
|
||||||
|
if max_time is not None:
|
||||||
|
idx = np.searchsorted(keys, max_time)
|
||||||
|
keys = keys[:idx]
|
||||||
|
means = means[:idx]
|
||||||
|
half_stds = half_stds[:idx]
|
||||||
|
|
||||||
|
plt.plot(keys + offset, means, label=label, color=color)
|
||||||
|
plt.fill_between(
|
||||||
|
keys + offset,
|
||||||
|
means - half_stds,
|
||||||
|
means + half_stds,
|
||||||
|
alpha=0.2,
|
||||||
|
color=color
|
||||||
|
)
|
||||||
|
plt.legend(loc='lower right', prop={
|
||||||
|
'size': 6
|
||||||
|
}).get_frame().set_edgecolor('0.1')
|
||||||
|
|
||||||
|
|
||||||
|
def plot_experiment(
|
||||||
|
task,
|
||||||
|
exp_query,
|
||||||
|
neg_exp_query=None,
|
||||||
|
root='runs',
|
||||||
|
exp_ids=None,
|
||||||
|
smooth=3,
|
||||||
|
train=False,
|
||||||
|
key_name='step',
|
||||||
|
value_name='eval_episode_reward',
|
||||||
|
baselines_data=None,
|
||||||
|
num_seeds=10,
|
||||||
|
planet_data=None,
|
||||||
|
slac_data=None,
|
||||||
|
max_time=None,
|
||||||
|
best_k=None,
|
||||||
|
timescale=1,
|
||||||
|
max_key=False
|
||||||
|
):
|
||||||
|
root = os.path.join(root, task)
|
||||||
|
|
||||||
|
experiments = set()
|
||||||
|
for exp in os.listdir(root):
|
||||||
|
if re.match(exp_query, exp) and (neg_exp_query is None or re.match(neg_exp_query, exp) is None):
|
||||||
|
exp = os.path.join(root, exp)
|
||||||
|
experiments.add(exp)
|
||||||
|
|
||||||
|
exp_ids = list(range(len(experiments))) if exp_ids is None else exp_ids
|
||||||
|
for exp_id, exp in enumerate(sorted(experiments)):
|
||||||
|
if exp_id in exp_ids:
|
||||||
|
print_result(
|
||||||
|
exp,
|
||||||
|
task,
|
||||||
|
smooth=smooth,
|
||||||
|
num_seeds=num_seeds,
|
||||||
|
train=train,
|
||||||
|
key_name=key_name,
|
||||||
|
value_name=value_name,
|
||||||
|
max_time=max_time,
|
||||||
|
best_k=best_k,
|
||||||
|
timescale=timescale,
|
||||||
|
max_key=max_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if baselines_data is not None:
|
||||||
|
print_baseline(task, 'd4pg_pixels', baselines_data, color='gray')
|
||||||
|
print_baseline(task, 'd4pg', baselines_data, color='black')
|
||||||
|
|
||||||
|
if planet_data is not None:
|
||||||
|
print_planet_baseline(
|
||||||
|
task,
|
||||||
|
planet_data,
|
||||||
|
max_time=max_time,
|
||||||
|
label='planet',
|
||||||
|
color='peru',
|
||||||
|
offset=5
|
||||||
|
)
|
||||||
|
|
||||||
|
if slac_data is not None:
|
||||||
|
action_repeat = {
|
||||||
|
'ball_in_cup_catch': 4,
|
||||||
|
'cartpole_swingup': 8,
|
||||||
|
'cheetah_run': 4,
|
||||||
|
'finger_spin': 2,
|
||||||
|
'walker_walk': 2,
|
||||||
|
'reacher_easy': 4
|
||||||
|
}
|
||||||
|
offset = 10 * action_repeat[task]
|
||||||
|
print_planet_baseline(
|
||||||
|
task,
|
||||||
|
slac_data,
|
||||||
|
max_time=max_time,
|
||||||
|
label='slac',
|
||||||
|
color='black',
|
||||||
|
offset=offset
|
||||||
|
)
|
56
local_dm_control_suite/README.md
Executable file
56
local_dm_control_suite/README.md
Executable file
@ -0,0 +1,56 @@
|
|||||||
|
# DeepMind Control Suite.
|
||||||
|
|
||||||
|
This submodule contains the domains and tasks described in the
|
||||||
|
[DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690).
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dm_control import suite
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Load one task:
|
||||||
|
env = suite.load(domain_name="cartpole", task_name="swingup")
|
||||||
|
|
||||||
|
# Iterate over a task set:
|
||||||
|
for domain_name, task_name in suite.BENCHMARKING:
|
||||||
|
env = suite.load(domain_name, task_name)
|
||||||
|
|
||||||
|
# Step through an episode and print out reward, discount and observation.
|
||||||
|
action_spec = env.action_spec()
|
||||||
|
time_step = env.reset()
|
||||||
|
while not time_step.last():
|
||||||
|
action = np.random.uniform(action_spec.minimum,
|
||||||
|
action_spec.maximum,
|
||||||
|
size=action_spec.shape)
|
||||||
|
time_step = env.step(action)
|
||||||
|
print(time_step.reward, time_step.discount, time_step.observation)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Illustration video
|
||||||
|
|
||||||
|
Below is a video montage of solved Control Suite tasks, with reward
|
||||||
|
visualisation enabled.
|
||||||
|
|
||||||
|
[![Video montage](https://img.youtube.com/vi/rAai4QzcYbs/0.jpg)](https://www.youtube.com/watch?v=rAai4QzcYbs)
|
||||||
|
|
||||||
|
|
||||||
|
### Quadruped domain [April 2019]
|
||||||
|
|
||||||
|
Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are:
|
||||||
|
|
||||||
|
- 4 DoFs per leg, 1 constraining tendon.
|
||||||
|
- 3 actuators per leg: 'yaw', 'lift', 'extend'.
|
||||||
|
- Filtered position actuators with timescale of 100ms.
|
||||||
|
- Sensors include an IMU, force/torque sensors, and rangefinders.
|
||||||
|
|
||||||
|
Four tasks:
|
||||||
|
|
||||||
|
- `walk` and `run`: self-right the body then move forward at a desired speed.
|
||||||
|
- `escape`: escape a bowl-shaped random terrain (uses rangefinders).
|
||||||
|
- `fetch`, go to a moving ball and bring it to a target.
|
||||||
|
|
||||||
|
All behaviors in the video below were trained with [Abdolmaleki et al's
|
||||||
|
MPO](https://arxiv.org/abs/1806.06920).
|
||||||
|
|
||||||
|
[![Video montage](https://img.youtube.com/vi/RhRLjbb7pBE/0.jpg)](https://www.youtube.com/watch?v=RhRLjbb7pBE)
|
151
local_dm_control_suite/__init__.py
Executable file
151
local_dm_control_suite/__init__.py
Executable file
@ -0,0 +1,151 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""A collection of MuJoCo-based Reinforcement Learning environments."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import inspect
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from dm_control.rl import control
|
||||||
|
|
||||||
|
from local_dm_control_suite import acrobot
|
||||||
|
from local_dm_control_suite import ball_in_cup
|
||||||
|
from local_dm_control_suite import cartpole
|
||||||
|
from local_dm_control_suite import cheetah
|
||||||
|
from local_dm_control_suite import finger
|
||||||
|
from local_dm_control_suite import fish
|
||||||
|
from local_dm_control_suite import hopper
|
||||||
|
from local_dm_control_suite import humanoid
|
||||||
|
from local_dm_control_suite import humanoid_CMU
|
||||||
|
from local_dm_control_suite import lqr
|
||||||
|
from local_dm_control_suite import manipulator
|
||||||
|
from local_dm_control_suite import pendulum
|
||||||
|
from local_dm_control_suite import point_mass
|
||||||
|
from local_dm_control_suite import quadruped
|
||||||
|
from local_dm_control_suite import reacher
|
||||||
|
from local_dm_control_suite import stacker
|
||||||
|
from local_dm_control_suite import swimmer
|
||||||
|
from local_dm_control_suite import walker
|
||||||
|
|
||||||
|
# Find all domains imported.
|
||||||
|
_DOMAINS = {name: module for name, module in locals().items()
|
||||||
|
if inspect.ismodule(module) and hasattr(module, 'SUITE')}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tasks(tag):
|
||||||
|
"""Returns a sequence of (domain name, task name) pairs for the given tag."""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for domain_name in sorted(_DOMAINS.keys()):
|
||||||
|
domain = _DOMAINS[domain_name]
|
||||||
|
|
||||||
|
if tag is None:
|
||||||
|
tasks_in_domain = domain.SUITE
|
||||||
|
else:
|
||||||
|
tasks_in_domain = domain.SUITE.tagged(tag)
|
||||||
|
|
||||||
|
for task_name in tasks_in_domain.keys():
|
||||||
|
result.append((domain_name, task_name))
|
||||||
|
|
||||||
|
return tuple(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tasks_by_domain(tasks):
|
||||||
|
"""Returns a dict mapping from task name to a tuple of domain names."""
|
||||||
|
result = collections.defaultdict(list)
|
||||||
|
|
||||||
|
for domain_name, task_name in tasks:
|
||||||
|
result[domain_name].append(task_name)
|
||||||
|
|
||||||
|
return {k: tuple(v) for k, v in result.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# A sequence containing all (domain name, task name) pairs.
|
||||||
|
ALL_TASKS = _get_tasks(tag=None)
|
||||||
|
|
||||||
|
# Subsets of ALL_TASKS, generated via the tag mechanism.
|
||||||
|
BENCHMARKING = _get_tasks('benchmarking')
|
||||||
|
EASY = _get_tasks('easy')
|
||||||
|
HARD = _get_tasks('hard')
|
||||||
|
EXTRA = tuple(sorted(set(ALL_TASKS) - set(BENCHMARKING)))
|
||||||
|
|
||||||
|
# A mapping from each domain name to a sequence of its task names.
|
||||||
|
TASKS_BY_DOMAIN = _get_tasks_by_domain(ALL_TASKS)
|
||||||
|
|
||||||
|
|
||||||
|
def load(domain_name, task_name, task_kwargs=None, environment_kwargs=None,
|
||||||
|
visualize_reward=False):
|
||||||
|
"""Returns an environment from a domain name, task name and optional settings.
|
||||||
|
|
||||||
|
```python
|
||||||
|
env = suite.load('cartpole', 'balance')
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_name: A string containing the name of a domain.
|
||||||
|
task_name: A string containing the name of a task.
|
||||||
|
task_kwargs: Optional `dict` of keyword arguments for the task.
|
||||||
|
environment_kwargs: Optional `dict` specifying keyword arguments for the
|
||||||
|
environment.
|
||||||
|
visualize_reward: Optional `bool`. If `True`, object colours in rendered
|
||||||
|
frames are set to indicate the reward at each step. Default `False`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The requested environment.
|
||||||
|
"""
|
||||||
|
return build_environment(domain_name, task_name, task_kwargs,
|
||||||
|
environment_kwargs, visualize_reward)
|
||||||
|
|
||||||
|
|
||||||
|
def build_environment(domain_name, task_name, task_kwargs=None,
|
||||||
|
environment_kwargs=None, visualize_reward=False):
|
||||||
|
"""Returns an environment from the suite given a domain name and a task name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_name: A string containing the name of a domain.
|
||||||
|
task_name: A string containing the name of a task.
|
||||||
|
task_kwargs: Optional `dict` specifying keyword arguments for the task.
|
||||||
|
environment_kwargs: Optional `dict` specifying keyword arguments for the
|
||||||
|
environment.
|
||||||
|
visualize_reward: Optional `bool`. If `True`, object colours in rendered
|
||||||
|
frames are set to indicate the reward at each step. Default `False`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the domain or task doesn't exist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the requested environment.
|
||||||
|
"""
|
||||||
|
if domain_name not in _DOMAINS:
|
||||||
|
raise ValueError('Domain {!r} does not exist.'.format(domain_name))
|
||||||
|
|
||||||
|
domain = _DOMAINS[domain_name]
|
||||||
|
|
||||||
|
if task_name not in domain.SUITE:
|
||||||
|
raise ValueError('Level {!r} does not exist in domain {!r}.'.format(
|
||||||
|
task_name, domain_name))
|
||||||
|
|
||||||
|
task_kwargs = task_kwargs or {}
|
||||||
|
if environment_kwargs is not None:
|
||||||
|
task_kwargs = task_kwargs.copy()
|
||||||
|
task_kwargs['environment_kwargs'] = environment_kwargs
|
||||||
|
env = domain.SUITE[task_name](**task_kwargs)
|
||||||
|
env.task.visualize_reward = visualize_reward
|
||||||
|
return env
|
127
local_dm_control_suite/acrobot.py
Executable file
127
local_dm_control_suite/acrobot.py
Executable file
@ -0,0 +1,127 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Acrobot domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 10
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('acrobot.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns Acrobot balance task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Balance(sparse=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns Acrobot sparse balance."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Balance(sparse=True, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Acrobot domain."""
|
||||||
|
|
||||||
|
def horizontal(self):
|
||||||
|
"""Returns horizontal (x) component of body frame z-axes."""
|
||||||
|
return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz']
|
||||||
|
|
||||||
|
def vertical(self):
|
||||||
|
"""Returns vertical (z) component of body frame z-axes."""
|
||||||
|
return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz']
|
||||||
|
|
||||||
|
def to_target(self):
|
||||||
|
"""Returns the distance from the tip to the target."""
|
||||||
|
tip_to_target = (self.named.data.site_xpos['target'] -
|
||||||
|
self.named.data.site_xpos['tip'])
|
||||||
|
return np.linalg.norm(tip_to_target)
|
||||||
|
|
||||||
|
def orientations(self):
|
||||||
|
"""Returns the sines and cosines of the pole angles."""
|
||||||
|
return np.concatenate((self.horizontal(), self.vertical()))
|
||||||
|
|
||||||
|
|
||||||
|
class Balance(base.Task):
|
||||||
|
"""An Acrobot `Task` to swing up and balance the pole."""
|
||||||
|
|
||||||
|
def __init__(self, sparse, random=None):
|
||||||
|
"""Initializes an instance of `Balance`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse: A `bool` specifying whether to use a sparse (indicator) reward.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._sparse = sparse
|
||||||
|
super(Balance, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Shoulder and elbow are set to a random position between [-pi, pi).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
"""
|
||||||
|
physics.named.data.qpos[
|
||||||
|
['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
|
||||||
|
super(Balance, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of pole orientation and angular velocities."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['orientations'] = physics.orientations()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _get_reward(self, physics, sparse):
|
||||||
|
target_radius = physics.named.model.site_size['target', 0]
|
||||||
|
return rewards.tolerance(physics.to_target(),
|
||||||
|
bounds=(0, target_radius),
|
||||||
|
margin=0 if sparse else 1)
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a sparse or a smooth reward, as specified in the constructor."""
|
||||||
|
return self._get_reward(physics, sparse=self._sparse)
|
43
local_dm_control_suite/acrobot.xml
Executable file
43
local_dm_control_suite/acrobot.xml
Executable file
@ -0,0 +1,43 @@
|
|||||||
|
<!--
|
||||||
|
Based on Coulomb's [1] rather than Spong's [2] model.
|
||||||
|
[1] Coulom, Rémi. Reinforcement learning using neural networks, with applications to motor control.
|
||||||
|
Diss. Institut National Polytechnique de Grenoble-INPG, 2002.
|
||||||
|
[2] Spong, Mark W. "The swing up control problem for the acrobot."
|
||||||
|
IEEE control systems 15, no. 1 (1995): 49-55.
|
||||||
|
-->
|
||||||
|
<mujoco model="acrobot">
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<joint damping=".05"/>
|
||||||
|
<geom type="capsule" mass="1"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<option timestep="0.01" integrator="RK4">
|
||||||
|
<flag constraint="disable" energy="enable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" pos="0 0 6"/>
|
||||||
|
<geom name="floor" size="3 3 .2" type="plane" material="grid"/>
|
||||||
|
<site name="target" type="sphere" pos="0 0 4" size="0.2" material="target" group="3"/>
|
||||||
|
<camera name="fixed" pos="0 -6 2" zaxis="0 -1 0"/>
|
||||||
|
<camera name="lookat" mode="targetbodycom" target="upper_arm" pos="0 -2 3"/>
|
||||||
|
<body name="upper_arm" pos="0 0 2">
|
||||||
|
<joint name="shoulder" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="upper_arm_decoration" material="decoration" type="cylinder" fromto="0 -.06 0 0 .06 0" size="0.051" mass="0"/>
|
||||||
|
<geom name="upper_arm" fromto="0 0 0 0 0 1" size="0.05" material="self"/>
|
||||||
|
<body name="lower_arm" pos="0 0 1">
|
||||||
|
<joint name="elbow" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="lower_arm" fromto="0 0 0 0 0 1" size="0.049" material="self"/>
|
||||||
|
<site name="tip" pos="0 0 1" size="0.01"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="elbow" joint="elbow" gear="2" ctrllimited="true" ctrlrange="-1 1"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
100
local_dm_control_suite/ball_in_cup.py
Executable file
100
local_dm_control_suite/ball_in_cup.py
Executable file
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Ball-in-Cup Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 20 # (seconds)
|
||||||
|
_CONTROL_TIMESTEP = .02 # (seconds)
|
||||||
|
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('ball_in_cup.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking', 'easy')
|
||||||
|
def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Ball-in-Cup task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = BallInCup(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics with additional features for the Ball-in-Cup domain."""
|
||||||
|
|
||||||
|
def ball_to_target(self):
|
||||||
|
"""Returns the vector from the ball to the target."""
|
||||||
|
target = self.named.data.site_xpos['target', ['x', 'z']]
|
||||||
|
ball = self.named.data.xpos['ball', ['x', 'z']]
|
||||||
|
return target - ball
|
||||||
|
|
||||||
|
def in_target(self):
|
||||||
|
"""Returns 1 if the ball is in the target, 0 otherwise."""
|
||||||
|
ball_to_target = abs(self.ball_to_target())
|
||||||
|
target_size = self.named.model.site_size['target', [0, 2]]
|
||||||
|
ball_size = self.named.model.geom_size['ball', 0]
|
||||||
|
return float(all(ball_to_target < target_size - ball_size))
|
||||||
|
|
||||||
|
|
||||||
|
class BallInCup(base.Task):
|
||||||
|
"""The Ball-in-Cup task. Put the ball in the cup."""
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Find a collision-free random initial position of the ball.
|
||||||
|
penetrating = True
|
||||||
|
while penetrating:
|
||||||
|
# Assign a random ball position.
|
||||||
|
physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
|
||||||
|
physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
|
||||||
|
# Check for collisions.
|
||||||
|
physics.after_reset()
|
||||||
|
penetrating = physics.data.ncon > 0
|
||||||
|
super(BallInCup, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of the state."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['position'] = physics.position()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a sparse reward."""
|
||||||
|
return physics.in_target()
|
54
local_dm_control_suite/ball_in_cup.xml
Executable file
54
local_dm_control_suite/ball_in_cup.xml
Executable file
@ -0,0 +1,54 @@
|
|||||||
|
<mujoco model="ball in cup">
|
||||||
|
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1 1" gear="5"/>
|
||||||
|
<default class="cup">
|
||||||
|
<joint type="slide" damping="3" stiffness="20"/>
|
||||||
|
<geom type="capsule" size=".008" material="self"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" directional="true" diffuse=".6 .6 .6" pos="0 0 2" specular=".3 .3 .3"/>
|
||||||
|
<geom name="ground" type="plane" pos="0 0 0" size=".6 .2 10" material="grid"/>
|
||||||
|
<camera name="cam0" pos="0 -1 .8" xyaxes="1 0 0 0 1 2"/>
|
||||||
|
<camera name="cam1" pos="0 -1 .4" xyaxes="1 0 0 0 0 1" />
|
||||||
|
|
||||||
|
<body name="cup" pos="0 0 .6" childclass="cup">
|
||||||
|
<joint name="cup_x" axis="1 0 0"/>
|
||||||
|
<joint name="cup_z" axis="0 0 1"/>
|
||||||
|
<geom name="cup_part_0" fromto="-.05 0 0 -.05 0 -.075" />
|
||||||
|
<geom name="cup_part_1" fromto="-.05 0 -.075 -.025 0 -.1" />
|
||||||
|
<geom name="cup_part_2" fromto="-.025 0 -.1 .025 0 -.1" />
|
||||||
|
<geom name="cup_part_3" fromto=".025 0 -.1 .05 0 -.075" />
|
||||||
|
<geom name="cup_part_4" fromto=".05 0 -.075 .05 0 0" />
|
||||||
|
<site name="cup" pos="0 0 -.108" size=".005"/>
|
||||||
|
<site name="target" type="box" pos="0 0 -.05" size=".05 .006 .05" group="4"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="ball" pos="0 0 .2">
|
||||||
|
<joint name="ball_x" type="slide" axis="1 0 0"/>
|
||||||
|
<joint name="ball_z" type="slide" axis="0 0 1"/>
|
||||||
|
<geom name="ball" type="sphere" size=".025" material="effector"/>
|
||||||
|
<site name="ball" size=".005"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="x" joint="cup_x"/>
|
||||||
|
<motor name="z" joint="cup_z"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<tendon>
|
||||||
|
<spatial name="string" limited="true" range="0 0.3" width="0.003">
|
||||||
|
<site site="ball"/>
|
||||||
|
<site site="cup"/>
|
||||||
|
</spatial>
|
||||||
|
</tendon>
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
|
|
112
local_dm_control_suite/base.py
Executable file
112
local_dm_control_suite/base.py
Executable file
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Base class for tasks in the Control Suite."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Task(control.Task):
|
||||||
|
"""Base class for tasks in the Control Suite.
|
||||||
|
|
||||||
|
Actions are mapped directly to the states of MuJoCo actuators: each element of
|
||||||
|
the action array is used to set the control input for a single actuator. The
|
||||||
|
ordering of the actuators is the same as in the corresponding MJCF XML file.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
random: A `numpy.random.RandomState` instance. This should be used to
|
||||||
|
generate all random variables associated with the task, such as random
|
||||||
|
starting states, observation noise* etc.
|
||||||
|
|
||||||
|
*If sensor noise is enabled in the MuJoCo model then this will be generated
|
||||||
|
using MuJoCo's internal RNG, which has its own independent state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, random=None):
|
||||||
|
"""Initializes a new continuous control task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an integer
|
||||||
|
seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
if not isinstance(random, np.random.RandomState):
|
||||||
|
random = np.random.RandomState(random)
|
||||||
|
self._random = random
|
||||||
|
self._visualize_reward = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def random(self):
|
||||||
|
"""Task-specific `numpy.random.RandomState` instance."""
|
||||||
|
return self._random
|
||||||
|
|
||||||
|
def action_spec(self, physics):
|
||||||
|
"""Returns a `BoundedArraySpec` matching the `physics` actuators."""
|
||||||
|
return mujoco.action_spec(physics)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Resets geom colors to their defaults after starting a new episode.
|
||||||
|
|
||||||
|
Subclasses of `base.Task` must delegate to this method after performing
|
||||||
|
their own initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `mujoco.Physics`.
|
||||||
|
"""
|
||||||
|
self.after_step(physics)
|
||||||
|
|
||||||
|
def before_step(self, action, physics):
|
||||||
|
"""Sets the control signal for the actuators to values in `action`."""
|
||||||
|
# Support legacy internal code.
|
||||||
|
action = getattr(action, "continuous_actions", action)
|
||||||
|
physics.set_control(action)
|
||||||
|
|
||||||
|
def after_step(self, physics):
|
||||||
|
"""Modifies colors according to the reward."""
|
||||||
|
if self._visualize_reward:
|
||||||
|
reward = np.clip(self.get_reward(physics), 0.0, 1.0)
|
||||||
|
_set_reward_colors(physics, reward)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def visualize_reward(self):
|
||||||
|
return self._visualize_reward
|
||||||
|
|
||||||
|
@visualize_reward.setter
|
||||||
|
def visualize_reward(self, value):
|
||||||
|
if not isinstance(value, bool):
|
||||||
|
raise ValueError("Expected a boolean, got {}.".format(type(value)))
|
||||||
|
self._visualize_reward = value
|
||||||
|
|
||||||
|
|
||||||
|
_MATERIALS = ["self", "effector", "target"]
|
||||||
|
_DEFAULT = [name + "_default" for name in _MATERIALS]
|
||||||
|
_HIGHLIGHT = [name + "_highlight" for name in _MATERIALS]
|
||||||
|
|
||||||
|
|
||||||
|
def _set_reward_colors(physics, reward):
|
||||||
|
"""Sets the highlight, effector and target colors according to the reward."""
|
||||||
|
assert 0.0 <= reward <= 1.0
|
||||||
|
colors = physics.named.model.mat_rgba
|
||||||
|
default = colors[_DEFAULT]
|
||||||
|
highlight = colors[_HIGHLIGHT]
|
||||||
|
blend_coef = reward ** 4 # Better color distinction near high rewards.
|
||||||
|
colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default
|
230
local_dm_control_suite/cartpole.py
Executable file
230
local_dm_control_suite/cartpole.py
Executable file
@ -0,0 +1,230 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Cartpole domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
from lxml import etree
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 10
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets(num_poles=1):
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return _make_model(num_poles), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def balance(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the Cartpole Balance task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Balance(swing_up=False, sparse=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def balance_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the sparse reward variant of the Cartpole Balance task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Balance(swing_up=False, sparse=True, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the Cartpole Swing-Up task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Balance(swing_up=True, sparse=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the sparse reward variant of teh Cartpole Swing-Up task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Balance(swing_up=True, sparse=True, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def two_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the Cartpole Balance task with two poles."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets(num_poles=2))
|
||||||
|
task = Balance(swing_up=True, sparse=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def three_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None, num_poles=3,
|
||||||
|
sparse=False, environment_kwargs=None):
|
||||||
|
"""Returns the Cartpole Balance task with three or more poles."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets(num_poles=num_poles))
|
||||||
|
task = Balance(swing_up=True, sparse=sparse, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model(n_poles):
|
||||||
|
"""Generates an xml string defining a cart with `n_poles` bodies."""
|
||||||
|
xml_string = common.read_model('cartpole.xml')
|
||||||
|
if n_poles == 1:
|
||||||
|
return xml_string
|
||||||
|
mjcf = etree.fromstring(xml_string)
|
||||||
|
parent = mjcf.find('./worldbody/body/body') # Find first pole.
|
||||||
|
# Make chain of poles.
|
||||||
|
for pole_index in range(2, n_poles+1):
|
||||||
|
child = etree.Element('body', name='pole_{}'.format(pole_index),
|
||||||
|
pos='0 0 1', childclass='pole')
|
||||||
|
etree.SubElement(child, 'joint', name='hinge_{}'.format(pole_index))
|
||||||
|
etree.SubElement(child, 'geom', name='pole_{}'.format(pole_index))
|
||||||
|
parent.append(child)
|
||||||
|
parent = child
|
||||||
|
# Move plane down.
|
||||||
|
floor = mjcf.find('./worldbody/geom')
|
||||||
|
floor.set('pos', '0 0 {}'.format(1 - n_poles - .05))
|
||||||
|
# Move cameras back.
|
||||||
|
cameras = mjcf.findall('./worldbody/camera')
|
||||||
|
cameras[0].set('pos', '0 {} 1'.format(-1 - 2*n_poles))
|
||||||
|
cameras[1].set('pos', '0 {} 2'.format(-2*n_poles))
|
||||||
|
return etree.tostring(mjcf, pretty_print=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Cartpole domain."""
|
||||||
|
|
||||||
|
def cart_position(self):
|
||||||
|
"""Returns the position of the cart."""
|
||||||
|
return self.named.data.qpos['slider'][0]
|
||||||
|
|
||||||
|
def angular_vel(self):
|
||||||
|
"""Returns the angular velocity of the pole."""
|
||||||
|
return self.data.qvel[1:]
|
||||||
|
|
||||||
|
def pole_angle_cosine(self):
|
||||||
|
"""Returns the cosine of the pole angle."""
|
||||||
|
return self.named.data.xmat[2:, 'zz']
|
||||||
|
|
||||||
|
def bounded_position(self):
|
||||||
|
"""Returns the state, with pole angle split into sin/cos."""
|
||||||
|
return np.hstack((self.cart_position(),
|
||||||
|
self.named.data.xmat[2:, ['zz', 'xz']].ravel()))
|
||||||
|
|
||||||
|
|
||||||
|
class Balance(base.Task):
|
||||||
|
"""A Cartpole `Task` to balance the pole.
|
||||||
|
|
||||||
|
State is initialized either close to the target configuration or at a random
|
||||||
|
configuration.
|
||||||
|
"""
|
||||||
|
_CART_RANGE = (-.25, .25)
|
||||||
|
_ANGLE_COSINE_RANGE = (.995, 1)
|
||||||
|
|
||||||
|
def __init__(self, swing_up, sparse, random=None):
|
||||||
|
"""Initializes an instance of `Balance`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
swing_up: A `bool`, which if `True` sets the cart to the middle of the
|
||||||
|
slider and the pole pointing towards the ground. Otherwise, sets the
|
||||||
|
cart to a random position on the slider and the pole to a random
|
||||||
|
near-vertical position.
|
||||||
|
sparse: A `bool`, whether to return a sparse or a smooth reward.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._sparse = sparse
|
||||||
|
self._swing_up = swing_up
|
||||||
|
super(Balance, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Initializes the cart and pole according to `swing_up`, and in both cases
|
||||||
|
adds a small random initial velocity to break symmetry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
"""
|
||||||
|
nv = physics.model.nv
|
||||||
|
if self._swing_up:
|
||||||
|
physics.named.data.qpos['slider'] = .01*self.random.randn()
|
||||||
|
physics.named.data.qpos['hinge_1'] = np.pi + .01*self.random.randn()
|
||||||
|
physics.named.data.qpos[2:] = .1*self.random.randn(nv - 2)
|
||||||
|
else:
|
||||||
|
physics.named.data.qpos['slider'] = self.random.uniform(-.1, .1)
|
||||||
|
physics.named.data.qpos[1:] = self.random.uniform(-.034, .034, nv - 1)
|
||||||
|
physics.named.data.qvel[:] = 0.01 * self.random.randn(physics.model.nv)
|
||||||
|
super(Balance, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of the (bounded) physics state."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['position'] = physics.bounded_position()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _get_reward(self, physics, sparse):
|
||||||
|
if sparse:
|
||||||
|
cart_in_bounds = rewards.tolerance(physics.cart_position(),
|
||||||
|
self._CART_RANGE)
|
||||||
|
angle_in_bounds = rewards.tolerance(physics.pole_angle_cosine(),
|
||||||
|
self._ANGLE_COSINE_RANGE).prod()
|
||||||
|
return cart_in_bounds * angle_in_bounds
|
||||||
|
else:
|
||||||
|
upright = (physics.pole_angle_cosine() + 1) / 2
|
||||||
|
centered = rewards.tolerance(physics.cart_position(), margin=2)
|
||||||
|
centered = (1 + centered) / 2
|
||||||
|
small_control = rewards.tolerance(physics.control(), margin=1,
|
||||||
|
value_at_margin=0,
|
||||||
|
sigmoid='quadratic')[0]
|
||||||
|
small_control = (4 + small_control) / 5
|
||||||
|
small_velocity = rewards.tolerance(physics.angular_vel(), margin=5).min()
|
||||||
|
small_velocity = (1 + small_velocity) / 2
|
||||||
|
return upright.mean() * small_control * small_velocity * centered
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a sparse or a smooth reward, as specified in the constructor."""
|
||||||
|
return self._get_reward(physics, sparse=self._sparse)
|
37
local_dm_control_suite/cartpole.xml
Executable file
37
local_dm_control_suite/cartpole.xml
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
<mujoco model="cart-pole">
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<option timestep="0.01" integrator="RK4">
|
||||||
|
<flag contact="disable" energy="enable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<default class="pole">
|
||||||
|
<joint type="hinge" axis="0 1 0" damping="2e-6"/>
|
||||||
|
<geom type="capsule" fromto="0 0 0 0 0 1" size="0.045" material="self" mass=".1"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" pos="0 0 6"/>
|
||||||
|
<camera name="fixed" pos="0 -4 1" zaxis="0 -1 0"/>
|
||||||
|
<camera name="lookatcart" mode="targetbody" target="cart" pos="0 -2 2"/>
|
||||||
|
<geom name="floor" pos="0 0 -.05" size="4 4 .2" type="plane" material="grid"/>
|
||||||
|
<geom name="rail1" type="capsule" pos="0 .07 1" zaxis="1 0 0" size="0.02 2" material="decoration" />
|
||||||
|
<geom name="rail2" type="capsule" pos="0 -.07 1" zaxis="1 0 0" size="0.02 2" material="decoration" />
|
||||||
|
<body name="cart" pos="0 0 1">
|
||||||
|
<joint name="slider" type="slide" limited="true" axis="1 0 0" range="-1.8 1.8" solreflimit=".08 1" damping="5e-4"/>
|
||||||
|
<geom name="cart" type="box" size="0.2 0.15 0.1" material="self" mass="1"/>
|
||||||
|
<body name="pole_1" childclass="pole">
|
||||||
|
<joint name="hinge_1"/>
|
||||||
|
<geom name="pole_1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="slide" joint="slider" gear="10" ctrllimited="true" ctrlrange="-1 1" />
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
97
local_dm_control_suite/cheetah.py
Executable file
97
local_dm_control_suite/cheetah.py
Executable file
@ -0,0 +1,97 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Cheetah Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
|
||||||
|
|
||||||
|
# How long the simulation will run, in seconds.
|
||||||
|
_DEFAULT_TIME_LIMIT = 10
|
||||||
|
|
||||||
|
# Running speed above which reward is 1.
|
||||||
|
_RUN_SPEED = 10
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('cheetah.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the run task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Cheetah(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(physics, task, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Cheetah domain."""
|
||||||
|
|
||||||
|
def speed(self):
|
||||||
|
"""Returns the horizontal speed of the Cheetah."""
|
||||||
|
return self.named.data.sensordata['torso_subtreelinvel'][0]
|
||||||
|
|
||||||
|
|
||||||
|
class Cheetah(base.Task):
|
||||||
|
"""A `Task` to train a running Cheetah."""
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
# The indexing below assumes that all joints have a single DOF.
|
||||||
|
assert physics.model.nq == physics.model.njnt
|
||||||
|
is_limited = physics.model.jnt_limited == 1
|
||||||
|
lower, upper = physics.model.jnt_range[is_limited].T
|
||||||
|
physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
|
||||||
|
|
||||||
|
# Stabilize the model before the actual simulation.
|
||||||
|
for _ in range(200):
|
||||||
|
physics.step()
|
||||||
|
|
||||||
|
physics.data.time = 0
|
||||||
|
self._timeout_progress = 0
|
||||||
|
super(Cheetah, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of the state, ignoring horizontal position."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
# Ignores horizontal position to maintain translational invariance.
|
||||||
|
obs['position'] = physics.data.qpos[1:].copy()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
return rewards.tolerance(physics.speed(),
|
||||||
|
bounds=(_RUN_SPEED, float('inf')),
|
||||||
|
margin=_RUN_SPEED,
|
||||||
|
value_at_margin=0,
|
||||||
|
sigmoid='linear')
|
73
local_dm_control_suite/cheetah.xml
Executable file
73
local_dm_control_suite/cheetah.xml
Executable file
@ -0,0 +1,73 @@
|
|||||||
|
<mujoco model="cheetah">
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials_white_floor.xml"/>
|
||||||
|
|
||||||
|
<compiler settotalmass="14"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<default class="cheetah">
|
||||||
|
<joint limited="true" damping=".01" armature=".1" stiffness="8" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom contype="1" conaffinity="1" condim="3" friction=".4 .1 .1" material="self"/>
|
||||||
|
</default>
|
||||||
|
<default class="free">
|
||||||
|
<joint limited="false" damping="0" armature="0" stiffness="0"/>
|
||||||
|
</default>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1 1"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<statistic center="0 0 .7" extent="2"/>
|
||||||
|
|
||||||
|
<option timestep="0.01"/>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<geom name="ground" type="plane" conaffinity="1" pos="98 0 0" size="100 .8 .5" material="grid"/>
|
||||||
|
<body name="torso" pos="0 0 .7" childclass="cheetah">
|
||||||
|
<light name="light" pos="0 0 2" mode="trackcom"/>
|
||||||
|
<camera name="side" pos="0 -3 0" quat="0.707 0.707 0 0" mode="trackcom"/>
|
||||||
|
<camera name="back" pos="-1.8 -1.3 0.8" xyaxes="0.45 -0.9 0 0.3 0.15 0.94" mode="trackcom"/>
|
||||||
|
<joint name="rootx" type="slide" axis="1 0 0" class="free"/>
|
||||||
|
<joint name="rootz" type="slide" axis="0 0 1" class="free"/>
|
||||||
|
<joint name="rooty" type="hinge" axis="0 1 0" class="free"/>
|
||||||
|
<geom name="torso" type="capsule" fromto="-.5 0 0 .5 0 0" size="0.046"/>
|
||||||
|
<geom name="head" type="capsule" pos=".6 0 .1" euler="0 50 0" size="0.046 .15"/>
|
||||||
|
<body name="bthigh" pos="-.5 0 0">
|
||||||
|
<joint name="bthigh" range="-30 60" stiffness="240" damping="6"/>
|
||||||
|
<geom name="bthigh" type="capsule" pos=".1 0 -.13" euler="0 -218 0" size="0.046 .145"/>
|
||||||
|
<body name="bshin" pos=".16 0 -.25">
|
||||||
|
<joint name="bshin" range="-50 50" stiffness="180" damping="4.5"/>
|
||||||
|
<geom name="bshin" type="capsule" pos="-.14 0 -.07" euler="0 -116 0" size="0.046 .15"/>
|
||||||
|
<body name="bfoot" pos="-.28 0 -.14">
|
||||||
|
<joint name="bfoot" range="-230 50" stiffness="120" damping="3"/>
|
||||||
|
<geom name="bfoot" type="capsule" pos=".03 0 -.097" euler="0 -15 0" size="0.046 .094"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="fthigh" pos=".5 0 0">
|
||||||
|
<joint name="fthigh" range="-57 .40" stiffness="180" damping="4.5"/>
|
||||||
|
<geom name="fthigh" type="capsule" pos="-.07 0 -.12" euler="0 30 0" size="0.046 .133"/>
|
||||||
|
<body name="fshin" pos="-.14 0 -.24">
|
||||||
|
<joint name="fshin" range="-70 50" stiffness="120" damping="3"/>
|
||||||
|
<geom name="fshin" type="capsule" pos=".065 0 -.09" euler="0 -34 0" size="0.046 .106"/>
|
||||||
|
<body name="ffoot" pos=".13 0 -.18">
|
||||||
|
<joint name="ffoot" range="-28 28" stiffness="60" damping="1.5"/>
|
||||||
|
<geom name="ffoot" type="capsule" pos=".045 0 -.07" euler="0 -34 0" size="0.046 .07"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="bthigh" joint="bthigh" gear="120" />
|
||||||
|
<motor name="bshin" joint="bshin" gear="90" />
|
||||||
|
<motor name="bfoot" joint="bfoot" gear="60" />
|
||||||
|
<motor name="fthigh" joint="fthigh" gear="90" />
|
||||||
|
<motor name="fshin" joint="fshin" gear="60" />
|
||||||
|
<motor name="ffoot" joint="ffoot" gear="30" />
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
39
local_dm_control_suite/common/__init__.py
Executable file
39
local_dm_control_suite/common/__init__.py
Executable file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Functions to manage the common assets for domains."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dm_control.utils import io as resources
|
||||||
|
|
||||||
|
_SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
_FILENAMES = [
|
||||||
|
"./common/materials.xml",
|
||||||
|
"./common/materials_white_floor.xml",
|
||||||
|
"./common/skybox.xml",
|
||||||
|
"./common/visual.xml",
|
||||||
|
]
|
||||||
|
|
||||||
|
ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
|
||||||
|
for filename in _FILENAMES}
|
||||||
|
|
||||||
|
|
||||||
|
def read_model(model_filename):
|
||||||
|
"""Reads a model XML file and returns its contents as a string."""
|
||||||
|
return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
|
23
local_dm_control_suite/common/materials.xml
Executable file
23
local_dm_control_suite/common/materials.xml
Executable file
@ -0,0 +1,23 @@
|
|||||||
|
<!--
|
||||||
|
Common textures, colors and materials to be used throughout this suite. Some
|
||||||
|
materials such as xxx_highlight are activated on occurence of certain events,
|
||||||
|
for example receiving a positive reward.
|
||||||
|
-->
|
||||||
|
<mujoco>
|
||||||
|
<asset>
|
||||||
|
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .1 .4" rgb2=".2 .2 .8" width="300" height="300" mark="edge" markrgb=".1 .1 .4"/>
|
||||||
|
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
|
||||||
|
<material name="self" rgba=".7 .5 .3 1"/>
|
||||||
|
<material name="self_default" rgba=".7 .5 .3 1"/>
|
||||||
|
<material name="self_highlight" rgba="0 .5 .3 1"/>
|
||||||
|
<material name="effector" rgba=".7 .4 .2 1"/>
|
||||||
|
<material name="effector_default" rgba=".7 .4 .2 1"/>
|
||||||
|
<material name="effector_highlight" rgba="0 .5 .3 1"/>
|
||||||
|
<material name="decoration" rgba=".7 .5 .3 1"/>
|
||||||
|
<material name="eye" rgba="0 .2 1 1"/>
|
||||||
|
<material name="target" rgba=".6 .3 .3 1"/>
|
||||||
|
<material name="target_default" rgba=".6 .3 .3 1"/>
|
||||||
|
<material name="target_highlight" rgba=".6 .3 .3 .4"/>
|
||||||
|
<material name="site" rgba=".5 .5 .5 .3"/>
|
||||||
|
</asset>
|
||||||
|
</mujoco>
|
23
local_dm_control_suite/common/materials_white_floor.xml
Executable file
23
local_dm_control_suite/common/materials_white_floor.xml
Executable file
@ -0,0 +1,23 @@
|
|||||||
|
<!--
|
||||||
|
Common textures, colors and materials to be used throughout this suite. Some
|
||||||
|
materials such as xxx_highlight are activated on occurence of certain events,
|
||||||
|
for example receiving a positive reward.
|
||||||
|
-->
|
||||||
|
<mujoco>
|
||||||
|
<asset>
|
||||||
|
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .1 .1" rgb2=".2 .2 .2" width="300" height="300" mark="edge" markrgb=".1 .1 .1"/>
|
||||||
|
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
|
||||||
|
<material name="self" rgba=".7 .5 .3 1"/>
|
||||||
|
<material name="self_default" rgba=".7 .5 .3 1"/>
|
||||||
|
<material name="self_highlight" rgba="0 .5 .3 1"/>
|
||||||
|
<material name="effector" rgba=".7 .4 .2 1"/>
|
||||||
|
<material name="effector_default" rgba=".7 .4 .2 1"/>
|
||||||
|
<material name="effector_highlight" rgba="0 .5 .3 1"/>
|
||||||
|
<material name="decoration" rgba=".3 .5 .7 1"/>
|
||||||
|
<material name="eye" rgba="0 .2 1 1"/>
|
||||||
|
<material name="target" rgba=".6 .3 .3 1"/>
|
||||||
|
<material name="target_default" rgba=".6 .3 .3 1"/>
|
||||||
|
<material name="target_highlight" rgba=".6 .3 .3 .4"/>
|
||||||
|
<material name="site" rgba=".5 .5 .5 .3"/>
|
||||||
|
</asset>
|
||||||
|
</mujoco>
|
6
local_dm_control_suite/common/skybox.xml
Executable file
6
local_dm_control_suite/common/skybox.xml
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
<mujoco>
|
||||||
|
<asset>
|
||||||
|
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0"
|
||||||
|
width="800" height="800" mark="random" markrgb="0 0 0"/>
|
||||||
|
</asset>
|
||||||
|
</mujoco>
|
7
local_dm_control_suite/common/visual.xml
Executable file
7
local_dm_control_suite/common/visual.xml
Executable file
@ -0,0 +1,7 @@
|
|||||||
|
<mujoco>
|
||||||
|
<visual>
|
||||||
|
<headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
|
||||||
|
<map znear=".01"/>
|
||||||
|
<quality shadowsize="2048"/>
|
||||||
|
</visual>
|
||||||
|
</mujoco>
|
84
local_dm_control_suite/demos/mocap_demo.py
Executable file
84
local_dm_control_suite/demos/mocap_demo.py
Executable file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Demonstration of amc parsing for CMU mocap database.
|
||||||
|
|
||||||
|
To run the demo, supply a path to a `.amc` file:
|
||||||
|
|
||||||
|
python mocap_demo --filename='path/to/mocap.amc'
|
||||||
|
|
||||||
|
CMU motion capture clips are available at mocap.cs.cmu.edu
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
# Internal dependencies.
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
from local_dm_control_suite import humanoid_CMU
|
||||||
|
from dm_control.suite.utils import parse_amc
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_string('filename', None, 'amc file to be converted.')
|
||||||
|
flags.DEFINE_integer('max_num_frames', 90,
|
||||||
|
'Maximum number of frames for plotting/playback')
|
||||||
|
|
||||||
|
|
||||||
|
def main(unused_argv):
|
||||||
|
env = humanoid_CMU.stand()
|
||||||
|
|
||||||
|
# Parse and convert specified clip.
|
||||||
|
converted = parse_amc.convert(FLAGS.filename,
|
||||||
|
env.physics, env.control_timestep())
|
||||||
|
|
||||||
|
max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1)
|
||||||
|
|
||||||
|
width = 480
|
||||||
|
height = 480
|
||||||
|
video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
for i in range(max_frame):
|
||||||
|
p_i = converted.qpos[:, i]
|
||||||
|
with env.physics.reset_context():
|
||||||
|
env.physics.data.qpos[:] = p_i
|
||||||
|
video[i] = np.hstack([env.physics.render(height, width, camera_id=0),
|
||||||
|
env.physics.render(height, width, camera_id=1)])
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
for i in range(max_frame):
|
||||||
|
if i == 0:
|
||||||
|
img = plt.imshow(video[i])
|
||||||
|
else:
|
||||||
|
img.set_data(video[i])
|
||||||
|
toc = time.time()
|
||||||
|
clock_dt = toc - tic
|
||||||
|
tic = time.time()
|
||||||
|
# Real-time playback not always possible as clock_dt > .03
|
||||||
|
plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0.
|
||||||
|
plt.draw()
|
||||||
|
plt.waitforbuttonpress()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
flags.mark_flag_as_required('filename')
|
||||||
|
app.run(main)
|
213
local_dm_control_suite/demos/zeros.amc
Executable file
213
local_dm_control_suite/demos/zeros.amc
Executable file
@ -0,0 +1,213 @@
|
|||||||
|
#DUMMY AMC for testing
|
||||||
|
:FULLY-SPECIFIED
|
||||||
|
:DEGREES
|
||||||
|
1
|
||||||
|
root 0 0 0 0 0 0
|
||||||
|
lowerback 0 0 0
|
||||||
|
upperback 0 0 0
|
||||||
|
thorax 0 0 0
|
||||||
|
lowerneck 0 0 0
|
||||||
|
upperneck 0 0 0
|
||||||
|
head 0 0 0
|
||||||
|
rclavicle 0 0
|
||||||
|
rhumerus 0 0 0
|
||||||
|
rradius 0
|
||||||
|
rwrist 0
|
||||||
|
rhand 0 0
|
||||||
|
rfingers 0
|
||||||
|
rthumb 0 0
|
||||||
|
lclavicle 0 0
|
||||||
|
lhumerus 0 0 0
|
||||||
|
lradius 0
|
||||||
|
lwrist 0
|
||||||
|
lhand 0 0
|
||||||
|
lfingers 0
|
||||||
|
lthumb 0 0
|
||||||
|
rfemur 0 0 0
|
||||||
|
rtibia 0
|
||||||
|
rfoot 0 0
|
||||||
|
rtoes 0
|
||||||
|
lfemur 0 0 0
|
||||||
|
ltibia 0
|
||||||
|
lfoot 0 0
|
||||||
|
ltoes 0
|
||||||
|
2
|
||||||
|
root 0 0 0 0 0 0
|
||||||
|
lowerback 0 0 0
|
||||||
|
upperback 0 0 0
|
||||||
|
thorax 0 0 0
|
||||||
|
lowerneck 0 0 0
|
||||||
|
upperneck 0 0 0
|
||||||
|
head 0 0 0
|
||||||
|
rclavicle 0 0
|
||||||
|
rhumerus 0 0 0
|
||||||
|
rradius 0
|
||||||
|
rwrist 0
|
||||||
|
rhand 0 0
|
||||||
|
rfingers 0
|
||||||
|
rthumb 0 0
|
||||||
|
lclavicle 0 0
|
||||||
|
lhumerus 0 0 0
|
||||||
|
lradius 0
|
||||||
|
lwrist 0
|
||||||
|
lhand 0 0
|
||||||
|
lfingers 0
|
||||||
|
lthumb 0 0
|
||||||
|
rfemur 0 0 0
|
||||||
|
rtibia 0
|
||||||
|
rfoot 0 0
|
||||||
|
rtoes 0
|
||||||
|
lfemur 0 0 0
|
||||||
|
ltibia 0
|
||||||
|
lfoot 0 0
|
||||||
|
ltoes 0
|
||||||
|
3
|
||||||
|
root 0 0 0 0 0 0
|
||||||
|
lowerback 0 0 0
|
||||||
|
upperback 0 0 0
|
||||||
|
thorax 0 0 0
|
||||||
|
lowerneck 0 0 0
|
||||||
|
upperneck 0 0 0
|
||||||
|
head 0 0 0
|
||||||
|
rclavicle 0 0
|
||||||
|
rhumerus 0 0 0
|
||||||
|
rradius 0
|
||||||
|
rwrist 0
|
||||||
|
rhand 0 0
|
||||||
|
rfingers 0
|
||||||
|
rthumb 0 0
|
||||||
|
lclavicle 0 0
|
||||||
|
lhumerus 0 0 0
|
||||||
|
lradius 0
|
||||||
|
lwrist 0
|
||||||
|
lhand 0 0
|
||||||
|
lfingers 0
|
||||||
|
lthumb 0 0
|
||||||
|
rfemur 0 0 0
|
||||||
|
rtibia 0
|
||||||
|
rfoot 0 0
|
||||||
|
rtoes 0
|
||||||
|
lfemur 0 0 0
|
||||||
|
ltibia 0
|
||||||
|
lfoot 0 0
|
||||||
|
ltoes 0
|
||||||
|
4
|
||||||
|
root 0 0 0 0 0 0
|
||||||
|
lowerback 0 0 0
|
||||||
|
upperback 0 0 0
|
||||||
|
thorax 0 0 0
|
||||||
|
lowerneck 0 0 0
|
||||||
|
upperneck 0 0 0
|
||||||
|
head 0 0 0
|
||||||
|
rclavicle 0 0
|
||||||
|
rhumerus 0 0 0
|
||||||
|
rradius 0
|
||||||
|
rwrist 0
|
||||||
|
rhand 0 0
|
||||||
|
rfingers 0
|
||||||
|
rthumb 0 0
|
||||||
|
lclavicle 0 0
|
||||||
|
lhumerus 0 0 0
|
||||||
|
lradius 0
|
||||||
|
lwrist 0
|
||||||
|
lhand 0 0
|
||||||
|
lfingers 0
|
||||||
|
lthumb 0 0
|
||||||
|
rfemur 0 0 0
|
||||||
|
rtibia 0
|
||||||
|
rfoot 0 0
|
||||||
|
rtoes 0
|
||||||
|
lfemur 0 0 0
|
||||||
|
ltibia 0
|
||||||
|
lfoot 0 0
|
||||||
|
ltoes 0
|
||||||
|
5
|
||||||
|
root 0 0 0 0 0 0
|
||||||
|
lowerback 0 0 0
|
||||||
|
upperback 0 0 0
|
||||||
|
thorax 0 0 0
|
||||||
|
lowerneck 0 0 0
|
||||||
|
upperneck 0 0 0
|
||||||
|
head 0 0 0
|
||||||
|
rclavicle 0 0
|
||||||
|
rhumerus 0 0 0
|
||||||
|
rradius 0
|
||||||
|
rwrist 0
|
||||||
|
rhand 0 0
|
||||||
|
rfingers 0
|
||||||
|
rthumb 0 0
|
||||||
|
lclavicle 0 0
|
||||||
|
lhumerus 0 0 0
|
||||||
|
lradius 0
|
||||||
|
lwrist 0
|
||||||
|
lhand 0 0
|
||||||
|
lfingers 0
|
||||||
|
lthumb 0 0
|
||||||
|
rfemur 0 0 0
|
||||||
|
rtibia 0
|
||||||
|
rfoot 0 0
|
||||||
|
rtoes 0
|
||||||
|
lfemur 0 0 0
|
||||||
|
ltibia 0
|
||||||
|
lfoot 0 0
|
||||||
|
ltoes 0
|
||||||
|
6
|
||||||
|
root 0 0 0 0 0 0
|
||||||
|
lowerback 0 0 0
|
||||||
|
upperback 0 0 0
|
||||||
|
thorax 0 0 0
|
||||||
|
lowerneck 0 0 0
|
||||||
|
upperneck 0 0 0
|
||||||
|
head 0 0 0
|
||||||
|
rclavicle 0 0
|
||||||
|
rhumerus 0 0 0
|
||||||
|
rradius 0
|
||||||
|
rwrist 0
|
||||||
|
rhand 0 0
|
||||||
|
rfingers 0
|
||||||
|
rthumb 0 0
|
||||||
|
lclavicle 0 0
|
||||||
|
lhumerus 0 0 0
|
||||||
|
lradius 0
|
||||||
|
lwrist 0
|
||||||
|
lhand 0 0
|
||||||
|
lfingers 0
|
||||||
|
lthumb 0 0
|
||||||
|
rfemur 0 0 0
|
||||||
|
rtibia 0
|
||||||
|
rfoot 0 0
|
||||||
|
rtoes 0
|
||||||
|
lfemur 0 0 0
|
||||||
|
ltibia 0
|
||||||
|
lfoot 0 0
|
||||||
|
ltoes 0
|
||||||
|
7
|
||||||
|
root 0 0 0 0 0 0
|
||||||
|
lowerback 0 0 0
|
||||||
|
upperback 0 0 0
|
||||||
|
thorax 0 0 0
|
||||||
|
lowerneck 0 0 0
|
||||||
|
upperneck 0 0 0
|
||||||
|
head 0 0 0
|
||||||
|
rclavicle 0 0
|
||||||
|
rhumerus 0 0 0
|
||||||
|
rradius 0
|
||||||
|
rwrist 0
|
||||||
|
rhand 0 0
|
||||||
|
rfingers 0
|
||||||
|
rthumb 0 0
|
||||||
|
lclavicle 0 0
|
||||||
|
lhumerus 0 0 0
|
||||||
|
lradius 0
|
||||||
|
lwrist 0
|
||||||
|
lhand 0 0
|
||||||
|
lfingers 0
|
||||||
|
lthumb 0 0
|
||||||
|
rfemur 0 0 0
|
||||||
|
rtibia 0
|
||||||
|
rfoot 0 0
|
||||||
|
rtoes 0
|
||||||
|
lfemur 0 0 0
|
||||||
|
ltibia 0
|
||||||
|
lfoot 0 0
|
||||||
|
ltoes 0
|
84
local_dm_control_suite/explore.py
Executable file
84
local_dm_control_suite/explore.py
Executable file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2018 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Control suite environments explorer."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from dm_control import suite
|
||||||
|
from dm_control.suite.wrappers import action_noise
|
||||||
|
from six.moves import input
|
||||||
|
|
||||||
|
from dm_control import viewer
|
||||||
|
|
||||||
|
|
||||||
|
_ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS]
|
||||||
|
|
||||||
|
flags.DEFINE_enum('environment_name', None, _ALL_NAMES,
|
||||||
|
'Optional \'domain_name.task_name\' pair specifying the '
|
||||||
|
'environment to load. If unspecified a prompt will appear to '
|
||||||
|
'select one.')
|
||||||
|
flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
|
||||||
|
flags.DEFINE_bool('visualize_reward', True,
|
||||||
|
'Whether to vary the colors of geoms according to the '
|
||||||
|
'current reward value.')
|
||||||
|
flags.DEFINE_float('action_noise', 0.,
|
||||||
|
'Standard deviation of Gaussian noise to apply to actions, '
|
||||||
|
'expressed as a fraction of the max-min range for each '
|
||||||
|
'action dimension. Defaults to 0, i.e. no noise.')
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_environment_name(prompt, values):
|
||||||
|
environment_name = None
|
||||||
|
while not environment_name:
|
||||||
|
environment_name = input(prompt)
|
||||||
|
if not environment_name or values.index(environment_name) < 0:
|
||||||
|
print('"%s" is not a valid environment name.' % environment_name)
|
||||||
|
environment_name = None
|
||||||
|
return environment_name
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
environment_name = FLAGS.environment_name
|
||||||
|
if environment_name is None:
|
||||||
|
print('\n '.join(['Available environments:'] + _ALL_NAMES))
|
||||||
|
environment_name = prompt_environment_name(
|
||||||
|
'Please select an environment name: ', _ALL_NAMES)
|
||||||
|
|
||||||
|
index = _ALL_NAMES.index(environment_name)
|
||||||
|
domain_name, task_name = suite.ALL_TASKS[index]
|
||||||
|
|
||||||
|
task_kwargs = {}
|
||||||
|
if not FLAGS.timeout:
|
||||||
|
task_kwargs['time_limit'] = float('inf')
|
||||||
|
|
||||||
|
def loader():
|
||||||
|
env = suite.load(
|
||||||
|
domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs)
|
||||||
|
env.task.visualize_reward = FLAGS.visualize_reward
|
||||||
|
if FLAGS.action_noise > 0:
|
||||||
|
env = action_noise.Wrapper(env, scale=FLAGS.action_noise)
|
||||||
|
return env
|
||||||
|
|
||||||
|
viewer.launch(loader)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
217
local_dm_control_suite/finger.py
Executable file
217
local_dm_control_suite/finger.py
Executable file
@ -0,0 +1,217 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Finger Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 20 # (seconds)
|
||||||
|
_CONTROL_TIMESTEP = .02 # (seconds)
|
||||||
|
# For TURN tasks, the 'tip' geom needs to enter a spherical target of sizes:
|
||||||
|
_EASY_TARGET_SIZE = 0.07
|
||||||
|
_HARD_TARGET_SIZE = 0.03
|
||||||
|
# Initial spin velocity for the Stop task.
|
||||||
|
_INITIAL_SPIN_VELOCITY = 100
|
||||||
|
# Spinning slower than this value (radian/second) is considered stopped.
|
||||||
|
_STOP_VELOCITY = 1e-6
|
||||||
|
# Spinning faster than this value (radian/second) is considered spinning.
|
||||||
|
_SPIN_VELOCITY = 15.0
|
||||||
|
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('finger.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Spin task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Spin(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def turn_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the easy Turn task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Turn(target_radius=_EASY_TARGET_SIZE, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def turn_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the hard Turn task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Turn(target_radius=_HARD_TARGET_SIZE, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Finger domain."""
|
||||||
|
|
||||||
|
def touch(self):
|
||||||
|
"""Returns logarithmically scaled signals from the two touch sensors."""
|
||||||
|
return np.log1p(self.named.data.sensordata[['touchtop', 'touchbottom']])
|
||||||
|
|
||||||
|
def hinge_velocity(self):
|
||||||
|
"""Returns the velocity of the hinge joint."""
|
||||||
|
return self.named.data.sensordata['hinge_velocity']
|
||||||
|
|
||||||
|
def tip_position(self):
|
||||||
|
"""Returns the (x,z) position of the tip relative to the hinge."""
|
||||||
|
return (self.named.data.sensordata['tip'][[0, 2]] -
|
||||||
|
self.named.data.sensordata['spinner'][[0, 2]])
|
||||||
|
|
||||||
|
def bounded_position(self):
|
||||||
|
"""Returns the positions, with the hinge angle replaced by tip position."""
|
||||||
|
return np.hstack((self.named.data.sensordata[['proximal', 'distal']],
|
||||||
|
self.tip_position()))
|
||||||
|
|
||||||
|
def velocity(self):
|
||||||
|
"""Returns the velocities (extracted from sensordata)."""
|
||||||
|
return self.named.data.sensordata[['proximal_velocity',
|
||||||
|
'distal_velocity',
|
||||||
|
'hinge_velocity']]
|
||||||
|
|
||||||
|
def target_position(self):
|
||||||
|
"""Returns the (x,z) position of the target relative to the hinge."""
|
||||||
|
return (self.named.data.sensordata['target'][[0, 2]] -
|
||||||
|
self.named.data.sensordata['spinner'][[0, 2]])
|
||||||
|
|
||||||
|
def to_target(self):
|
||||||
|
"""Returns the vector from the tip to the target."""
|
||||||
|
return self.target_position() - self.tip_position()
|
||||||
|
|
||||||
|
def dist_to_target(self):
|
||||||
|
"""Returns the signed distance to the target surface, negative is inside."""
|
||||||
|
return (np.linalg.norm(self.to_target()) -
|
||||||
|
self.named.model.site_size['target', 0])
|
||||||
|
|
||||||
|
|
||||||
|
class Spin(base.Task):
|
||||||
|
"""A Finger `Task` to spin the stopped body."""
|
||||||
|
|
||||||
|
def __init__(self, random=None):
|
||||||
|
"""Initializes a new `Spin` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
super(Spin, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
physics.named.model.site_rgba['target', 3] = 0
|
||||||
|
physics.named.model.site_rgba['tip', 3] = 0
|
||||||
|
physics.named.model.dof_damping['hinge'] = .03
|
||||||
|
_set_random_joint_angles(physics, self.random)
|
||||||
|
super(Spin, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns state and touch sensors, and target info."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['position'] = physics.bounded_position()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
obs['touch'] = physics.touch()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a sparse reward."""
|
||||||
|
return float(physics.hinge_velocity() <= -_SPIN_VELOCITY)
|
||||||
|
|
||||||
|
|
||||||
|
class Turn(base.Task):
|
||||||
|
"""A Finger `Task` to turn the body to a target angle."""
|
||||||
|
|
||||||
|
def __init__(self, target_radius, random=None):
|
||||||
|
"""Initializes a new `Turn` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_radius: Radius of the target site, which specifies the goal angle.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._target_radius = target_radius
|
||||||
|
super(Turn, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
target_angle = self.random.uniform(-np.pi, np.pi)
|
||||||
|
hinge_x, hinge_z = physics.named.data.xanchor['hinge', ['x', 'z']]
|
||||||
|
radius = physics.named.model.geom_size['cap1'].sum()
|
||||||
|
target_x = hinge_x + radius * np.sin(target_angle)
|
||||||
|
target_z = hinge_z + radius * np.cos(target_angle)
|
||||||
|
physics.named.model.site_pos['target', ['x', 'z']] = target_x, target_z
|
||||||
|
physics.named.model.site_size['target', 0] = self._target_radius
|
||||||
|
|
||||||
|
_set_random_joint_angles(physics, self.random)
|
||||||
|
|
||||||
|
super(Turn, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns state, touch sensors, and target info."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['position'] = physics.bounded_position()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
obs['touch'] = physics.touch()
|
||||||
|
obs['target_position'] = physics.target_position()
|
||||||
|
obs['dist_to_target'] = physics.dist_to_target()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
return float(physics.dist_to_target() <= 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_random_joint_angles(physics, random, max_attempts=1000):
|
||||||
|
"""Sets the joints to a random collision-free state."""
|
||||||
|
|
||||||
|
for _ in range(max_attempts):
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, random)
|
||||||
|
# Check for collisions.
|
||||||
|
physics.after_reset()
|
||||||
|
if physics.data.ncon == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Could not find a collision-free state '
|
||||||
|
'after {} attempts'.format(max_attempts))
|
72
local_dm_control_suite/finger.xml
Executable file
72
local_dm_control_suite/finger.xml
Executable file
@ -0,0 +1,72 @@
|
|||||||
|
<mujoco model="finger">
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<option timestep="0.01" cone="elliptic" iterations="200">
|
||||||
|
<flag gravity="disable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<geom solimp="0 0.9 0.01" solref=".02 1"/>
|
||||||
|
<joint type="hinge" axis="0 -1 0"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1 1"/>
|
||||||
|
<default class="finger">
|
||||||
|
<joint damping="2.5" limited="true"/>
|
||||||
|
<site type="ellipsoid" size=".025 .03 .025" material="site" group="3"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" directional="true" diffuse=".6 .6 .6" pos="0 0 2" specular=".3 .3 .3"/>
|
||||||
|
<geom name="ground" type="plane" pos="0 0 0" size=".6 .2 10" material="grid"/>
|
||||||
|
<camera name="cam0" pos="0 -1 .8" xyaxes="1 0 0 0 1 2"/>
|
||||||
|
<camera name="cam1" pos="0 -1 .4" xyaxes="1 0 0 0 0 1" />
|
||||||
|
|
||||||
|
<body name="proximal" pos="-.2 0 .4" childclass="finger">
|
||||||
|
<geom name="proximal_decoration" type="cylinder" fromto="0 -.033 0 0 .033 0" size=".034" material="decoration"/>
|
||||||
|
<joint name="proximal" range="-110 110" ref="-90"/>
|
||||||
|
<geom name="proximal" type="capsule" material="self" size=".03" fromto="0 0 0 0 0 -.17"/>
|
||||||
|
<body name="distal" pos="0 0 -.18" childclass="finger">
|
||||||
|
<joint name="distal" range="-110 110"/>
|
||||||
|
<geom name="distal" type="capsule" size=".028" material="self" fromto="0 0 0 0 0 -.16" contype="0" conaffinity="0"/>
|
||||||
|
<geom name="fingertip" type="capsule" size=".03" material="effector" fromto="0 0 -.13 0 0 -.161"/>
|
||||||
|
<site name="touchtop" pos=".01 0 -.17"/>
|
||||||
|
<site name="touchbottom" pos="-.01 0 -.17"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="spinner" pos=".2 0 .4">
|
||||||
|
<joint name="hinge" frictionloss=".1" damping=".5"/>
|
||||||
|
<geom name="cap1" type="capsule" size=".04 .09" material="self" pos=".02 0 0"/>
|
||||||
|
<geom name="cap2" type="capsule" size=".04 .09" material="self" pos="-.02 0 0"/>
|
||||||
|
<site name="tip" type="sphere" size=".02" pos="0 0 .13" material="target"/>
|
||||||
|
<geom name="spinner_decoration" type="cylinder" fromto="0 -.045 0 0 .045 0" size=".02" material="decoration"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<site name="target" type="sphere" size=".03" pos="0 0 .4" material="target"/>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="proximal" joint="proximal" gear="30"/>
|
||||||
|
<motor name="distal" joint="distal" gear="15"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<!-- All finger observations are functions of sensors. This is useful for finite-differencing. -->
|
||||||
|
<sensor>
|
||||||
|
<jointpos name="proximal" joint="proximal"/>
|
||||||
|
<jointpos name="distal" joint="distal"/>
|
||||||
|
<jointvel name="proximal_velocity" joint="proximal"/>
|
||||||
|
<jointvel name="distal_velocity" joint="distal"/>
|
||||||
|
<jointvel name="hinge_velocity" joint="hinge"/>
|
||||||
|
<framepos name="tip" objtype="site" objname="tip"/>
|
||||||
|
<framepos name="target" objtype="site" objname="target"/>
|
||||||
|
<framepos name="spinner" objtype="xbody" objname="spinner"/>
|
||||||
|
<touch name="touchtop" site="touchtop"/>
|
||||||
|
<touch name="touchbottom" site="touchbottom"/>
|
||||||
|
<framepos name="touchtop_pos" objtype="site" objname="touchtop"/>
|
||||||
|
<framepos name="touchbottom_pos" objtype="site" objname="touchbottom"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
|
|
176
local_dm_control_suite/fish.py
Executable file
176
local_dm_control_suite/fish.py
Executable file
@ -0,0 +1,176 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Fish Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 40
|
||||||
|
_CONTROL_TIMESTEP = .04
|
||||||
|
_JOINTS = ['tail1',
|
||||||
|
'tail_twist',
|
||||||
|
'tail2',
|
||||||
|
'finright_roll',
|
||||||
|
'finright_pitch',
|
||||||
|
'finleft_roll',
|
||||||
|
'finleft_pitch']
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('fish.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def upright(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the Fish Upright task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Upright(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swim(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Fish Swim task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Swim(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Fish domain."""
|
||||||
|
|
||||||
|
def upright(self):
|
||||||
|
"""Returns projection from z-axes of torso to the z-axes of worldbody."""
|
||||||
|
return self.named.data.xmat['torso', 'zz']
|
||||||
|
|
||||||
|
def torso_velocity(self):
|
||||||
|
"""Returns velocities and angular velocities of the torso."""
|
||||||
|
return self.data.sensordata
|
||||||
|
|
||||||
|
def joint_velocities(self):
|
||||||
|
"""Returns the joint velocities."""
|
||||||
|
return self.named.data.qvel[_JOINTS]
|
||||||
|
|
||||||
|
def joint_angles(self):
|
||||||
|
"""Returns the joint positions."""
|
||||||
|
return self.named.data.qpos[_JOINTS]
|
||||||
|
|
||||||
|
def mouth_to_target(self):
|
||||||
|
"""Returns a vector, from mouth to target in local coordinate of mouth."""
|
||||||
|
data = self.named.data
|
||||||
|
mouth_to_target_global = data.geom_xpos['target'] - data.geom_xpos['mouth']
|
||||||
|
return mouth_to_target_global.dot(data.geom_xmat['mouth'].reshape(3, 3))
|
||||||
|
|
||||||
|
|
||||||
|
class Upright(base.Task):
|
||||||
|
"""A Fish `Task` for getting the torso upright with smooth reward."""
|
||||||
|
|
||||||
|
def __init__(self, random=None):
|
||||||
|
"""Initializes an instance of `Upright`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random: Either an existing `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically.
|
||||||
|
"""
|
||||||
|
super(Upright, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Randomizes the tail and fin angles and the orientation of the Fish."""
|
||||||
|
quat = self.random.randn(4)
|
||||||
|
physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
|
||||||
|
for joint in _JOINTS:
|
||||||
|
physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
|
||||||
|
# Hide the target. It's irrelevant for this task.
|
||||||
|
physics.named.model.geom_rgba['target', 3] = 0
|
||||||
|
super(Upright, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of joint angles, velocities and uprightness."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['joint_angles'] = physics.joint_angles()
|
||||||
|
obs['upright'] = physics.upright()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a smooth reward."""
|
||||||
|
return rewards.tolerance(physics.upright(), bounds=(1, 1), margin=1)
|
||||||
|
|
||||||
|
|
||||||
|
class Swim(base.Task):
|
||||||
|
"""A Fish `Task` for swimming with smooth reward."""
|
||||||
|
|
||||||
|
def __init__(self, random=None):
|
||||||
|
"""Initializes an instance of `Swim`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
super(Swim, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
|
||||||
|
quat = self.random.randn(4)
|
||||||
|
physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
|
||||||
|
for joint in _JOINTS:
|
||||||
|
physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
|
||||||
|
# Randomize target position.
|
||||||
|
physics.named.model.geom_pos['target', 'x'] = self.random.uniform(-.4, .4)
|
||||||
|
physics.named.model.geom_pos['target', 'y'] = self.random.uniform(-.4, .4)
|
||||||
|
physics.named.model.geom_pos['target', 'z'] = self.random.uniform(.1, .3)
|
||||||
|
super(Swim, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of joints, target direction and velocities."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['joint_angles'] = physics.joint_angles()
|
||||||
|
obs['upright'] = physics.upright()
|
||||||
|
obs['target'] = physics.mouth_to_target()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a smooth reward."""
|
||||||
|
radii = physics.named.model.geom_size[['mouth', 'target'], 0].sum()
|
||||||
|
in_target = rewards.tolerance(np.linalg.norm(physics.mouth_to_target()),
|
||||||
|
bounds=(0, radii), margin=2*radii)
|
||||||
|
is_upright = 0.5 * (physics.upright() + 1)
|
||||||
|
return (7*in_target + is_upright) / 8
|
85
local_dm_control_suite/fish.xml
Executable file
85
local_dm_control_suite/fish.xml
Executable file
@ -0,0 +1,85 @@
|
|||||||
|
<mujoco model="fish">
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
<asset>
|
||||||
|
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
|
||||||
|
<option timestep="0.004" density="5000">
|
||||||
|
<flag gravity="disable" constraint="disable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<general ctrllimited="true"/>
|
||||||
|
<default class="fish">
|
||||||
|
<joint type="hinge" limited="false" range="-60 60" damping="2e-5" solreflimit=".1 1" solimplimit="0 .8 .1"/>
|
||||||
|
<geom material="self"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<camera name="tracking_top" pos="0 0 1" xyaxes="1 0 0 0 1 0" mode="trackcom"/>
|
||||||
|
<camera name="tracking_x" pos="-.3 0 .2" xyaxes="0 -1 0 0.342 0 0.940" fovy="60" mode="trackcom"/>
|
||||||
|
<camera name="tracking_y" pos="0 -.3 .2" xyaxes="1 0 0 0 0.342 0.940" fovy="60" mode="trackcom"/>
|
||||||
|
<camera name="fixed_top" pos="0 0 5.5" fovy="10"/>
|
||||||
|
<geom name="ground" type="plane" size=".5 .5 .1" material="grid"/>
|
||||||
|
<geom name="target" type="sphere" pos="0 .4 .1" size=".04" material="target"/>
|
||||||
|
<body name="torso" pos="0 0 .1" childclass="fish">
|
||||||
|
<light name="light" diffuse=".6 .6 .6" pos="0 0 0.5" dir="0 0 -1" specular=".3 .3 .3" mode="track"/>
|
||||||
|
<joint name="root" type="free" damping="0" limited="false"/>
|
||||||
|
<site name="torso" size=".01" rgba="0 0 0 0"/>
|
||||||
|
<geom name="eye" type="ellipsoid" pos="0 .055 .015" size=".008 .012 .008" euler="-10 0 0" material="eye" mass="0"/>
|
||||||
|
<camera name="eye" pos="0 .06 .02" xyaxes="1 0 0 0 0 1"/>
|
||||||
|
<geom name="mouth" type="capsule" fromto="0 .079 0 0 .07 0" size=".005" material="effector" mass="0"/>
|
||||||
|
<geom name="lower_mouth" type="capsule" fromto="0 .079 -.004 0 .07 -.003" size=".0045" material="effector" mass="0"/>
|
||||||
|
<geom name="torso" type="ellipsoid" size=".01 .08 .04" mass="0"/>
|
||||||
|
<geom name="back_fin" type="ellipsoid" size=".001 .03 .015" pos="0 -.03 .03" material="effector" mass="0"/>
|
||||||
|
<geom name="torso_massive" type="box" size=".002 .06 .03" group="4"/>
|
||||||
|
<body name="tail1" pos="0 -.09 0">
|
||||||
|
<joint name="tail1" axis="0 0 1" pos="0 .01 0"/>
|
||||||
|
<joint name="tail_twist" axis="0 1 0" pos="0 .01 0" range="-30 30"/>
|
||||||
|
<geom name="tail1" type="ellipsoid" size=".001 .008 .016"/>
|
||||||
|
<body name="tail2" pos="0 -.028 0">
|
||||||
|
<joint name="tail2" axis="0 0 1" pos="0 .02 0" stiffness="8e-5"/>
|
||||||
|
<geom name="tail2" type="ellipsoid" size=".001 .018 .035"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="finright" pos=".01 0 0">
|
||||||
|
<joint name="finright_roll" axis="0 1 0"/>
|
||||||
|
<joint name="finright_pitch" axis="1 0 0" pos="0 .005 0"/>
|
||||||
|
<geom name="finright" type="ellipsoid" pos=".015 0 0" size=".02 .015 .001" />
|
||||||
|
</body>
|
||||||
|
<body name="finleft" pos="-.01 0 0">
|
||||||
|
<joint name="finleft_roll" axis="0 1 0"/>
|
||||||
|
<joint name="finleft_pitch" axis="1 0 0" pos="0 .005 0"/>
|
||||||
|
<geom name="finleft" type="ellipsoid" pos="-.015 0 0" size=".02 .015 .001"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<tendon>
|
||||||
|
<fixed name="fins_flap">
|
||||||
|
<joint joint="finleft_roll" coef="-.5"/>
|
||||||
|
<joint joint="finright_roll" coef=".5"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="fins_sym" stiffness="1e-4">
|
||||||
|
<joint joint="finleft_roll" coef=".5"/>
|
||||||
|
<joint joint="finright_roll" coef=".5"/>
|
||||||
|
</fixed>
|
||||||
|
</tendon>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<position name="tail" joint="tail1" ctrlrange="-1 1" kp="5e-4"/>
|
||||||
|
<position name="tail_twist" joint="tail_twist" ctrlrange="-1 1" kp="1e-4"/>
|
||||||
|
<position name="fins_flap" tendon="fins_flap" ctrlrange="-1 1" kp="3e-4"/>
|
||||||
|
<position name="finleft_pitch" joint="finleft_pitch" ctrlrange="-1 1" kp="1e-4"/>
|
||||||
|
<position name="finright_pitch" joint="finright_pitch" ctrlrange="-1 1" kp="1e-4"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<velocimeter name="velocimeter" site="torso"/>
|
||||||
|
<gyro name="gyro" site="torso"/>
|
||||||
|
</sensor>
|
||||||
|
</mujoco>
|
||||||
|
|
138
local_dm_control_suite/hopper.py
Executable file
138
local_dm_control_suite/hopper.py
Executable file
@ -0,0 +1,138 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Hopper domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
_CONTROL_TIMESTEP = .02 # (Seconds)
|
||||||
|
|
||||||
|
# Default duration of an episode, in seconds.
|
||||||
|
_DEFAULT_TIME_LIMIT = 20
|
||||||
|
|
||||||
|
# Minimal height of torso over foot above which stand reward is 1.
|
||||||
|
_STAND_HEIGHT = 0.6
|
||||||
|
|
||||||
|
# Hopping speed above which hop reward is 1.
|
||||||
|
_HOP_SPEED = 2
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('hopper.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns a Hopper that strives to stand upright, balancing its pose."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Hopper(hopping=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def hop(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns a Hopper that strives to hop forward."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Hopper(hopping=True, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Hopper domain."""
|
||||||
|
|
||||||
|
def height(self):
|
||||||
|
"""Returns height of torso with respect to foot."""
|
||||||
|
return (self.named.data.xipos['torso', 'z'] -
|
||||||
|
self.named.data.xipos['foot', 'z'])
|
||||||
|
|
||||||
|
def speed(self):
|
||||||
|
"""Returns horizontal speed of the Hopper."""
|
||||||
|
return self.named.data.sensordata['torso_subtreelinvel'][0]
|
||||||
|
|
||||||
|
def touch(self):
|
||||||
|
"""Returns the signals from two foot touch sensors."""
|
||||||
|
return np.log1p(self.named.data.sensordata[['touch_toe', 'touch_heel']])
|
||||||
|
|
||||||
|
|
||||||
|
class Hopper(base.Task):
|
||||||
|
"""A Hopper's `Task` to train a standing and a jumping Hopper."""
|
||||||
|
|
||||||
|
def __init__(self, hopping, random=None):
|
||||||
|
"""Initialize an instance of `Hopper`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hopping: Boolean, if True the task is to hop forwards, otherwise it is to
|
||||||
|
balance upright.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._hopping = hopping
|
||||||
|
super(Hopper, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
|
||||||
|
self._timeout_progress = 0
|
||||||
|
super(Hopper, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of positions, velocities and touch sensors."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
# Ignores horizontal position to maintain translational invariance:
|
||||||
|
obs['position'] = physics.data.qpos[1:].copy()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
obs['touch'] = physics.touch()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward applicable to the performed task."""
|
||||||
|
standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
|
||||||
|
if self._hopping:
|
||||||
|
hopping = rewards.tolerance(physics.speed(),
|
||||||
|
bounds=(_HOP_SPEED, float('inf')),
|
||||||
|
margin=_HOP_SPEED/2,
|
||||||
|
value_at_margin=0.5,
|
||||||
|
sigmoid='linear')
|
||||||
|
return standing * hopping
|
||||||
|
else:
|
||||||
|
small_control = rewards.tolerance(physics.control(),
|
||||||
|
margin=1, value_at_margin=0,
|
||||||
|
sigmoid='quadratic').mean()
|
||||||
|
small_control = (small_control + 4) / 5
|
||||||
|
return standing * small_control
|
66
local_dm_control_suite/hopper.xml
Executable file
66
local_dm_control_suite/hopper.xml
Executable file
@ -0,0 +1,66 @@
|
|||||||
|
<mujoco model="planar hopper">
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials_white_floor.xml"/>
|
||||||
|
|
||||||
|
<statistic extent="2" center="0 0 .5"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<default class="hopper">
|
||||||
|
<joint type="hinge" axis="0 1 0" limited="true" damping=".05" armature=".2"/>
|
||||||
|
<geom type="capsule" material="self"/>
|
||||||
|
<site type="sphere" size="0.05" group="3"/>
|
||||||
|
</default>
|
||||||
|
<default class="free">
|
||||||
|
<joint limited="false" damping="0" armature="0" stiffness="0"/>
|
||||||
|
</default>
|
||||||
|
<motor ctrlrange="-1 1" ctrllimited="true"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<option timestep="0.005"/>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<camera name="cam0" pos="0 -2.8 0.8" euler="90 0 0" mode="trackcom"/>
|
||||||
|
<camera name="back" pos="-2 -.2 1.2" xyaxes="0.2 -1 0 .5 0 2" mode="trackcom"/>
|
||||||
|
<geom name="floor" type="plane" conaffinity="1" pos="48 0 0" size="50 1 .2" material="grid"/>
|
||||||
|
<body name="torso" pos="0 0 1" childclass="hopper">
|
||||||
|
<light name="top" pos="0 0 2" mode="trackcom"/>
|
||||||
|
<joint name="rootx" type="slide" axis="1 0 0" class="free"/>
|
||||||
|
<joint name="rootz" type="slide" axis="0 0 1" class="free"/>
|
||||||
|
<joint name="rooty" type="hinge" axis="0 1 0" class="free"/>
|
||||||
|
<geom name="torso" fromto="0 0 -.05 0 0 .2" size="0.0653"/>
|
||||||
|
<geom name="nose" fromto=".08 0 .13 .15 0 .14" size="0.03"/>
|
||||||
|
<body name="pelvis" pos="0 0 -.05">
|
||||||
|
<joint name="waist" range="-30 30"/>
|
||||||
|
<geom name="pelvis" fromto="0 0 0 0 0 -.15" size="0.065"/>
|
||||||
|
<body name="thigh" pos="0 0 -.2">
|
||||||
|
<joint name="hip" range="-170 10"/>
|
||||||
|
<geom name="thigh" fromto="0 0 0 0 0 -.33" size="0.04"/>
|
||||||
|
<body name="calf" pos="0 0 -.33">
|
||||||
|
<joint name="knee" range="5 150"/>
|
||||||
|
<geom name="calf" fromto="0 0 0 0 0 -.32" size="0.03"/>
|
||||||
|
<body name="foot" pos="0 0 -.32">
|
||||||
|
<joint name="ankle" range="-45 45"/>
|
||||||
|
<geom name="foot" fromto="-.08 0 0 .17 0 0" size="0.04"/>
|
||||||
|
<site name="touch_toe" pos=".17 0 0"/>
|
||||||
|
<site name="touch_heel" pos="-.08 0 0"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
|
||||||
|
<touch name="touch_toe" site="touch_toe"/>
|
||||||
|
<touch name="touch_heel" site="touch_heel"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="waist" joint="waist" gear="30"/>
|
||||||
|
<motor name="hip" joint="hip" gear="40"/>
|
||||||
|
<motor name="knee" joint="knee" gear="30"/>
|
||||||
|
<motor name="ankle" joint="ankle" gear="10"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
211
local_dm_control_suite/humanoid.py
Executable file
211
local_dm_control_suite/humanoid.py
Executable file
@ -0,0 +1,211 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Humanoid Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 25
|
||||||
|
_CONTROL_TIMESTEP = .025
|
||||||
|
|
||||||
|
# Height of head above which stand reward is 1.
|
||||||
|
_STAND_HEIGHT = 1.4
|
||||||
|
|
||||||
|
# Horizontal speeds above which move reward is 1.
|
||||||
|
_WALK_SPEED = 1
|
||||||
|
_RUN_SPEED = 10
|
||||||
|
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('humanoid.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Stand task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Humanoid(move_speed=0, pure_state=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Walk task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Humanoid(move_speed=_WALK_SPEED, pure_state=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Run task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Humanoid(move_speed=_RUN_SPEED, pure_state=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def run_pure_state(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the Run task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Humanoid(move_speed=_RUN_SPEED, pure_state=True, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Walker domain."""
|
||||||
|
|
||||||
|
def torso_upright(self):
|
||||||
|
"""Returns projection from z-axes of torso to the z-axes of world."""
|
||||||
|
return self.named.data.xmat['torso', 'zz']
|
||||||
|
|
||||||
|
def head_height(self):
|
||||||
|
"""Returns the height of the torso."""
|
||||||
|
return self.named.data.xpos['head', 'z']
|
||||||
|
|
||||||
|
def center_of_mass_position(self):
|
||||||
|
"""Returns position of the center-of-mass."""
|
||||||
|
return self.named.data.subtree_com['torso'].copy()
|
||||||
|
|
||||||
|
def center_of_mass_velocity(self):
|
||||||
|
"""Returns the velocity of the center-of-mass."""
|
||||||
|
return self.named.data.sensordata['torso_subtreelinvel'].copy()
|
||||||
|
|
||||||
|
def torso_vertical_orientation(self):
|
||||||
|
"""Returns the z-projection of the torso orientation matrix."""
|
||||||
|
return self.named.data.xmat['torso', ['zx', 'zy', 'zz']]
|
||||||
|
|
||||||
|
def joint_angles(self):
|
||||||
|
"""Returns the state without global orientation or position."""
|
||||||
|
return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
|
||||||
|
|
||||||
|
def extremities(self):
|
||||||
|
"""Returns end effector positions in egocentric frame."""
|
||||||
|
torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
|
||||||
|
torso_pos = self.named.data.xpos['torso']
|
||||||
|
positions = []
|
||||||
|
for side in ('left_', 'right_'):
|
||||||
|
for limb in ('hand', 'foot'):
|
||||||
|
torso_to_limb = self.named.data.xpos[side + limb] - torso_pos
|
||||||
|
positions.append(torso_to_limb.dot(torso_frame))
|
||||||
|
return np.hstack(positions)
|
||||||
|
|
||||||
|
|
||||||
|
class Humanoid(base.Task):
|
||||||
|
"""A humanoid task."""
|
||||||
|
|
||||||
|
def __init__(self, move_speed, pure_state, random=None):
|
||||||
|
"""Initializes an instance of `Humanoid`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
move_speed: A float. If this value is zero, reward is given simply for
|
||||||
|
standing up. Otherwise this specifies a target horizontal velocity for
|
||||||
|
the walking task.
|
||||||
|
pure_state: A bool. Whether the observations consist of the pure MuJoCo
|
||||||
|
state or includes some useful features thereof.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._move_speed = move_speed
|
||||||
|
self._pure_state = pure_state
|
||||||
|
super(Humanoid, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Find a collision-free random initial configuration.
|
||||||
|
penetrating = True
|
||||||
|
while penetrating:
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
|
||||||
|
# Check for collisions.
|
||||||
|
physics.after_reset()
|
||||||
|
penetrating = physics.data.ncon > 0
|
||||||
|
super(Humanoid, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns either the pure state or a set of egocentric features."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
if self._pure_state:
|
||||||
|
obs['position'] = physics.position()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
else:
|
||||||
|
obs['joint_angles'] = physics.joint_angles()
|
||||||
|
obs['head_height'] = physics.head_height()
|
||||||
|
obs['extremities'] = physics.extremities()
|
||||||
|
obs['torso_vertical'] = physics.torso_vertical_orientation()
|
||||||
|
obs['com_velocity'] = physics.center_of_mass_velocity()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
standing = rewards.tolerance(physics.head_height(),
|
||||||
|
bounds=(_STAND_HEIGHT, float('inf')),
|
||||||
|
margin=_STAND_HEIGHT/4)
|
||||||
|
upright = rewards.tolerance(physics.torso_upright(),
|
||||||
|
bounds=(0.9, float('inf')), sigmoid='linear',
|
||||||
|
margin=1.9, value_at_margin=0)
|
||||||
|
stand_reward = standing * upright
|
||||||
|
small_control = rewards.tolerance(physics.control(), margin=1,
|
||||||
|
value_at_margin=0,
|
||||||
|
sigmoid='quadratic').mean()
|
||||||
|
small_control = (4 + small_control) / 5
|
||||||
|
if self._move_speed == 0:
|
||||||
|
horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]]
|
||||||
|
dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean()
|
||||||
|
return small_control * stand_reward * dont_move
|
||||||
|
else:
|
||||||
|
com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]])
|
||||||
|
move = rewards.tolerance(com_velocity,
|
||||||
|
bounds=(self._move_speed, float('inf')),
|
||||||
|
margin=self._move_speed, value_at_margin=0,
|
||||||
|
sigmoid='linear')
|
||||||
|
move = (5*move + 1) / 6
|
||||||
|
return small_control * stand_reward * move
|
202
local_dm_control_suite/humanoid.xml
Executable file
202
local_dm_control_suite/humanoid.xml
Executable file
@ -0,0 +1,202 @@
|
|||||||
|
<mujoco model="humanoid">
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<statistic extent="2" center="0 0 1"/>
|
||||||
|
|
||||||
|
<option timestep=".005"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<motor ctrlrange="-1 1" ctrllimited="true"/>
|
||||||
|
<default class="body">
|
||||||
|
<geom type="capsule" condim="1" friction=".7" solimp=".9 .99 .003" solref=".015 1" material="self"/>
|
||||||
|
<joint type="hinge" damping=".2" stiffness="1" armature=".01" limited="true" solimplimit="0 .99 .01"/>
|
||||||
|
<default class="big_joint">
|
||||||
|
<joint damping="5" stiffness="10"/>
|
||||||
|
<default class="big_stiff_joint">
|
||||||
|
<joint stiffness="20"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
<site size=".04" group="3"/>
|
||||||
|
<default class="force-torque">
|
||||||
|
<site type="box" size=".01 .01 .02" rgba="1 0 0 1" />
|
||||||
|
</default>
|
||||||
|
<default class="touch">
|
||||||
|
<site type="capsule" rgba="0 0 1 .3"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<geom name="floor" type="plane" conaffinity="1" size="100 100 .2" material="grid"/>
|
||||||
|
<body name="torso" pos="0 0 1.5" childclass="body">
|
||||||
|
<light name="top" pos="0 0 2" mode="trackcom"/>
|
||||||
|
<camera name="back" pos="-3 0 1" xyaxes="0 -1 0 1 0 2" mode="trackcom"/>
|
||||||
|
<camera name="side" pos="0 -3 1" xyaxes="1 0 0 0 1 2" mode="trackcom"/>
|
||||||
|
<freejoint name="root"/>
|
||||||
|
<site name="root" class="force-torque"/>
|
||||||
|
<geom name="torso" fromto="0 -.07 0 0 .07 0" size=".07"/>
|
||||||
|
<geom name="upper_waist" fromto="-.01 -.06 -.12 -.01 .06 -.12" size=".06"/>
|
||||||
|
<site name="torso" class="touch" type="box" pos="0 0 -.05" size=".075 .14 .13"/>
|
||||||
|
<body name="head" pos="0 0 .19">
|
||||||
|
<geom name="head" type="sphere" size=".09"/>
|
||||||
|
<site name="head" class="touch" type="sphere" size=".091"/>
|
||||||
|
<camera name="egocentric" pos=".09 0 0" xyaxes="0 -1 0 .1 0 1" fovy="80"/>
|
||||||
|
</body>
|
||||||
|
<body name="lower_waist" pos="-.01 0 -.260" quat="1.000 0 -.002 0">
|
||||||
|
<geom name="lower_waist" fromto="0 -.06 0 0 .06 0" size=".06"/>
|
||||||
|
<site name="lower_waist" class="touch" size=".061 .06" zaxis="0 1 0"/>
|
||||||
|
<joint name="abdomen_z" pos="0 0 .065" axis="0 0 1" range="-45 45" class="big_stiff_joint"/>
|
||||||
|
<joint name="abdomen_y" pos="0 0 .065" axis="0 1 0" range="-75 30" class="big_joint"/>
|
||||||
|
<body name="pelvis" pos="0 0 -.165" quat="1.000 0 -.002 0">
|
||||||
|
<joint name="abdomen_x" pos="0 0 .1" axis="1 0 0" range="-35 35" class="big_joint"/>
|
||||||
|
<geom name="butt" fromto="-.02 -.07 0 -.02 .07 0" size=".09"/>
|
||||||
|
<site name="butt" class="touch" size=".091 .07" pos="-.02 0 0" zaxis="0 1 0"/>
|
||||||
|
<body name="right_thigh" pos="0 -.1 -.04">
|
||||||
|
<site name="right_hip" class="force-torque"/>
|
||||||
|
<joint name="right_hip_x" axis="1 0 0" range="-25 5" class="big_joint"/>
|
||||||
|
<joint name="right_hip_z" axis="0 0 1" range="-60 35" class="big_joint"/>
|
||||||
|
<joint name="right_hip_y" axis="0 1 0" range="-110 20" class="big_stiff_joint"/>
|
||||||
|
<geom name="right_thigh" fromto="0 0 0 0 .01 -.34" size=".06"/>
|
||||||
|
<site name="right_thigh" class="touch" pos="0 .005 -.17" size=".061 .17" zaxis="0 -1 34"/>
|
||||||
|
<body name="right_shin" pos="0 .01 -.403">
|
||||||
|
<site name="right_knee" class="force-torque" pos="0 0 .02"/>
|
||||||
|
<joint name="right_knee" pos="0 0 .02" axis="0 -1 0" range="-160 2"/>
|
||||||
|
<geom name="right_shin" fromto="0 0 0 0 0 -.3" size=".049"/>
|
||||||
|
<site name="right_shin" class="touch" pos="0 0 -.15" size=".05 .15"/>
|
||||||
|
<body name="right_foot" pos="0 0 -.39">
|
||||||
|
<site name="right_ankle" class="force-torque"/>
|
||||||
|
<joint name="right_ankle_y" pos="0 0 .08" axis="0 1 0" range="-50 50" stiffness="6"/>
|
||||||
|
<joint name="right_ankle_x" pos="0 0 .04" axis="1 0 .5" range="-50 50" stiffness="3"/>
|
||||||
|
<geom name="right_right_foot" fromto="-.07 -.02 0 .14 -.04 0" size=".027"/>
|
||||||
|
<geom name="left_right_foot" fromto="-.07 0 0 .14 .02 0" size=".027"/>
|
||||||
|
<site name="right_right_foot" class="touch" pos=".035 -.03 0" size=".03 .11" zaxis="21 -2 0"/>
|
||||||
|
<site name="left_right_foot" class="touch" pos=".035 .01 0" size=".03 .11" zaxis="21 2 0"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="left_thigh" pos="0 .1 -.04">
|
||||||
|
<site name="left_hip" class="force-torque"/>
|
||||||
|
<joint name="left_hip_x" axis="-1 0 0" range="-25 5" class="big_joint"/>
|
||||||
|
<joint name="left_hip_z" axis="0 0 -1" range="-60 35" class="big_joint"/>
|
||||||
|
<joint name="left_hip_y" axis="0 1 0" range="-120 20" class="big_stiff_joint"/>
|
||||||
|
<geom name="left_thigh" fromto="0 0 0 0 -.01 -.34" size=".06"/>
|
||||||
|
<site name="left_thigh" class="touch" pos="0 -.005 -.17" size=".061 .17" zaxis="0 1 34"/>
|
||||||
|
<body name="left_shin" pos="0 -.01 -.403">
|
||||||
|
<site name="left_knee" class="force-torque" pos="0 0 .02"/>
|
||||||
|
<joint name="left_knee" pos="0 0 .02" axis="0 -1 0" range="-160 2"/>
|
||||||
|
<geom name="left_shin" fromto="0 0 0 0 0 -.3" size=".049"/>
|
||||||
|
<site name="left_shin" class="touch" pos="0 0 -.15" size=".05 .15"/>
|
||||||
|
<body name="left_foot" pos="0 0 -.39">
|
||||||
|
<site name="left_ankle" class="force-torque"/>
|
||||||
|
<joint name="left_ankle_y" pos="0 0 .08" axis="0 1 0" range="-50 50" stiffness="6"/>
|
||||||
|
<joint name="left_ankle_x" pos="0 0 .04" axis="1 0 .5" range="-50 50" stiffness="3"/>
|
||||||
|
<geom name="left_left_foot" fromto="-.07 .02 0 .14 .04 0" size=".027"/>
|
||||||
|
<geom name="right_left_foot" fromto="-.07 0 0 .14 -.02 0" size=".027"/>
|
||||||
|
<site name="right_left_foot" class="touch" pos=".035 -.01 0" size=".03 .11" zaxis="21 -2 0"/>
|
||||||
|
<site name="left_left_foot" class="touch" pos=".035 .03 0" size=".03 .11" zaxis="21 2 0"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="right_upper_arm" pos="0 -.17 .06">
|
||||||
|
<joint name="right_shoulder1" axis="2 1 1" range="-85 60"/>
|
||||||
|
<joint name="right_shoulder2" axis="0 -1 1" range="-85 60"/>
|
||||||
|
<geom name="right_upper_arm" fromto="0 0 0 .16 -.16 -.16" size=".04 .16"/>
|
||||||
|
<site name="right_upper_arm" class="touch" pos=".08 -.08 -.08" size=".041 .14" zaxis="1 -1 -1"/>
|
||||||
|
<body name="right_lower_arm" pos=".18 -.18 -.18">
|
||||||
|
<joint name="right_elbow" axis="0 -1 1" range="-90 50" stiffness="0"/>
|
||||||
|
<geom name="right_lower_arm" fromto=".01 .01 .01 .17 .17 .17" size=".031"/>
|
||||||
|
<site name="right_lower_arm" class="touch" pos=".09 .09 .09" size=".032 .14" zaxis="1 1 1"/>
|
||||||
|
<body name="right_hand" pos=".18 .18 .18">
|
||||||
|
<geom name="right_hand" type="sphere" size=".04"/>
|
||||||
|
<site name="right_hand" class="touch" type="sphere" size=".041"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="left_upper_arm" pos="0 .17 .06">
|
||||||
|
<joint name="left_shoulder1" axis="2 -1 1" range="-60 85"/>
|
||||||
|
<joint name="left_shoulder2" axis="0 1 1" range="-60 85"/>
|
||||||
|
<geom name="left_upper_arm" fromto="0 0 0 .16 .16 -.16" size=".04 .16"/>
|
||||||
|
<site name="left_upper_arm" class="touch" pos=".08 .08 -.08" size=".041 .14" zaxis="1 1 -1"/>
|
||||||
|
<body name="left_lower_arm" pos=".18 .18 -.18">
|
||||||
|
<joint name="left_elbow" axis="0 -1 -1" range="-90 50" stiffness="0"/>
|
||||||
|
<geom name="left_lower_arm" fromto=".01 -.01 .01 .17 -.17 .17" size=".031"/>
|
||||||
|
<site name="left_lower_arm" class="touch" pos=".09 -.09 .09" size=".032 .14" zaxis="1 -1 1"/>
|
||||||
|
<body name="left_hand" pos=".18 -.18 .18">
|
||||||
|
<geom name="left_hand" type="sphere" size=".04"/>
|
||||||
|
<site name="left_hand" class="touch" type="sphere" size=".041"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="abdomen_y" gear="40" joint="abdomen_y"/>
|
||||||
|
<motor name="abdomen_z" gear="40" joint="abdomen_z"/>
|
||||||
|
<motor name="abdomen_x" gear="40" joint="abdomen_x"/>
|
||||||
|
<motor name="right_hip_x" gear="40" joint="right_hip_x"/>
|
||||||
|
<motor name="right_hip_z" gear="40" joint="right_hip_z"/>
|
||||||
|
<motor name="right_hip_y" gear="120" joint="right_hip_y"/>
|
||||||
|
<motor name="right_knee" gear="80" joint="right_knee"/>
|
||||||
|
<motor name="right_ankle_x" gear="20" joint="right_ankle_x"/>
|
||||||
|
<motor name="right_ankle_y" gear="20" joint="right_ankle_y"/>
|
||||||
|
<motor name="left_hip_x" gear="40" joint="left_hip_x"/>
|
||||||
|
<motor name="left_hip_z" gear="40" joint="left_hip_z"/>
|
||||||
|
<motor name="left_hip_y" gear="120" joint="left_hip_y"/>
|
||||||
|
<motor name="left_knee" gear="80" joint="left_knee"/>
|
||||||
|
<motor name="left_ankle_x" gear="20" joint="left_ankle_x"/>
|
||||||
|
<motor name="left_ankle_y" gear="20" joint="left_ankle_y"/>
|
||||||
|
<motor name="right_shoulder1" gear="20" joint="right_shoulder1"/>
|
||||||
|
<motor name="right_shoulder2" gear="20" joint="right_shoulder2"/>
|
||||||
|
<motor name="right_elbow" gear="40" joint="right_elbow"/>
|
||||||
|
<motor name="left_shoulder1" gear="20" joint="left_shoulder1"/>
|
||||||
|
<motor name="left_shoulder2" gear="20" joint="left_shoulder2"/>
|
||||||
|
<motor name="left_elbow" gear="40" joint="left_elbow"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
|
||||||
|
<accelerometer name="torso_accel" site="root"/>
|
||||||
|
<velocimeter name="torso_vel" site="root"/>
|
||||||
|
<gyro name="torso_gyro" site="root"/>
|
||||||
|
|
||||||
|
<force name="left_ankle_force" site="left_ankle"/>
|
||||||
|
<force name="right_ankle_force" site="right_ankle"/>
|
||||||
|
<force name="left_knee_force" site="left_knee"/>
|
||||||
|
<force name="right_knee_force" site="right_knee"/>
|
||||||
|
<force name="left_hip_force" site="left_hip"/>
|
||||||
|
<force name="right_hip_force" site="right_hip"/>
|
||||||
|
|
||||||
|
<torque name="left_ankle_torque" site="left_ankle"/>
|
||||||
|
<torque name="right_ankle_torque" site="right_ankle"/>
|
||||||
|
<torque name="left_knee_torque" site="left_knee"/>
|
||||||
|
<torque name="right_knee_torque" site="right_knee"/>
|
||||||
|
<torque name="left_hip_torque" site="left_hip"/>
|
||||||
|
<torque name="right_hip_torque" site="right_hip"/>
|
||||||
|
|
||||||
|
<touch name="torso_touch" site="torso"/>
|
||||||
|
<touch name="head_touch" site="head"/>
|
||||||
|
<touch name="lower_waist_touch" site="lower_waist"/>
|
||||||
|
<touch name="butt_touch" site="butt"/>
|
||||||
|
<touch name="right_thigh_touch" site="right_thigh"/>
|
||||||
|
<touch name="right_shin_touch" site="right_shin"/>
|
||||||
|
<touch name="right_right_foot_touch" site="right_right_foot"/>
|
||||||
|
<touch name="left_right_foot_touch" site="left_right_foot"/>
|
||||||
|
<touch name="left_thigh_touch" site="left_thigh"/>
|
||||||
|
<touch name="left_shin_touch" site="left_shin"/>
|
||||||
|
<touch name="right_left_foot_touch" site="right_left_foot"/>
|
||||||
|
<touch name="left_left_foot_touch" site="left_left_foot"/>
|
||||||
|
<touch name="right_upper_arm_touch" site="right_upper_arm"/>
|
||||||
|
<touch name="right_lower_arm_touch" site="right_lower_arm"/>
|
||||||
|
<touch name="right_hand_touch" site="right_hand"/>
|
||||||
|
<touch name="left_upper_arm_touch" site="left_upper_arm"/>
|
||||||
|
<touch name="left_lower_arm_touch" site="left_lower_arm"/>
|
||||||
|
<touch name="left_hand_touch" site="left_hand"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
|
|
179
local_dm_control_suite/humanoid_CMU.py
Executable file
179
local_dm_control_suite/humanoid_CMU.py
Executable file
@ -0,0 +1,179 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Humanoid_CMU Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 20
|
||||||
|
_CONTROL_TIMESTEP = 0.02
|
||||||
|
|
||||||
|
# Height of head above which stand reward is 1.
|
||||||
|
_STAND_HEIGHT = 1.4
|
||||||
|
|
||||||
|
# Horizontal speeds above which move reward is 1.
|
||||||
|
_WALK_SPEED = 1
|
||||||
|
_RUN_SPEED = 10
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('humanoid_CMU.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Stand task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = HumanoidCMU(move_speed=0, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Run task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = HumanoidCMU(move_speed=_RUN_SPEED, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the humanoid_CMU domain."""
|
||||||
|
|
||||||
|
def thorax_upright(self):
|
||||||
|
"""Returns projection from y-axes of thorax to the z-axes of world."""
|
||||||
|
return self.named.data.xmat['thorax', 'zy']
|
||||||
|
|
||||||
|
def head_height(self):
|
||||||
|
"""Returns the height of the head."""
|
||||||
|
return self.named.data.xpos['head', 'z']
|
||||||
|
|
||||||
|
def center_of_mass_position(self):
|
||||||
|
"""Returns position of the center-of-mass."""
|
||||||
|
return self.named.data.subtree_com['thorax']
|
||||||
|
|
||||||
|
def center_of_mass_velocity(self):
|
||||||
|
"""Returns the velocity of the center-of-mass."""
|
||||||
|
return self.named.data.sensordata['thorax_subtreelinvel'].copy()
|
||||||
|
|
||||||
|
def torso_vertical_orientation(self):
|
||||||
|
"""Returns the z-projection of the thorax orientation matrix."""
|
||||||
|
return self.named.data.xmat['thorax', ['zx', 'zy', 'zz']]
|
||||||
|
|
||||||
|
def joint_angles(self):
|
||||||
|
"""Returns the state without global orientation or position."""
|
||||||
|
return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
|
||||||
|
|
||||||
|
def extremities(self):
|
||||||
|
"""Returns end effector positions in egocentric frame."""
|
||||||
|
torso_frame = self.named.data.xmat['thorax'].reshape(3, 3)
|
||||||
|
torso_pos = self.named.data.xpos['thorax']
|
||||||
|
positions = []
|
||||||
|
for side in ('l', 'r'):
|
||||||
|
for limb in ('hand', 'foot'):
|
||||||
|
torso_to_limb = self.named.data.xpos[side + limb] - torso_pos
|
||||||
|
positions.append(torso_to_limb.dot(torso_frame))
|
||||||
|
return np.hstack(positions)
|
||||||
|
|
||||||
|
|
||||||
|
class HumanoidCMU(base.Task):
|
||||||
|
"""A task for the CMU Humanoid."""
|
||||||
|
|
||||||
|
def __init__(self, move_speed, random=None):
|
||||||
|
"""Initializes an instance of `Humanoid_CMU`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
move_speed: A float. If this value is zero, reward is given simply for
|
||||||
|
standing up. Otherwise this specifies a target horizontal velocity for
|
||||||
|
the walking task.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._move_speed = move_speed
|
||||||
|
super(HumanoidCMU, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets a random collision-free configuration at the start of each episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
"""
|
||||||
|
penetrating = True
|
||||||
|
while penetrating:
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(
|
||||||
|
physics, self.random)
|
||||||
|
# Check for collisions.
|
||||||
|
physics.after_reset()
|
||||||
|
penetrating = physics.data.ncon > 0
|
||||||
|
super(HumanoidCMU, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns a set of egocentric features."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['joint_angles'] = physics.joint_angles()
|
||||||
|
obs['head_height'] = physics.head_height()
|
||||||
|
obs['extremities'] = physics.extremities()
|
||||||
|
obs['torso_vertical'] = physics.torso_vertical_orientation()
|
||||||
|
obs['com_velocity'] = physics.center_of_mass_velocity()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
standing = rewards.tolerance(physics.head_height(),
|
||||||
|
bounds=(_STAND_HEIGHT, float('inf')),
|
||||||
|
margin=_STAND_HEIGHT/4)
|
||||||
|
upright = rewards.tolerance(physics.thorax_upright(),
|
||||||
|
bounds=(0.9, float('inf')), sigmoid='linear',
|
||||||
|
margin=1.9, value_at_margin=0)
|
||||||
|
stand_reward = standing * upright
|
||||||
|
small_control = rewards.tolerance(physics.control(), margin=1,
|
||||||
|
value_at_margin=0,
|
||||||
|
sigmoid='quadratic').mean()
|
||||||
|
small_control = (4 + small_control) / 5
|
||||||
|
if self._move_speed == 0:
|
||||||
|
horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]]
|
||||||
|
dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean()
|
||||||
|
return small_control * stand_reward * dont_move
|
||||||
|
else:
|
||||||
|
com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]])
|
||||||
|
move = rewards.tolerance(com_velocity,
|
||||||
|
bounds=(self._move_speed, float('inf')),
|
||||||
|
margin=self._move_speed, value_at_margin=0,
|
||||||
|
sigmoid='linear')
|
||||||
|
move = (5*move + 1) / 6
|
||||||
|
return small_control * stand_reward * move
|
289
local_dm_control_suite/humanoid_CMU.xml
Executable file
289
local_dm_control_suite/humanoid_CMU.xml
Executable file
@ -0,0 +1,289 @@
|
|||||||
|
<mujoco model="humanoid_CMU">
|
||||||
|
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<statistic extent="2" center="0 0 1"/>
|
||||||
|
|
||||||
|
<default class="main">
|
||||||
|
<joint limited="true" solimplimit="0 0.99 0.01" stiffness="0.1" armature=".01" damping="1"/>
|
||||||
|
<geom friction="0.7" solref="0.015 1" solimp="0.95 0.99 0.003"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1 1"/>
|
||||||
|
<default class="humanoid">
|
||||||
|
<geom type="capsule" material="self"/>
|
||||||
|
<default class="stiff_low">
|
||||||
|
<joint stiffness=".5" damping="4"/>
|
||||||
|
</default>
|
||||||
|
<default class="stiff_medium">
|
||||||
|
<joint stiffness="10" damping="5"/>
|
||||||
|
</default>
|
||||||
|
<default class="stiff_high">
|
||||||
|
<joint stiffness="30" damping="10"/>
|
||||||
|
</default>
|
||||||
|
<default class="touch">
|
||||||
|
<site group="3" rgba="0 0 1 .5"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<geom name="floor" type="plane" conaffinity="1" size="100 100 .2" material="grid"/>
|
||||||
|
<light name="tracking_light" pos="0 0 7" dir="0 0 -1" mode="trackcom"/>
|
||||||
|
<camera name="back" pos="0 3 2.4" xyaxes="-1 0 0 0 -1 2" mode="trackcom"/>
|
||||||
|
<camera name="side" pos="-3 0 2.4" xyaxes="0 -1 0 1 0 2" mode="trackcom"/>
|
||||||
|
<body name="root" childclass="humanoid" pos="0 0 1" euler="90 0 0">
|
||||||
|
<site name="root" size=".01" rgba="0.5 0.5 0.5 0"/>
|
||||||
|
<freejoint name="root"/>
|
||||||
|
<geom name="root_geom" size="0.09 0.06" pos="0 -0.05 0" quat="1 0 -1 0"/>
|
||||||
|
<body name="lhipjoint">
|
||||||
|
<geom name="lhipjoint" size="0.008 0.022" pos="0.051 -0.046 0.025" quat="0.5708 -0.566602 -0.594264 0"/>
|
||||||
|
<body name="lfemur" pos="0.102 -0.092 0.05" quat="1 0 0 0.17365">
|
||||||
|
<joint name="lfemurrz" axis="0 0 1" range="-60 70" class="stiff_medium"/>
|
||||||
|
<joint name="lfemurry" axis="0 1 0" range="-70 70" class="stiff_medium"/>
|
||||||
|
<joint name="lfemurrx" axis="1 0 0" range="-160 20" class="stiff_medium"/>
|
||||||
|
<geom name="lfemur" size="0.06 0.17" pos="-.01 -0.202473 0" quat="0.7 -0.7 -0.1228 -0.07"/>
|
||||||
|
<body name="ltibia" pos="0 -0.404945 0">
|
||||||
|
<joint name="ltibiarx" axis="1 0 0" range="1 170" class="stiff_low"/>
|
||||||
|
<geom name="ltibia" size="0.03 0.1825614" pos="0 -0.202846 0" quat="0.7 -0.7 -0.1228 -0.1228"/>
|
||||||
|
<geom name="lcalf" size="0.045 0.08" pos="0 -0.1 -.01" quat="0.7 -0.7 -0.1228 -0.1228"/>
|
||||||
|
<body name="lfoot" pos="0 -0.405693 0" quat="0.707107 -0.707107 0 0">
|
||||||
|
<site name="lfoot_touch" type="box" pos="-.005 -.02 -0.025" size=".04 .08 .02" euler="10 0 0" class="touch"/>
|
||||||
|
<joint name="lfootrz" axis="0 0 1" range="-70 20" class="stiff_medium"/>
|
||||||
|
<joint name="lfootrx" axis="1 0 0" range="-45 90" class="stiff_medium"/>
|
||||||
|
<geom name="lfoot0" size="0.02 0.06" pos="-0.02 -0.023 -0.01" euler="100 -2 0"/>
|
||||||
|
<geom name="lfoot1" size="0.02 0.06" pos="0 -0.023 -0.01" euler="100 0 0"/>
|
||||||
|
<geom name="lfoot2" size="0.02 0.06" pos=".01 -0.023 -0.01" euler="100 10 0"/>
|
||||||
|
<body name="ltoes" pos="0 -0.106372 -0.0227756">
|
||||||
|
<joint name="ltoesrx" axis="1 0 0" range="-90 20"/>
|
||||||
|
<geom name="ltoes0" type="sphere" size="0.02" pos="-.025 -0.01 -.01"/>
|
||||||
|
<geom name="ltoes1" type="sphere" size="0.02" pos="0 -0.005 -.01"/>
|
||||||
|
<geom name="ltoes2" type="sphere" size="0.02" pos=".02 .001 -.01"/>
|
||||||
|
<site name="ltoes_touch" type="capsule" pos="-.005 -.005 -.01" size="0.025 0.02" zaxis="1 .2 0" class="touch"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="rhipjoint">
|
||||||
|
<geom name="rhipjoint" size="0.008 0.022" pos="-0.051 -0.046 0.025" quat="0.574856 -0.547594 0.608014 0"/>
|
||||||
|
<body name="rfemur" pos="-0.102 -0.092 0.05" quat="1 0 0 -0.17365">
|
||||||
|
<joint name="rfemurrz" axis="0 0 1" range="-70 60" class="stiff_medium"/>
|
||||||
|
<joint name="rfemurry" axis="0 1 0" range="-70 70" class="stiff_medium"/>
|
||||||
|
<joint name="rfemurrx" axis="1 0 0" range="-160 20" class="stiff_medium"/>
|
||||||
|
<geom name="rfemur" size="0.06 0.17" pos=".01 -0.202473 0" quat="0.7 -0.7 0.1228 0.07"/>
|
||||||
|
<body name="rtibia" pos="0 -0.404945 0">
|
||||||
|
<joint name="rtibiarx" axis="1 0 0" range="1 170" class="stiff_low"/>
|
||||||
|
<geom name="rtibia" size="0.03 0.1825614" pos="0 -0.202846 0" quat="0.7 -0.7 0.1228 0.1228"/>
|
||||||
|
<geom name="rcalf" size="0.045 0.08" pos="0 -0.1 -.01" quat="0.7 -0.7 -0.1228 -0.1228"/>
|
||||||
|
<body name="rfoot" pos="0 -0.405693 0" quat="0.707107 -0.707107 0 0">
|
||||||
|
<site name="rfoot_touch" type="box" pos=".005 -.02 -0.025" size=".04 .08 .02" euler="10 0 0" class="touch"/>
|
||||||
|
<joint name="rfootrz" axis="0 0 1" range="-20 70" class="stiff_medium"/>
|
||||||
|
<joint name="rfootrx" axis="1 0 0" range="-45 90" class="stiff_medium"/>
|
||||||
|
<geom name="rfoot0" size="0.02 0.06" pos="0.02 -0.023 -0.01" euler="100 2 0"/>
|
||||||
|
<geom name="rfoot1" size="0.02 0.06" pos="0 -0.023 -0.01" euler="100 0 0"/>
|
||||||
|
<geom name="rfoot2" size="0.02 0.06" pos="-.01 -0.023 -0.01" euler="100 -10 0"/>
|
||||||
|
<body name="rtoes" pos="0 -0.106372 -0.0227756">
|
||||||
|
<joint name="rtoesrx" axis="1 0 0" range="-90 20"/>
|
||||||
|
<geom name="rtoes0" type="sphere" size="0.02" pos=".025 -0.01 -.01"/>
|
||||||
|
<geom name="rtoes1" type="sphere" size="0.02" pos="0 -0.005 -.01"/>
|
||||||
|
<geom name="rtoes2" type="sphere" size="0.02" pos="-.02 .001 -.01"/>
|
||||||
|
<site name="rtoes_touch" type="capsule" pos=".005 -.005 -.01" size="0.025 0.02" zaxis="1 -.2 0" class="touch"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="lowerback">
|
||||||
|
<joint name="lowerbackrz" axis="0 0 1" range="-30 30" class="stiff_high"/>
|
||||||
|
<joint name="lowerbackry" axis="0 1 0" range="-30 30" class="stiff_high"/>
|
||||||
|
<joint name="lowerbackrx" axis="1 0 0" range="-20 45" class="stiff_high"/>
|
||||||
|
<geom name="lowerback" size="0.065 0.055" pos="0 0.056 .03" quat="1 0 1 0"/>
|
||||||
|
<body name="upperback" pos="0 0.1 -0.01">
|
||||||
|
<joint name="upperbackrz" axis="0 0 1" range="-30 30" class="stiff_high"/>
|
||||||
|
<joint name="upperbackry" axis="0 1 0" range="-30 30" class="stiff_high"/>
|
||||||
|
<joint name="upperbackrx" axis="1 0 0" range="-20 45" class="stiff_high"/>
|
||||||
|
<geom name="upperback" size="0.06 0.06" pos="0 0.06 0.02" quat="1 0 1 0"/>
|
||||||
|
<body name="thorax" pos="0.000512528 0.11356 0.000936821">
|
||||||
|
<joint name="thoraxrz" axis="0 0 1" range="-30 30" class="stiff_high"/>
|
||||||
|
<joint name="thoraxry" axis="0 1 0" range="-30 30" class="stiff_high"/>
|
||||||
|
<joint name="thoraxrx" axis="1 0 0" range="-20 45" class="stiff_high"/>
|
||||||
|
<geom name="thorax" size="0.08 0.07" pos="0 0.05 0" quat="1 0 1 0"/>
|
||||||
|
<body name="lowerneck" pos="0 0.113945 0.00468037">
|
||||||
|
<joint name="lowerneckrz" axis="0 0 1" range="-30 30" class="stiff_medium"/>
|
||||||
|
<joint name="lowerneckry" axis="0 1 0" range="-30 30" class="stiff_medium"/>
|
||||||
|
<joint name="lowerneckrx" axis="1 0 0" range="-20 45" class="stiff_medium"/>
|
||||||
|
<geom name="lowerneck" size="0.08 0.02" pos="0 0.04 -.02" quat="1 1 0 0"/>
|
||||||
|
<body name="upperneck" pos="0 0.09 0.01">
|
||||||
|
<joint name="upperneckrz" axis="0 0 1" range="-30 30" class="stiff_medium"/>
|
||||||
|
<joint name="upperneckry" axis="0 1 0" range="-30 30" class="stiff_medium"/>
|
||||||
|
<joint name="upperneckrx" axis="1 0 0" range="-20 45" class="stiff_medium"/>
|
||||||
|
<geom name="upperneck" size="0.05 0.03" pos="0 0.05 0" quat=".8 1 0 0"/>
|
||||||
|
<body name="head" pos="0 0.09 0">
|
||||||
|
<camera name="egocentric" pos="0 0 0" xyaxes="-1 0 0 0 1 0" fovy="80"/>
|
||||||
|
<joint name="headrz" axis="0 0 1" range="-30 30" class="stiff_medium"/>
|
||||||
|
<joint name="headry" axis="0 1 0" range="-30 30" class="stiff_medium"/>
|
||||||
|
<joint name="headrx" axis="1 0 0" range="-20 45" class="stiff_medium"/>
|
||||||
|
<geom name="head" size="0.085 0.035" pos="0 0.11 0.03" quat="1 .9 0 0"/>
|
||||||
|
<geom name="leye" type="sphere" size="0.02" pos=" .03 0.11 0.1"/>
|
||||||
|
<geom name="reye" type="sphere" size="0.02" pos="-.03 0.11 0.1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="lclavicle" pos="0 0.113945 0.00468037">
|
||||||
|
<joint name="lclaviclerz" axis="0 0 1" range="0 20" class="stiff_high"/>
|
||||||
|
<joint name="lclaviclery" axis="0 1 0" range="-20 10" class="stiff_high"/>
|
||||||
|
<geom name="lclavicle" size="0.08 0.04" pos="0.09 0.05 -.01" quat="1 0 -1 -.4"/>
|
||||||
|
<body name="lhumerus" pos="0.183 0.076 0.01" quat="0.18 0.68 -0.68 0.18">
|
||||||
|
<joint name="lhumerusrz" axis="0 0 1" range="-90 90" class="stiff_low"/>
|
||||||
|
<joint name="lhumerusry" axis="0 1 0" range="-90 90" class="stiff_low"/>
|
||||||
|
<joint name="lhumerusrx" axis="1 0 0" range="-60 90" class="stiff_low"/>
|
||||||
|
<geom name="lhumerus" size="0.035 0.124" pos="0 -0.138 0" quat="0.612 -0.612 0.35 0.35"/>
|
||||||
|
<body name="lradius" pos="0 -0.277 0">
|
||||||
|
<joint name="lradiusrx" axis="1 0 0" range="-10 170" class="stiff_low"/>
|
||||||
|
<geom name="lradius" size="0.03 0.06" pos="0 -0.08 0" quat="0.612 -0.612 0.35 0.35"/>
|
||||||
|
<body name="lwrist" pos="0 -0.17 0" quat="-0.5 0 0.866 0">
|
||||||
|
<joint name="lwristry" axis="0 1 0" range="-180 0"/>
|
||||||
|
<geom name="lwrist" size="0.025 0.03" pos="0 -0.02 0" quat="0 0 -1 -1"/>
|
||||||
|
<body name="lhand" pos="0 -0.08 0">
|
||||||
|
<joint name="lhandrz" axis="0 0 1" range="-45 45"/>
|
||||||
|
<joint name="lhandrx" axis="1 0 0" range="-90 90"/>
|
||||||
|
<geom name="lhand" type="ellipsoid" size=".048 0.02 0.06" pos="0 -0.047 0" quat="0 0 -1 -1"/>
|
||||||
|
<body name="lfingers" pos="0 -0.08 0">
|
||||||
|
<joint name="lfingersrx" axis="1 0 0" range="0 90"/>
|
||||||
|
<geom name="lfinger0" size="0.01 0.04" pos="-.03 -0.05 0" quat="1 -1 0 0" />
|
||||||
|
<geom name="lfinger1" size="0.01 0.04" pos="-.008 -0.06 0" quat="1 -1 0 0" />
|
||||||
|
<geom name="lfinger2" size="0.009 0.04" pos=".014 -0.06 0" quat="1 -1 0 0" />
|
||||||
|
<geom name="lfinger3" size="0.008 0.04" pos=".032 -0.05 0" quat="1 -1 0 0" />
|
||||||
|
</body>
|
||||||
|
<body name="lthumb" pos="-.02 -.03 0" quat="0.92388 0 0 -0.382683">
|
||||||
|
<joint name="lthumbrz" axis="0 0 1" range="-45 45"/>
|
||||||
|
<joint name="lthumbrx" axis="1 0 0" range="0 90"/>
|
||||||
|
<geom name="lthumb" size="0.012 0.04" pos="0 -0.06 0" quat="0 0 -1 -1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="rclavicle" pos="0 0.113945 0.00468037">
|
||||||
|
<joint name="rclaviclerz" axis="0 0 1" range="-20 0" class="stiff_high"/>
|
||||||
|
<joint name="rclaviclery" axis="0 1 0" range="-10 20" class="stiff_high"/>
|
||||||
|
<geom name="rclavicle" size="0.08 0.04" pos="-.09 0.05 -.01" quat="1 0 -1 .4"/>
|
||||||
|
<body name="rhumerus" pos="-0.183 0.076 0.01" quat="0.18 0.68 0.68 -0.18">
|
||||||
|
<joint name="rhumerusrz" axis="0 0 1" range="-90 90" class="stiff_low"/>
|
||||||
|
<joint name="rhumerusry" axis="0 1 0" range="-90 90" class="stiff_low"/>
|
||||||
|
<joint name="rhumerusrx" axis="1 0 0" range="-60 90" class="stiff_low"/>
|
||||||
|
<geom name="rhumerus" size="0.035 0.124" pos="0 -0.138 0" quat="0.61 -0.61 -0.35 -0.35"/>
|
||||||
|
<body name="rradius" pos="0 -0.277 0">
|
||||||
|
<joint name="rradiusrx" axis="1 0 0" range="-10 170" class="stiff_low"/>
|
||||||
|
<geom name="rradius" size="0.03 0.06" pos="0 -0.08 0" quat="0.612 -0.612 -0.35 -0.35"/>
|
||||||
|
<body name="rwrist" pos="0 -0.17 0" quat="-0.5 0 -0.866 0">
|
||||||
|
<joint name="rwristry" axis="0 1 0" range="-180 0"/>
|
||||||
|
<geom name="rwrist" size="0.025 0.03" pos="0 -0.02 0" quat="0 0 1 1"/>
|
||||||
|
<body name="rhand" pos="0 -0.08 0">
|
||||||
|
<joint name="rhandrz" axis="0 0 1" range="-45 45"/>
|
||||||
|
<joint name="rhandrx" axis="1 0 0" range="-90 90"/>
|
||||||
|
<geom name="rhand" type="ellipsoid" size=".048 0.02 .06" pos="0 -0.047 0" quat="0 0 1 1"/>
|
||||||
|
<body name="rfingers" pos="0 -0.08 0">
|
||||||
|
<joint name="rfingersrx" axis="1 0 0" range="0 90"/>
|
||||||
|
<geom name="rfinger0" size="0.01 0.04" pos=".03 -0.05 0" quat="1 -1 0 0" />
|
||||||
|
<geom name="rfinger1" size="0.01 0.04" pos=".008 -0.06 0" quat="1 -1 0 0" />
|
||||||
|
<geom name="rfinger2" size="0.009 0.04" pos="-.014 -0.06 0" quat="1 -1 0 0" />
|
||||||
|
<geom name="rfinger3" size="0.008 0.04" pos="-.032 -0.05 0" quat="1 -1 0 0" />
|
||||||
|
</body>
|
||||||
|
<body name="rthumb" pos=".02 -.03 0" quat="0.92388 0 0 0.382683">
|
||||||
|
<joint name="rthumbrz" axis="0 0 1" range="-45 45"/>
|
||||||
|
<joint name="rthumbrx" axis="1 0 0" range="0 90"/>
|
||||||
|
<geom name="rthumb" size="0.012 0.04" pos="0 -0.06 0" quat="0 0 1 1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<contact>
|
||||||
|
<exclude body1="lclavicle" body2="rclavicle"/>
|
||||||
|
<exclude body1="lowerneck" body2="lclavicle"/>
|
||||||
|
<exclude body1="lowerneck" body2="rclavicle"/>
|
||||||
|
<exclude body1="upperneck" body2="lclavicle"/>
|
||||||
|
<exclude body1="upperneck" body2="rclavicle"/>
|
||||||
|
</contact>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="headrx" joint="headrx" gear="20"/>
|
||||||
|
<motor name="headry" joint="headry" gear="20"/>
|
||||||
|
<motor name="headrz" joint="headrz" gear="20"/>
|
||||||
|
<motor name="lclaviclery" joint="lclaviclery" gear="20"/>
|
||||||
|
<motor name="lclaviclerz" joint="lclaviclerz" gear="20"/>
|
||||||
|
<motor name="lfemurrx" joint="lfemurrx" gear="120"/>
|
||||||
|
<motor name="lfemurry" joint="lfemurry" gear="40"/>
|
||||||
|
<motor name="lfemurrz" joint="lfemurrz" gear="40"/>
|
||||||
|
<motor name="lfingersrx" joint="lfingersrx" gear="20"/>
|
||||||
|
<motor name="lfootrx" joint="lfootrx" gear="20"/>
|
||||||
|
<motor name="lfootrz" joint="lfootrz" gear="20"/>
|
||||||
|
<motor name="lhandrx" joint="lhandrx" gear="20"/>
|
||||||
|
<motor name="lhandrz" joint="lhandrz" gear="20"/>
|
||||||
|
<motor name="lhumerusrx" joint="lhumerusrx" gear="40"/>
|
||||||
|
<motor name="lhumerusry" joint="lhumerusry" gear="40"/>
|
||||||
|
<motor name="lhumerusrz" joint="lhumerusrz" gear="40"/>
|
||||||
|
<motor name="lowerbackrx" joint="lowerbackrx" gear="40"/>
|
||||||
|
<motor name="lowerbackry" joint="lowerbackry" gear="40"/>
|
||||||
|
<motor name="lowerbackrz" joint="lowerbackrz" gear="40"/>
|
||||||
|
<motor name="lowerneckrx" joint="lowerneckrx" gear="20"/>
|
||||||
|
<motor name="lowerneckry" joint="lowerneckry" gear="20"/>
|
||||||
|
<motor name="lowerneckrz" joint="lowerneckrz" gear="20"/>
|
||||||
|
<motor name="lradiusrx" joint="lradiusrx" gear="40"/>
|
||||||
|
<motor name="lthumbrx" joint="lthumbrx" gear="20"/>
|
||||||
|
<motor name="lthumbrz" joint="lthumbrz" gear="20"/>
|
||||||
|
<motor name="ltibiarx" joint="ltibiarx" gear="80"/>
|
||||||
|
<motor name="ltoesrx" joint="ltoesrx" gear="20"/>
|
||||||
|
<motor name="lwristry" joint="lwristry" gear="20"/>
|
||||||
|
<motor name="rclaviclery" joint="rclaviclery" gear="20"/>
|
||||||
|
<motor name="rclaviclerz" joint="rclaviclerz" gear="20"/>
|
||||||
|
<motor name="rfemurrx" joint="rfemurrx" gear="120"/>
|
||||||
|
<motor name="rfemurry" joint="rfemurry" gear="40"/>
|
||||||
|
<motor name="rfemurrz" joint="rfemurrz" gear="40"/>
|
||||||
|
<motor name="rfingersrx" joint="rfingersrx" gear="20"/>
|
||||||
|
<motor name="rfootrx" joint="rfootrx" gear="20"/>
|
||||||
|
<motor name="rfootrz" joint="rfootrz" gear="20"/>
|
||||||
|
<motor name="rhandrx" joint="rhandrx" gear="20"/>
|
||||||
|
<motor name="rhandrz" joint="rhandrz" gear="20"/>
|
||||||
|
<motor name="rhumerusrx" joint="rhumerusrx" gear="40"/>
|
||||||
|
<motor name="rhumerusry" joint="rhumerusry" gear="40"/>
|
||||||
|
<motor name="rhumerusrz" joint="rhumerusrz" gear="40"/>
|
||||||
|
<motor name="rradiusrx" joint="rradiusrx" gear="40"/>
|
||||||
|
<motor name="rthumbrx" joint="rthumbrx" gear="20"/>
|
||||||
|
<motor name="rthumbrz" joint="rthumbrz" gear="20"/>
|
||||||
|
<motor name="rtibiarx" joint="rtibiarx" gear="80"/>
|
||||||
|
<motor name="rtoesrx" joint="rtoesrx" gear="20"/>
|
||||||
|
<motor name="rwristry" joint="rwristry" gear="20"/>
|
||||||
|
<motor name="thoraxrx" joint="thoraxrx" gear="40"/>
|
||||||
|
<motor name="thoraxry" joint="thoraxry" gear="40"/>
|
||||||
|
<motor name="thoraxrz" joint="thoraxrz" gear="40"/>
|
||||||
|
<motor name="upperbackrx" joint="upperbackrx" gear="40"/>
|
||||||
|
<motor name="upperbackry" joint="upperbackry" gear="40"/>
|
||||||
|
<motor name="upperbackrz" joint="upperbackrz" gear="40"/>
|
||||||
|
<motor name="upperneckrx" joint="upperneckrx" gear="20"/>
|
||||||
|
<motor name="upperneckry" joint="upperneckry" gear="20"/>
|
||||||
|
<motor name="upperneckrz" joint="upperneckrz" gear="20"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<subtreelinvel name="thorax_subtreelinvel" body="thorax"/>
|
||||||
|
<velocimeter name="sensor_root_veloc" site="root"/>
|
||||||
|
<gyro name="sensor_root_gyro" site="root"/>
|
||||||
|
<accelerometer name="sensor_root_accel" site="root"/>
|
||||||
|
<touch name="sensor_touch_ltoes" site="ltoes_touch"/>
|
||||||
|
<touch name="sensor_touch_rtoes" site="rtoes_touch"/>
|
||||||
|
<touch name="sensor_touch_rfoot" site="rfoot_touch"/>
|
||||||
|
<touch name="sensor_touch_lfoot" site="lfoot_touch"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
</mujoco>
|
272
local_dm_control_suite/lqr.py
Executable file
272
local_dm_control_suite/lqr.py
Executable file
@ -0,0 +1,272 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Procedurally generated LQR domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import xml_tools
|
||||||
|
from lxml import etree
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
from dm_control.utils import io as resources
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = float('inf')
|
||||||
|
_CONTROL_COST_COEF = 0.1
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets(n_bodies, n_actuators, random):
|
||||||
|
"""Returns the model description as an XML string and a dict of assets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bodies: An int, number of bodies of the LQR.
|
||||||
|
n_actuators: An int, number of actuated bodies of the LQR. `n_actuators`
|
||||||
|
should be less or equal than `n_bodies`.
|
||||||
|
random: A `numpy.random.RandomState` instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of
|
||||||
|
`{filename: contents_string}` pairs.
|
||||||
|
"""
|
||||||
|
return _make_model(n_bodies, n_actuators, random), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def lqr_2_1(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns an LQR environment with 2 bodies of which the first is actuated."""
|
||||||
|
return _make_lqr(n_bodies=2,
|
||||||
|
n_actuators=1,
|
||||||
|
control_cost_coef=_CONTROL_COST_COEF,
|
||||||
|
time_limit=time_limit,
|
||||||
|
random=random,
|
||||||
|
environment_kwargs=environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def lqr_6_2(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns an LQR environment with 6 bodies of which first 2 are actuated."""
|
||||||
|
return _make_lqr(n_bodies=6,
|
||||||
|
n_actuators=2,
|
||||||
|
control_cost_coef=_CONTROL_COST_COEF,
|
||||||
|
time_limit=time_limit,
|
||||||
|
random=random,
|
||||||
|
environment_kwargs=environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_lqr(n_bodies, n_actuators, control_cost_coef, time_limit, random,
|
||||||
|
environment_kwargs):
|
||||||
|
"""Returns a LQR environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bodies: An int, number of bodies of the LQR.
|
||||||
|
n_actuators: An int, number of actuated bodies of the LQR. `n_actuators`
|
||||||
|
should be less or equal than `n_bodies`.
|
||||||
|
control_cost_coef: A number, the coefficient of the control cost.
|
||||||
|
time_limit: An int, maximum time for each episode in seconds.
|
||||||
|
random: Either an existing `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically.
|
||||||
|
environment_kwargs: A `dict` specifying keyword arguments for the
|
||||||
|
environment, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LQR environment with `n_bodies` bodies of which first `n_actuators` are
|
||||||
|
actuated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(random, np.random.RandomState):
|
||||||
|
random = np.random.RandomState(random)
|
||||||
|
|
||||||
|
model_string, assets = get_model_and_assets(n_bodies, n_actuators,
|
||||||
|
random=random)
|
||||||
|
physics = Physics.from_xml_string(model_string, assets=assets)
|
||||||
|
task = LQRLevel(control_cost_coef, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(physics, task, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_body(body_id, stiffness_range, damping_range, random):
|
||||||
|
"""Returns an `etree.Element` defining a body.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body_id: Id of the created body.
|
||||||
|
stiffness_range: A tuple of (stiffness_lower_bound, stiffness_uppder_bound).
|
||||||
|
The stiffness of the joint is drawn uniformly from this range.
|
||||||
|
damping_range: A tuple of (damping_lower_bound, damping_upper_bound). The
|
||||||
|
damping of the joint is drawn uniformly from this range.
|
||||||
|
random: A `numpy.random.RandomState` instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of `etree.Element`. A body element with two children: joint
|
||||||
|
and geom.
|
||||||
|
"""
|
||||||
|
body_name = 'body_{}'.format(body_id)
|
||||||
|
joint_name = 'joint_{}'.format(body_id)
|
||||||
|
geom_name = 'geom_{}'.format(body_id)
|
||||||
|
|
||||||
|
body = etree.Element('body', name=body_name)
|
||||||
|
body.set('pos', '.25 0 0')
|
||||||
|
joint = etree.SubElement(body, 'joint', name=joint_name)
|
||||||
|
body.append(etree.Element('geom', name=geom_name))
|
||||||
|
joint.set('stiffness',
|
||||||
|
str(random.uniform(stiffness_range[0], stiffness_range[1])))
|
||||||
|
joint.set('damping',
|
||||||
|
str(random.uniform(damping_range[0], damping_range[1])))
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model(n_bodies,
|
||||||
|
n_actuators,
|
||||||
|
random,
|
||||||
|
stiffness_range=(15, 25),
|
||||||
|
damping_range=(0, 0)):
|
||||||
|
"""Returns an MJCF XML string defining a model of springs and dampers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bodies: An integer, the number of bodies (DoFs) in the system.
|
||||||
|
n_actuators: An integer, the number of actuated bodies.
|
||||||
|
random: A `numpy.random.RandomState` instance.
|
||||||
|
stiffness_range: A tuple containing minimum and maximum stiffness. Each
|
||||||
|
joint's stiffness is sampled uniformly from this interval.
|
||||||
|
damping_range: A tuple containing minimum and maximum damping. Each joint's
|
||||||
|
damping is sampled uniformly from this interval.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An MJCF string describing the linear system.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of bodies or actuators is erronous.
|
||||||
|
"""
|
||||||
|
if n_bodies < 1 or n_actuators < 1:
|
||||||
|
raise ValueError('At least 1 body and 1 actuator required.')
|
||||||
|
if n_actuators > n_bodies:
|
||||||
|
raise ValueError('At most 1 actuator per body.')
|
||||||
|
|
||||||
|
file_path = os.path.join(os.path.dirname(__file__), 'lqr.xml')
|
||||||
|
with resources.GetResourceAsFile(file_path) as xml_file:
|
||||||
|
mjcf = xml_tools.parse(xml_file)
|
||||||
|
parent = mjcf.find('./worldbody')
|
||||||
|
actuator = etree.SubElement(mjcf.getroot(), 'actuator')
|
||||||
|
tendon = etree.SubElement(mjcf.getroot(), 'tendon')
|
||||||
|
|
||||||
|
for body in range(n_bodies):
|
||||||
|
# Inserting body.
|
||||||
|
child = _make_body(body, stiffness_range, damping_range, random)
|
||||||
|
site_name = 'site_{}'.format(body)
|
||||||
|
child.append(etree.Element('site', name=site_name))
|
||||||
|
|
||||||
|
if body == 0:
|
||||||
|
child.set('pos', '.25 0 .1')
|
||||||
|
# Add actuators to the first n_actuators bodies.
|
||||||
|
if body < n_actuators:
|
||||||
|
# Adding actuator.
|
||||||
|
joint_name = 'joint_{}'.format(body)
|
||||||
|
motor_name = 'motor_{}'.format(body)
|
||||||
|
child.find('joint').set('name', joint_name)
|
||||||
|
actuator.append(etree.Element('motor', name=motor_name, joint=joint_name))
|
||||||
|
|
||||||
|
# Add a tendon between consecutive bodies (for visualisation purposes only).
|
||||||
|
if body < n_bodies - 1:
|
||||||
|
child_site_name = 'site_{}'.format(body + 1)
|
||||||
|
tendon_name = 'tendon_{}'.format(body)
|
||||||
|
spatial = etree.SubElement(tendon, 'spatial', name=tendon_name)
|
||||||
|
spatial.append(etree.Element('site', site=site_name))
|
||||||
|
spatial.append(etree.Element('site', site=child_site_name))
|
||||||
|
parent.append(child)
|
||||||
|
parent = child
|
||||||
|
|
||||||
|
return etree.tostring(mjcf, pretty_print=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the LQR domain."""
|
||||||
|
|
||||||
|
def state_norm(self):
|
||||||
|
"""Returns the norm of the physics state."""
|
||||||
|
return np.linalg.norm(self.state())
|
||||||
|
|
||||||
|
|
||||||
|
class LQRLevel(base.Task):
|
||||||
|
"""A Linear Quadratic Regulator `Task`."""
|
||||||
|
|
||||||
|
_TERMINAL_TOL = 1e-6
|
||||||
|
|
||||||
|
def __init__(self, control_cost_coef, random=None):
|
||||||
|
"""Initializes an LQR level with cost = sum(states^2) + c*sum(controls^2).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_cost_coef: The coefficient of the control cost.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the control cost coefficient is not positive.
|
||||||
|
"""
|
||||||
|
if control_cost_coef <= 0:
|
||||||
|
raise ValueError('control_cost_coef must be positive.')
|
||||||
|
|
||||||
|
self._control_cost_coef = control_cost_coef
|
||||||
|
super(LQRLevel, self).__init__(random=random)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def control_cost_coef(self):
|
||||||
|
return self._control_cost_coef
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Random state sampled from a unit sphere."""
|
||||||
|
ndof = physics.model.nq
|
||||||
|
unit = self.random.randn(ndof)
|
||||||
|
physics.data.qpos[:] = np.sqrt(2) * unit / np.linalg.norm(unit)
|
||||||
|
super(LQRLevel, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of the state."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['position'] = physics.position()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a quadratic state and control reward."""
|
||||||
|
position = physics.position()
|
||||||
|
state_cost = 0.5 * np.dot(position, position)
|
||||||
|
control_signal = physics.control()
|
||||||
|
control_l2_norm = 0.5 * np.dot(control_signal, control_signal)
|
||||||
|
return 1 - (state_cost + control_l2_norm * self._control_cost_coef)
|
||||||
|
|
||||||
|
def get_evaluation(self, physics):
|
||||||
|
"""Returns a sparse evaluation reward that is not used for learning."""
|
||||||
|
return float(physics.state_norm() <= 0.01)
|
||||||
|
|
||||||
|
def get_termination(self, physics):
|
||||||
|
"""Terminates when the state norm is smaller than epsilon."""
|
||||||
|
if physics.state_norm() < self._TERMINAL_TOL:
|
||||||
|
return 0.0
|
26
local_dm_control_suite/lqr.xml
Executable file
26
local_dm_control_suite/lqr.xml
Executable file
@ -0,0 +1,26 @@
|
|||||||
|
<mujoco model="LQR">
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<option timestep=".03"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<joint type="slide" axis="0 1 0"/>
|
||||||
|
<geom type="sphere" size=".1" material="self"/>
|
||||||
|
<site size=".01"/>
|
||||||
|
<tendon width=".02" material="self"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<option>
|
||||||
|
<flag constraint="disable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" pos="0 0 2"/>
|
||||||
|
<camera name="cam0" pos="-1.428 -0.311 0.856" xyaxes="0.099 -0.995 0.000 0.350 0.035 0.936"/>
|
||||||
|
<camera name="cam1" pos="1.787 2.452 4.331" xyaxes="-1 0 0 0 -0.868 0.497"/>
|
||||||
|
<geom name="floor" size="4 1 .2" type="plane" material="grid"/>
|
||||||
|
<geom name="origin" pos="2 0 .05" size="2 .003 .05" type="box" rgba=".5 .5 .5 .5"/>
|
||||||
|
</worldbody>
|
||||||
|
</mujoco>
|
142
local_dm_control_suite/lqr_solver.py
Executable file
142
local_dm_control_suite/lqr_solver.py
Executable file
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
r"""Optimal policy for LQR levels.
|
||||||
|
|
||||||
|
LQR control problem is described in
|
||||||
|
https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
from dm_control.mujoco import wrapper
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
try:
|
||||||
|
import scipy.linalg as sp # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
sp = None
|
||||||
|
|
||||||
|
|
||||||
|
def _solve_dare(a, b, q, r):
|
||||||
|
"""Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration.
|
||||||
|
|
||||||
|
Algebraic Riccati Equation:
|
||||||
|
```none
|
||||||
|
P_{t-1} = Q + A' * P_{t} * A -
|
||||||
|
A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: A 2 dimensional numpy array, transition matrix A.
|
||||||
|
b: A 2 dimensional numpy array, control matrix B.
|
||||||
|
q: A 2 dimensional numpy array, symmetric positive definite cost matrix.
|
||||||
|
r: A 2 dimensional numpy array, symmetric positive definite cost matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array, a real symmetric matrix P which is the solution to DARE.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the computed P matrix is not symmetric and
|
||||||
|
positive-definite.
|
||||||
|
"""
|
||||||
|
p = np.eye(len(a))
|
||||||
|
for _ in range(1000000):
|
||||||
|
a_p = a.T.dot(p) # A' * P_t
|
||||||
|
a_p_b = np.dot(a_p, b) # A' * P_t * B
|
||||||
|
# Algebraic Riccati Equation.
|
||||||
|
p_next = q + np.dot(a_p, a) - a_p_b.dot(
|
||||||
|
np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T))
|
||||||
|
p_next += p_next.T
|
||||||
|
p_next *= .5
|
||||||
|
if np.abs(p - p_next).max() < 1e-12:
|
||||||
|
break
|
||||||
|
p = p_next
|
||||||
|
else:
|
||||||
|
logging.warning('DARE solver did not converge')
|
||||||
|
try:
|
||||||
|
# Check that the result is symmetric and positive-definite.
|
||||||
|
np.linalg.cholesky(p_next)
|
||||||
|
except np.linalg.LinAlgError:
|
||||||
|
raise RuntimeError('ARE solver failed: P matrix is not symmetric and '
|
||||||
|
'positive-definite.')
|
||||||
|
return p_next
|
||||||
|
|
||||||
|
|
||||||
|
def solve(env):
|
||||||
|
"""Returns the optimal value and policy for LQR problem.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: An instance of `control.EnvironmentV2` with LQR level.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
p: A numpy array, the Hessian of the optimal total cost-to-go (value
|
||||||
|
function at state x) is V(x) = .5 * x' * p * x.
|
||||||
|
k: A numpy array which gives the optimal linear policy u = k * x.
|
||||||
|
beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at
|
||||||
|
timestep n the state tends to 0 like beta^n.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the controlled system is unstable.
|
||||||
|
"""
|
||||||
|
n = env.physics.model.nq # number of DoFs
|
||||||
|
m = env.physics.model.nu # number of controls
|
||||||
|
|
||||||
|
# Compute the mass matrix.
|
||||||
|
mass = np.zeros((n, n))
|
||||||
|
wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass,
|
||||||
|
env.physics.data.qM)
|
||||||
|
|
||||||
|
# Compute input matrices a, b, q and r to the DARE solvers.
|
||||||
|
# State transition matrix a.
|
||||||
|
stiffness = np.diag(env.physics.model.jnt_stiffness.ravel())
|
||||||
|
damping = np.diag(env.physics.model.dof_damping.ravel())
|
||||||
|
dt = env.physics.model.opt.timestep
|
||||||
|
|
||||||
|
j = np.linalg.solve(-mass, np.hstack((stiffness, damping)))
|
||||||
|
a = np.eye(2 * n) + dt * np.vstack(
|
||||||
|
(dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j))
|
||||||
|
|
||||||
|
# Control transition matrix b.
|
||||||
|
b = env.physics.data.actuator_moment.T
|
||||||
|
bc = np.linalg.solve(mass, b)
|
||||||
|
b = dt * np.vstack((dt * bc, bc))
|
||||||
|
|
||||||
|
# State cost Hessian q.
|
||||||
|
q = np.diag(np.hstack([np.ones(n), np.zeros(n)]))
|
||||||
|
|
||||||
|
# Control cost Hessian r.
|
||||||
|
r = env.task.control_cost_coef * np.eye(m)
|
||||||
|
|
||||||
|
if sp:
|
||||||
|
# Use scipy's faster DARE solver if available.
|
||||||
|
solve_dare = sp.solve_discrete_are
|
||||||
|
else:
|
||||||
|
# Otherwise fall back on a slower internal implementation.
|
||||||
|
solve_dare = _solve_dare
|
||||||
|
|
||||||
|
# Solve the discrete algebraic Riccati equation.
|
||||||
|
p = solve_dare(a, b, q, r)
|
||||||
|
k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a)))
|
||||||
|
|
||||||
|
# Under optimal policy, state tends to 0 like beta^n_timesteps
|
||||||
|
beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max()
|
||||||
|
if beta >= 1.0:
|
||||||
|
raise RuntimeError('Controlled system is unstable.')
|
||||||
|
return p, k, beta
|
290
local_dm_control_suite/manipulator.py
Executable file
290
local_dm_control_suite/manipulator.py
Executable file
@ -0,0 +1,290 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Planar Manipulator domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
from dm_control.utils import xml_tools
|
||||||
|
|
||||||
|
from lxml import etree
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_CLOSE = .01 # (Meters) Distance below which a thing is considered close.
|
||||||
|
_CONTROL_TIMESTEP = .01 # (Seconds)
|
||||||
|
_TIME_LIMIT = 10 # (Seconds)
|
||||||
|
_P_IN_HAND = .1 # Probabillity of object-in-hand initial state
|
||||||
|
_P_IN_TARGET = .1 # Probabillity of object-in-target initial state
|
||||||
|
_ARM_JOINTS = ['arm_root', 'arm_shoulder', 'arm_elbow', 'arm_wrist',
|
||||||
|
'finger', 'fingertip', 'thumb', 'thumbtip']
|
||||||
|
_ALL_PROPS = frozenset(['ball', 'target_ball', 'cup',
|
||||||
|
'peg', 'target_peg', 'slot'])
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def make_model(use_peg, insert):
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
xml_string = common.read_model('manipulator.xml')
|
||||||
|
parser = etree.XMLParser(remove_blank_text=True)
|
||||||
|
mjcf = etree.XML(xml_string, parser)
|
||||||
|
|
||||||
|
# Select the desired prop.
|
||||||
|
if use_peg:
|
||||||
|
required_props = ['peg', 'target_peg']
|
||||||
|
if insert:
|
||||||
|
required_props += ['slot']
|
||||||
|
else:
|
||||||
|
required_props = ['ball', 'target_ball']
|
||||||
|
if insert:
|
||||||
|
required_props += ['cup']
|
||||||
|
|
||||||
|
# Remove unused props
|
||||||
|
for unused_prop in _ALL_PROPS.difference(required_props):
|
||||||
|
prop = xml_tools.find_element(mjcf, 'body', unused_prop)
|
||||||
|
prop.getparent().remove(prop)
|
||||||
|
|
||||||
|
return etree.tostring(mjcf, pretty_print=True), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking', 'hard')
|
||||||
|
def bring_ball(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns manipulator bring task with the ball prop."""
|
||||||
|
use_peg = False
|
||||||
|
insert = False
|
||||||
|
physics = Physics.from_xml_string(*make_model(use_peg, insert))
|
||||||
|
task = Bring(use_peg=use_peg, insert=insert,
|
||||||
|
fully_observable=fully_observable, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('hard')
|
||||||
|
def bring_peg(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns manipulator bring task with the peg prop."""
|
||||||
|
use_peg = True
|
||||||
|
insert = False
|
||||||
|
physics = Physics.from_xml_string(*make_model(use_peg, insert))
|
||||||
|
task = Bring(use_peg=use_peg, insert=insert,
|
||||||
|
fully_observable=fully_observable, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('hard')
|
||||||
|
def insert_ball(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns manipulator insert task with the ball prop."""
|
||||||
|
use_peg = False
|
||||||
|
insert = True
|
||||||
|
physics = Physics.from_xml_string(*make_model(use_peg, insert))
|
||||||
|
task = Bring(use_peg=use_peg, insert=insert,
|
||||||
|
fully_observable=fully_observable, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('hard')
|
||||||
|
def insert_peg(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns manipulator insert task with the peg prop."""
|
||||||
|
use_peg = True
|
||||||
|
insert = True
|
||||||
|
physics = Physics.from_xml_string(*make_model(use_peg, insert))
|
||||||
|
task = Bring(use_peg=use_peg, insert=insert,
|
||||||
|
fully_observable=fully_observable, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics with additional features for the Planar Manipulator domain."""
|
||||||
|
|
||||||
|
def bounded_joint_pos(self, joint_names):
|
||||||
|
"""Returns joint positions as (sin, cos) values."""
|
||||||
|
joint_pos = self.named.data.qpos[joint_names]
|
||||||
|
return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T
|
||||||
|
|
||||||
|
def joint_vel(self, joint_names):
|
||||||
|
"""Returns joint velocities."""
|
||||||
|
return self.named.data.qvel[joint_names]
|
||||||
|
|
||||||
|
def body_2d_pose(self, body_names, orientation=True):
|
||||||
|
"""Returns positions and/or orientations of bodies."""
|
||||||
|
if not isinstance(body_names, str):
|
||||||
|
body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices.
|
||||||
|
pos = self.named.data.xpos[body_names, ['x', 'z']]
|
||||||
|
if orientation:
|
||||||
|
ori = self.named.data.xquat[body_names, ['qw', 'qy']]
|
||||||
|
return np.hstack([pos, ori])
|
||||||
|
else:
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def touch(self):
|
||||||
|
return np.log1p(self.data.sensordata)
|
||||||
|
|
||||||
|
def site_distance(self, site1, site2):
|
||||||
|
site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0)
|
||||||
|
return np.linalg.norm(site1_to_site2)
|
||||||
|
|
||||||
|
|
||||||
|
class Bring(base.Task):
|
||||||
|
"""A Bring `Task`: bring the prop to the target."""
|
||||||
|
|
||||||
|
def __init__(self, use_peg, insert, fully_observable, random=None):
|
||||||
|
"""Initialize an instance of the `Bring` task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_peg: A `bool`, whether to replace the ball prop with the peg prop.
|
||||||
|
insert: A `bool`, whether to insert the prop in a receptacle.
|
||||||
|
fully_observable: A `bool`, whether the observation should contain the
|
||||||
|
position and velocity of the object being manipulated and the target
|
||||||
|
location.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._use_peg = use_peg
|
||||||
|
self._target = 'target_peg' if use_peg else 'target_ball'
|
||||||
|
self._object = 'peg' if self._use_peg else 'ball'
|
||||||
|
self._object_joints = ['_'.join([self._object, dim]) for dim in 'xzy']
|
||||||
|
self._receptacle = 'slot' if self._use_peg else 'cup'
|
||||||
|
self._insert = insert
|
||||||
|
self._fully_observable = fully_observable
|
||||||
|
super(Bring, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
# Local aliases
|
||||||
|
choice = self.random.choice
|
||||||
|
uniform = self.random.uniform
|
||||||
|
model = physics.named.model
|
||||||
|
data = physics.named.data
|
||||||
|
|
||||||
|
# Find a collision-free random initial configuration.
|
||||||
|
penetrating = True
|
||||||
|
while penetrating:
|
||||||
|
|
||||||
|
# Randomise angles of arm joints.
|
||||||
|
is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
|
||||||
|
joint_range = model.jnt_range[_ARM_JOINTS]
|
||||||
|
lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
|
||||||
|
upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
|
||||||
|
angles = uniform(lower_limits, upper_limits)
|
||||||
|
data.qpos[_ARM_JOINTS] = angles
|
||||||
|
|
||||||
|
# Symmetrize hand.
|
||||||
|
data.qpos['finger'] = data.qpos['thumb']
|
||||||
|
|
||||||
|
# Randomise target location.
|
||||||
|
target_x = uniform(-.4, .4)
|
||||||
|
target_z = uniform(.1, .4)
|
||||||
|
if self._insert:
|
||||||
|
target_angle = uniform(-np.pi/3, np.pi/3)
|
||||||
|
model.body_pos[self._receptacle, ['x', 'z']] = target_x, target_z
|
||||||
|
model.body_quat[self._receptacle, ['qw', 'qy']] = [
|
||||||
|
np.cos(target_angle/2), np.sin(target_angle/2)]
|
||||||
|
else:
|
||||||
|
target_angle = uniform(-np.pi, np.pi)
|
||||||
|
|
||||||
|
model.body_pos[self._target, ['x', 'z']] = target_x, target_z
|
||||||
|
model.body_quat[self._target, ['qw', 'qy']] = [
|
||||||
|
np.cos(target_angle/2), np.sin(target_angle/2)]
|
||||||
|
|
||||||
|
# Randomise object location.
|
||||||
|
object_init_probs = [_P_IN_HAND, _P_IN_TARGET, 1-_P_IN_HAND-_P_IN_TARGET]
|
||||||
|
init_type = choice(['in_hand', 'in_target', 'uniform'],
|
||||||
|
p=object_init_probs)
|
||||||
|
if init_type == 'in_target':
|
||||||
|
object_x = target_x
|
||||||
|
object_z = target_z
|
||||||
|
object_angle = target_angle
|
||||||
|
elif init_type == 'in_hand':
|
||||||
|
physics.after_reset()
|
||||||
|
object_x = data.site_xpos['grasp', 'x']
|
||||||
|
object_z = data.site_xpos['grasp', 'z']
|
||||||
|
grasp_direction = data.site_xmat['grasp', ['xx', 'zx']]
|
||||||
|
object_angle = np.pi-np.arctan2(grasp_direction[1], grasp_direction[0])
|
||||||
|
else:
|
||||||
|
object_x = uniform(-.5, .5)
|
||||||
|
object_z = uniform(0, .7)
|
||||||
|
object_angle = uniform(0, 2*np.pi)
|
||||||
|
data.qvel[self._object + '_x'] = uniform(-5, 5)
|
||||||
|
|
||||||
|
data.qpos[self._object_joints] = object_x, object_z, object_angle
|
||||||
|
|
||||||
|
# Check for collisions.
|
||||||
|
physics.after_reset()
|
||||||
|
penetrating = physics.data.ncon > 0
|
||||||
|
|
||||||
|
super(Bring, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns either features or only sensors (to be used with pixels)."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['arm_pos'] = physics.bounded_joint_pos(_ARM_JOINTS)
|
||||||
|
obs['arm_vel'] = physics.joint_vel(_ARM_JOINTS)
|
||||||
|
obs['touch'] = physics.touch()
|
||||||
|
if self._fully_observable:
|
||||||
|
obs['hand_pos'] = physics.body_2d_pose('hand')
|
||||||
|
obs['object_pos'] = physics.body_2d_pose(self._object)
|
||||||
|
obs['object_vel'] = physics.joint_vel(self._object_joints)
|
||||||
|
obs['target_pos'] = physics.body_2d_pose(self._target)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _is_close(self, distance):
|
||||||
|
return rewards.tolerance(distance, (0, _CLOSE), _CLOSE*2)
|
||||||
|
|
||||||
|
def _peg_reward(self, physics):
|
||||||
|
"""Returns a reward for bringing the peg prop to the target."""
|
||||||
|
grasp = self._is_close(physics.site_distance('peg_grasp', 'grasp'))
|
||||||
|
pinch = self._is_close(physics.site_distance('peg_pinch', 'pinch'))
|
||||||
|
grasping = (grasp + pinch) / 2
|
||||||
|
bring = self._is_close(physics.site_distance('peg', 'target_peg'))
|
||||||
|
bring_tip = self._is_close(physics.site_distance('target_peg_tip',
|
||||||
|
'peg_tip'))
|
||||||
|
bringing = (bring + bring_tip) / 2
|
||||||
|
return max(bringing, grasping/3)
|
||||||
|
|
||||||
|
def _ball_reward(self, physics):
|
||||||
|
"""Returns a reward for bringing the ball prop to the target."""
|
||||||
|
return self._is_close(physics.site_distance('ball', 'target_ball'))
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
if self._use_peg:
|
||||||
|
return self._peg_reward(physics)
|
||||||
|
else:
|
||||||
|
return self._ball_reward(physics)
|
211
local_dm_control_suite/manipulator.xml
Executable file
211
local_dm_control_suite/manipulator.xml
Executable file
@ -0,0 +1,211 @@
|
|||||||
|
<mujoco model="planar manipulator">
|
||||||
|
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
<asset>
|
||||||
|
<texture name="background" builtin="flat" type="2d" mark="random" markrgb="1 1 1" width="800" height="800" rgb1=".2 .3 .4"/>
|
||||||
|
<material name="background" texture="background" texrepeat="1 1" texuniform="true"/>
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
<visual>
|
||||||
|
<map shadowclip=".5"/>
|
||||||
|
<quality shadowsize="2048"/>
|
||||||
|
</visual>>
|
||||||
|
|
||||||
|
<option timestep="0.001" cone="elliptic"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<geom friction=".7" solimp="0.9 0.97 0.001" solref=".005 1"/>
|
||||||
|
<joint solimplimit="0 0.99 0.01" solreflimit=".005 1"/>
|
||||||
|
<general ctrllimited="true"/>
|
||||||
|
<tendon width="0.01"/>
|
||||||
|
<site size=".003 .003 .003" material="site" group="3"/>
|
||||||
|
|
||||||
|
<default class="arm">
|
||||||
|
<geom type="capsule" material="self" density="500"/>
|
||||||
|
<joint type="hinge" pos="0 0 0" axis="0 -1 0" limited="true"/>
|
||||||
|
<default class="hand">
|
||||||
|
<joint damping=".5" range="-10 60"/>
|
||||||
|
<geom size=".008"/>
|
||||||
|
<site type="box" size=".018 .005 .005" pos=".022 0 -.002" euler="0 15 0" group="4"/>
|
||||||
|
<default class="fingertip">
|
||||||
|
<geom type="sphere" size=".008" material="effector"/>
|
||||||
|
<joint damping=".01" stiffness=".01" range="-40 20"/>
|
||||||
|
<site size=".012 .005 .008" pos=".003 0 .003" group="4" euler="0 0 0"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="object">
|
||||||
|
<geom material="self"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="task">
|
||||||
|
<site rgba="0 0 0 0"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="obstacle">
|
||||||
|
<geom material="decoration" friction="0"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="ghost">
|
||||||
|
<geom material="target" contype="0" conaffinity="0"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<!-- Arena -->
|
||||||
|
<light name="light" directional="true" diffuse=".6 .6 .6" pos="0 0 1" specular=".3 .3 .3"/>
|
||||||
|
<geom name="floor" type="plane" pos="0 0 0" size=".4 .2 10" material="grid"/>
|
||||||
|
<geom name="wall1" type="plane" pos="-.682843 0 .282843" size=".4 .2 10" material="grid" zaxis="1 0 1"/>
|
||||||
|
<geom name="wall2" type="plane" pos=".682843 0 .282843" size=".4 .2 10" material="grid" zaxis="-1 0 1"/>
|
||||||
|
<geom name="background" type="plane" pos="0 .2 .5" size="1 .5 10" material="background" zaxis="0 -1 0"/>
|
||||||
|
<camera name="fixed" pos="0 -16 .4" xyaxes="1 0 0 0 0 1" fovy="4"/>
|
||||||
|
|
||||||
|
<!-- Arm -->
|
||||||
|
<geom name="arm_root" type="cylinder" fromto="0 -.022 .4 0 .022 .4" size=".024"
|
||||||
|
material="decoration" contype="0" conaffinity="0"/>
|
||||||
|
<body name="upper_arm" pos="0 0 .4" childclass="arm">
|
||||||
|
<joint name="arm_root" damping="2" limited="false"/>
|
||||||
|
<geom name="upper_arm" size=".02" fromto="0 0 0 0 0 .18"/>
|
||||||
|
<body name="middle_arm" pos="0 0 .18" childclass="arm">
|
||||||
|
<joint name="arm_shoulder" damping="1.5" range="-160 160"/>
|
||||||
|
<geom name="middle_arm" size=".017" fromto="0 0 0 0 0 .15"/>
|
||||||
|
<body name="lower_arm" pos="0 0 .15">
|
||||||
|
<joint name="arm_elbow" damping="1" range="-160 160"/>
|
||||||
|
<geom name="lower_arm" size=".014" fromto="0 0 0 0 0 .12"/>
|
||||||
|
<body name="hand" pos="0 0 .12">
|
||||||
|
<joint name="arm_wrist" damping=".5" range="-140 140" />
|
||||||
|
<geom name="hand" size=".011" fromto="0 0 0 0 0 .03"/>
|
||||||
|
<geom name="palm1" fromto="0 0 .03 .03 0 .045" class="hand"/>
|
||||||
|
<geom name="palm2" fromto="0 0 .03 -.03 0 .045" class="hand"/>
|
||||||
|
<site name="grasp" pos="0 0 .065"/>
|
||||||
|
<body name="pinch site" pos="0 0 .090">
|
||||||
|
<site name="pinch"/>
|
||||||
|
<inertial pos="0 0 0" mass="1e-6" diaginertia="1e-12 1e-12 1e-12"/>
|
||||||
|
<camera name="hand" pos="0 -.3 0" xyaxes="1 0 0 0 0 1" mode="track"/>
|
||||||
|
</body>
|
||||||
|
<site name="palm_touch" type="box" group="4" size=".025 .005 .008" pos="0 0 .043"/>
|
||||||
|
|
||||||
|
<body name="thumb" pos=".03 0 .045" euler="0 -90 0" childclass="hand">
|
||||||
|
<joint name="thumb"/>
|
||||||
|
<geom name="thumb1" fromto="0 0 0 .02 0 -.01" size=".007"/>
|
||||||
|
<geom name="thumb2" fromto=".02 0 -.01 .04 0 -.01" size=".007"/>
|
||||||
|
<site name="thumb_touch" group="4"/>
|
||||||
|
<body name="thumbtip" pos=".05 0 -.01" childclass="fingertip">
|
||||||
|
<joint name="thumbtip"/>
|
||||||
|
<geom name="thumbtip1" pos="-.003 0 0" />
|
||||||
|
<geom name="thumbtip2" pos=".003 0 0" />
|
||||||
|
<site name="thumbtip_touch" group="4"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="finger" pos="-.03 0 .045" euler="0 90 180" childclass="hand">
|
||||||
|
<joint name="finger"/>
|
||||||
|
<geom name="finger1" fromto="0 0 0 .02 0 -.01" size=".007" />
|
||||||
|
<geom name="finger2" fromto=".02 0 -.01 .04 0 -.01" size=".007"/>
|
||||||
|
<site name="finger_touch"/>
|
||||||
|
<body name="fingertip" pos=".05 0 -.01" childclass="fingertip">
|
||||||
|
<joint name="fingertip"/>
|
||||||
|
<geom name="fingertip1" pos="-.003 0 0" />
|
||||||
|
<geom name="fingertip2" pos=".003 0 0" />
|
||||||
|
<site name="fingertip_touch"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<!-- props -->
|
||||||
|
<body name="ball" pos=".4 0 .4" childclass="object">
|
||||||
|
<joint name="ball_x" type="slide" axis="1 0 0" ref=".4"/>
|
||||||
|
<joint name="ball_z" type="slide" axis="0 0 1" ref=".4"/>
|
||||||
|
<joint name="ball_y" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="ball" type="sphere" size=".022" />
|
||||||
|
<site name="ball" type="sphere"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="peg" pos="-.4 0 .4" childclass="object">
|
||||||
|
<joint name="peg_x" type="slide" axis="1 0 0" ref="-.4"/>
|
||||||
|
<joint name="peg_z" type="slide" axis="0 0 1" ref=".4"/>
|
||||||
|
<joint name="peg_y" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="blade" type="capsule" size=".005" fromto="0 0 -.013 0 0 -.113"/>
|
||||||
|
<geom name="guard" type="capsule" size=".005" fromto="-.017 0 -.043 .017 0 -.043"/>
|
||||||
|
<body name="pommel" pos="0 0 -.013">
|
||||||
|
<geom name="pommel" type="sphere" size=".009"/>
|
||||||
|
</body>
|
||||||
|
<site name="peg" type="box" pos="0 0 -.063"/>
|
||||||
|
<site name="peg_pinch" type="box" pos="0 0 -.025"/>
|
||||||
|
<site name="peg_grasp" type="box" pos="0 0 0"/>
|
||||||
|
<site name="peg_tip" type="box" pos="0 0 -.113"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<!-- receptacles -->
|
||||||
|
<body name="slot" pos="-.405 0 .2" euler="0 20 0" childclass="obstacle">
|
||||||
|
<geom name="slot_0" type="box" pos="-.0252 0 -.083" size=".0198 .01 .035"/>
|
||||||
|
<geom name="slot_1" type="box" pos=" .0252 0 -.083" size=".0198 .01 .035"/>
|
||||||
|
<geom name="slot_2" type="box" pos=" 0 0 -.138" size=".045 .01 .02"/>
|
||||||
|
<site name="slot" type="box" pos="0 0 0"/>
|
||||||
|
<site name="slot_end" type="box" pos="0 0 -.05"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="cup" pos=".3 0 .4" euler="0 -15 0" childclass="obstacle">
|
||||||
|
<geom name="cup_0" type="capsule" size=".008" fromto="-.03 0 .06 -.03 0 -.015" />
|
||||||
|
<geom name="cup_1" type="capsule" size=".008" fromto="-.03 0 -.015 0 0 -.04" />
|
||||||
|
<geom name="cup_2" type="capsule" size=".008" fromto="0 0 -.04 .03 0 -.015" />
|
||||||
|
<geom name="cup_3" type="capsule" size=".008" fromto=".03 0 -.015 .03 0 .06" />
|
||||||
|
<site name="cup" size=".005"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<!-- targets -->
|
||||||
|
<body name="target_ball" pos=".4 .001 .4" childclass="ghost">
|
||||||
|
<geom name="target_ball" type="sphere" size=".02" />
|
||||||
|
<site name="target_ball" type="sphere"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="target_peg" pos="-.2 .001 .4" childclass="ghost">
|
||||||
|
<geom name="target_blade" type="capsule" size=".005" fromto="0 0 -.013 0 0 -.113"/>
|
||||||
|
<geom name="target_guard" type="capsule" size=".005" fromto="-.017 0 -.043 .017 0 -.043"/>
|
||||||
|
<geom name="target_pommel" type="sphere" size=".009" pos="0 0 -.013"/>
|
||||||
|
<site name="target_peg" type="box" pos="0 0 -.063"/>
|
||||||
|
<site name="target_peg_pinch" type="box" pos="0 0 -.025"/>
|
||||||
|
<site name="target_peg_grasp" type="box" pos="0 0 0"/>
|
||||||
|
<site name="target_peg_tip" type="box" pos="0 0 -.113"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<tendon>
|
||||||
|
<fixed name="grasp">
|
||||||
|
<joint joint="thumb" coef=".5"/>
|
||||||
|
<joint joint="finger" coef=".5"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="coupling">
|
||||||
|
<joint joint="thumb" coef="-.5"/>
|
||||||
|
<joint joint="finger" coef=".5"/>
|
||||||
|
</fixed>
|
||||||
|
</tendon>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<tendon name="coupling" tendon1="coupling" solimp="0.95 0.99 0.001" solref=".005 .5"/>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<touch name="palm_touch" site="palm_touch"/>
|
||||||
|
<touch name="finger_touch" site="finger_touch"/>
|
||||||
|
<touch name="thumb_touch" site="thumb_touch"/>
|
||||||
|
<touch name="fingertip_touch" site="fingertip_touch"/>
|
||||||
|
<touch name="thumbtip_touch" site="thumbtip_touch"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="root" joint="arm_root" ctrlrange="-1 1" gear="12"/>
|
||||||
|
<motor name="shoulder" joint="arm_shoulder" ctrlrange="-1 1" gear="8"/>
|
||||||
|
<motor name="elbow" joint="arm_elbow" ctrlrange="-1 1" gear="4"/>
|
||||||
|
<motor name="wrist" joint="arm_wrist" ctrlrange="-1 1" gear="2"/>
|
||||||
|
<motor name="grasp" tendon="grasp" ctrlrange="-1 1" gear="2"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
</mujoco>
|
114
local_dm_control_suite/pendulum.py
Executable file
114
local_dm_control_suite/pendulum.py
Executable file
@ -0,0 +1,114 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Pendulum domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 20
|
||||||
|
_ANGLE_BOUND = 8
|
||||||
|
_COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND))
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('pendulum.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns pendulum swingup task ."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = SwingUp(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Pendulum domain."""
|
||||||
|
|
||||||
|
def pole_vertical(self):
|
||||||
|
"""Returns vertical (z) component of pole frame."""
|
||||||
|
return self.named.data.xmat['pole', 'zz']
|
||||||
|
|
||||||
|
def angular_velocity(self):
|
||||||
|
"""Returns the angular velocity of the pole."""
|
||||||
|
return self.named.data.qvel['hinge'].copy()
|
||||||
|
|
||||||
|
def pole_orientation(self):
|
||||||
|
"""Returns both horizontal and vertical components of pole frame."""
|
||||||
|
return self.named.data.xmat['pole', ['zz', 'xz']]
|
||||||
|
|
||||||
|
|
||||||
|
class SwingUp(base.Task):
|
||||||
|
"""A Pendulum `Task` to swing up and balance the pole."""
|
||||||
|
|
||||||
|
def __init__(self, random=None):
|
||||||
|
"""Initialize an instance of `Pendulum`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
super(SwingUp, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Pole is set to a random angle between [-pi, pi).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
|
||||||
|
super(SwingUp, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation.
|
||||||
|
|
||||||
|
Observations are states concatenating pole orientation and angular velocity
|
||||||
|
and pixels from fixed camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `physics`, Pendulum physics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `dict` of observation.
|
||||||
|
"""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['orientation'] = physics.pole_orientation()
|
||||||
|
obs['velocity'] = physics.angular_velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1))
|
26
local_dm_control_suite/pendulum.xml
Executable file
26
local_dm_control_suite/pendulum.xml
Executable file
@ -0,0 +1,26 @@
|
|||||||
|
<mujoco model="pendulum">
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<option timestep="0.02">
|
||||||
|
<flag contact="disable" energy="enable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" pos="0 0 2"/>
|
||||||
|
<geom name="floor" size="2 2 .2" type="plane" material="grid"/>
|
||||||
|
<camera name="fixed" pos="0 -1.5 2" xyaxes='1 0 0 0 1 1'/>
|
||||||
|
<camera name="lookat" mode="targetbodycom" target="pole" pos="0 -2 1"/>
|
||||||
|
<body name="pole" pos="0 0 .6">
|
||||||
|
<joint name="hinge" type="hinge" axis="0 1 0" damping="0.1"/>
|
||||||
|
<geom name="base" material="decoration" type="cylinder" fromto="0 -.03 0 0 .03 0" size="0.021" mass="0"/>
|
||||||
|
<geom name="pole" material="self" type="capsule" fromto="0 0 0 0 0 0.5" size="0.02" mass="0"/>
|
||||||
|
<geom name="mass" material="effector" type="sphere" pos="0 0 0.5" size="0.05" mass="1"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="torque" joint="hinge" gear="1" ctrlrange="-1 1" ctrllimited="true"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
130
local_dm_control_suite/point_mass.py
Executable file
130
local_dm_control_suite/point_mass.py
Executable file
@ -0,0 +1,130 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Point-mass domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 20
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('point_mass.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking', 'easy')
|
||||||
|
def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the easy point_mass task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = PointMass(randomize_gains=False, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the hard point_mass task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = PointMass(randomize_gains=True, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""physics for the point_mass domain."""
|
||||||
|
|
||||||
|
def mass_to_target(self):
|
||||||
|
"""Returns the vector from mass to target in global coordinate."""
|
||||||
|
return (self.named.data.geom_xpos['target'] -
|
||||||
|
self.named.data.geom_xpos['pointmass'])
|
||||||
|
|
||||||
|
def mass_to_target_dist(self):
|
||||||
|
"""Returns the distance from mass to the target."""
|
||||||
|
return np.linalg.norm(self.mass_to_target())
|
||||||
|
|
||||||
|
|
||||||
|
class PointMass(base.Task):
|
||||||
|
"""A point_mass `Task` to reach target with smooth reward."""
|
||||||
|
|
||||||
|
def __init__(self, randomize_gains, random=None):
|
||||||
|
"""Initialize an instance of `PointMass`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
randomize_gains: A `bool`, whether to randomize the actuator gains.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._randomize_gains = randomize_gains
|
||||||
|
super(PointMass, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
If _randomize_gains is True, the relationship between the controls and
|
||||||
|
the joints is randomized, so that each control actuates a random linear
|
||||||
|
combination of joints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `mujoco.Physics`.
|
||||||
|
"""
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
|
||||||
|
if self._randomize_gains:
|
||||||
|
dir1 = self.random.randn(2)
|
||||||
|
dir1 /= np.linalg.norm(dir1)
|
||||||
|
# Find another actuation direction that is not 'too parallel' to dir1.
|
||||||
|
parallel = True
|
||||||
|
while parallel:
|
||||||
|
dir2 = self.random.randn(2)
|
||||||
|
dir2 /= np.linalg.norm(dir2)
|
||||||
|
parallel = abs(np.dot(dir1, dir2)) > 0.9
|
||||||
|
physics.model.wrap_prm[[0, 1]] = dir1
|
||||||
|
physics.model.wrap_prm[[2, 3]] = dir2
|
||||||
|
super(PointMass, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of the state."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['position'] = physics.position()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
target_size = physics.named.model.geom_size['target', 0]
|
||||||
|
near_target = rewards.tolerance(physics.mass_to_target_dist(),
|
||||||
|
bounds=(0, target_size), margin=target_size)
|
||||||
|
control_reward = rewards.tolerance(physics.control(), margin=1,
|
||||||
|
value_at_margin=0,
|
||||||
|
sigmoid='quadratic').mean()
|
||||||
|
small_control = (control_reward + 4) / 5
|
||||||
|
return near_target * small_control
|
49
local_dm_control_suite/point_mass.xml
Executable file
49
local_dm_control_suite/point_mass.xml
Executable file
@ -0,0 +1,49 @@
|
|||||||
|
<mujoco model="planar point mass">
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<option timestep="0.02">
|
||||||
|
<flag contact="disable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<joint type="hinge" axis="0 0 1" limited="true" range="-.29 .29" damping="1"/>
|
||||||
|
<motor gear=".1" ctrlrange="-1 1" ctrllimited="true"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" pos="0 0 1"/>
|
||||||
|
<camera name="fixed" pos="0 0 .75" quat="1 0 0 0"/>
|
||||||
|
<geom name="ground" type="plane" pos="0 0 0" size=".3 .3 .1" material="grid"/>
|
||||||
|
<geom name="wall_x" type="plane" pos="-.3 0 .02" zaxis="1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||||
|
<geom name="wall_y" type="plane" pos="0 -.3 .02" zaxis="0 1 0" size=".3 .02 .02" material="decoration"/>
|
||||||
|
<geom name="wall_neg_x" type="plane" pos=".3 0 .02" zaxis="-1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||||
|
<geom name="wall_neg_y" type="plane" pos="0 .3 .02" zaxis="0 -1 0" size=".3 .02 .02" material="decoration"/>
|
||||||
|
|
||||||
|
<body name="pointmass" pos="0 0 .01">
|
||||||
|
<camera name="cam0" pos="0 -0.3 0.3" xyaxes="1 0 0 0 0.7 0.7"/>
|
||||||
|
<joint name="root_x" type="slide" pos="0 0 0" axis="1 0 0" />
|
||||||
|
<joint name="root_y" type="slide" pos="0 0 0" axis="0 1 0" />
|
||||||
|
<geom name="pointmass" type="sphere" size=".01" material="self" mass=".3"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<geom name="target" pos="0 0 .01" material="target" type="sphere" size=".015"/>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<tendon>
|
||||||
|
<fixed name="t1">
|
||||||
|
<joint joint="root_x" coef="1"/>
|
||||||
|
<joint joint="root_y" coef="0"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="t2">
|
||||||
|
<joint joint="root_x" coef="0"/>
|
||||||
|
<joint joint="root_y" coef="1"/>
|
||||||
|
</fixed>
|
||||||
|
</tendon>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="t1" tendon="t1"/>
|
||||||
|
<motor name="t2" tendon="t2"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
480
local_dm_control_suite/quadruped.py
Executable file
480
local_dm_control_suite/quadruped.py
Executable file
@ -0,0 +1,480 @@
|
|||||||
|
# Copyright 2019 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Quadruped Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.mujoco.wrapper import mjbindings
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
from dm_control.utils import xml_tools
|
||||||
|
|
||||||
|
from lxml import etree
|
||||||
|
import numpy as np
|
||||||
|
from scipy import ndimage
|
||||||
|
|
||||||
|
enums = mjbindings.enums
|
||||||
|
mjlib = mjbindings.mjlib
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 20
|
||||||
|
_CONTROL_TIMESTEP = .02
|
||||||
|
|
||||||
|
# Horizontal speeds above which the move reward is 1.
|
||||||
|
_RUN_SPEED = 5
|
||||||
|
_WALK_SPEED = 0.5
|
||||||
|
|
||||||
|
# Constants related to terrain generation.
|
||||||
|
_HEIGHTFIELD_ID = 0
|
||||||
|
_TERRAIN_SMOOTHNESS = 0.15 # 0.0: maximally bumpy; 1.0: completely smooth.
|
||||||
|
_TERRAIN_BUMP_SCALE = 2 # Spatial scale of terrain bumps (in meters).
|
||||||
|
|
||||||
|
# Named model elements.
|
||||||
|
_TOES = ['toe_front_left', 'toe_back_left', 'toe_back_right', 'toe_front_right']
|
||||||
|
_WALLS = ['wall_px', 'wall_py', 'wall_nx', 'wall_ny']
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def make_model(floor_size=None, terrain=False, rangefinders=False,
|
||||||
|
walls_and_ball=False):
|
||||||
|
"""Returns the model XML string."""
|
||||||
|
xml_string = common.read_model('quadruped.xml')
|
||||||
|
parser = etree.XMLParser(remove_blank_text=True)
|
||||||
|
mjcf = etree.XML(xml_string, parser)
|
||||||
|
|
||||||
|
# Set floor size.
|
||||||
|
if floor_size is not None:
|
||||||
|
floor_geom = mjcf.find('.//geom[@name={!r}]'.format('floor'))
|
||||||
|
floor_geom.attrib['size'] = '{} {} .5'.format(floor_size, floor_size)
|
||||||
|
|
||||||
|
# Remove walls, ball and target.
|
||||||
|
if not walls_and_ball:
|
||||||
|
for wall in _WALLS:
|
||||||
|
wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
|
||||||
|
wall_geom.getparent().remove(wall_geom)
|
||||||
|
|
||||||
|
# Remove ball.
|
||||||
|
ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
|
||||||
|
ball_body.getparent().remove(ball_body)
|
||||||
|
|
||||||
|
# Remove target.
|
||||||
|
target_site = xml_tools.find_element(mjcf, 'site', 'target')
|
||||||
|
target_site.getparent().remove(target_site)
|
||||||
|
|
||||||
|
# Remove terrain.
|
||||||
|
if not terrain:
|
||||||
|
terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
|
||||||
|
terrain_geom.getparent().remove(terrain_geom)
|
||||||
|
|
||||||
|
# Remove rangefinders if they're not used, as range computations can be
|
||||||
|
# expensive, especially in a scene with heightfields.
|
||||||
|
if not rangefinders:
|
||||||
|
rangefinder_sensors = mjcf.findall('.//rangefinder')
|
||||||
|
for rf in rangefinder_sensors:
|
||||||
|
rf.getparent().remove(rf)
|
||||||
|
|
||||||
|
return etree.tostring(mjcf, pretty_print=True)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Walk task."""
|
||||||
|
xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
|
||||||
|
physics = Physics.from_xml_string(xml_string, common.ASSETS)
|
||||||
|
task = Move(desired_speed=_WALK_SPEED, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(physics, task, time_limit=time_limit,
|
||||||
|
control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Run task."""
|
||||||
|
xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _RUN_SPEED)
|
||||||
|
physics = Physics.from_xml_string(xml_string, common.ASSETS)
|
||||||
|
task = Move(desired_speed=_RUN_SPEED, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(physics, task, time_limit=time_limit,
|
||||||
|
control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def escape(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns the Escape task."""
|
||||||
|
xml_string = make_model(floor_size=40, terrain=True, rangefinders=True)
|
||||||
|
physics = Physics.from_xml_string(xml_string, common.ASSETS)
|
||||||
|
task = Escape(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(physics, task, time_limit=time_limit,
|
||||||
|
control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add()
|
||||||
|
def fetch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Fetch task."""
|
||||||
|
xml_string = make_model(walls_and_ball=True)
|
||||||
|
physics = Physics.from_xml_string(xml_string, common.ASSETS)
|
||||||
|
task = Fetch(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(physics, task, time_limit=time_limit,
|
||||||
|
control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Quadruped domain."""
|
||||||
|
|
||||||
|
def _reload_from_data(self, data):
|
||||||
|
super(Physics, self)._reload_from_data(data)
|
||||||
|
# Clear cached sensor names when the physics is reloaded.
|
||||||
|
self._sensor_types_to_names = {}
|
||||||
|
self._hinge_names = []
|
||||||
|
|
||||||
|
def _get_sensor_names(self, *sensor_types):
|
||||||
|
try:
|
||||||
|
sensor_names = self._sensor_types_to_names[sensor_types]
|
||||||
|
except KeyError:
|
||||||
|
[sensor_ids] = np.where(np.in1d(self.model.sensor_type, sensor_types))
|
||||||
|
sensor_names = [self.model.id2name(s_id, 'sensor') for s_id in sensor_ids]
|
||||||
|
self._sensor_types_to_names[sensor_types] = sensor_names
|
||||||
|
return sensor_names
|
||||||
|
|
||||||
|
def torso_upright(self):
|
||||||
|
"""Returns the dot-product of the torso z-axis and the global z-axis."""
|
||||||
|
return np.asarray(self.named.data.xmat['torso', 'zz'])
|
||||||
|
|
||||||
|
def torso_velocity(self):
|
||||||
|
"""Returns the velocity of the torso, in the local frame."""
|
||||||
|
return self.named.data.sensordata['velocimeter'].copy()
|
||||||
|
|
||||||
|
def egocentric_state(self):
|
||||||
|
"""Returns the state without global orientation or position."""
|
||||||
|
if not self._hinge_names:
|
||||||
|
[hinge_ids] = np.nonzero(self.model.jnt_type ==
|
||||||
|
enums.mjtJoint.mjJNT_HINGE)
|
||||||
|
self._hinge_names = [self.model.id2name(j_id, 'joint')
|
||||||
|
for j_id in hinge_ids]
|
||||||
|
return np.hstack((self.named.data.qpos[self._hinge_names],
|
||||||
|
self.named.data.qvel[self._hinge_names],
|
||||||
|
self.data.act))
|
||||||
|
|
||||||
|
def toe_positions(self):
|
||||||
|
"""Returns toe positions in egocentric frame."""
|
||||||
|
torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
|
||||||
|
torso_pos = self.named.data.xpos['torso']
|
||||||
|
torso_to_toe = self.named.data.xpos[_TOES] - torso_pos
|
||||||
|
return torso_to_toe.dot(torso_frame)
|
||||||
|
|
||||||
|
def force_torque(self):
|
||||||
|
"""Returns scaled force/torque sensor readings at the toes."""
|
||||||
|
force_torque_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_FORCE,
|
||||||
|
enums.mjtSensor.mjSENS_TORQUE)
|
||||||
|
return np.arcsinh(self.named.data.sensordata[force_torque_sensors])
|
||||||
|
|
||||||
|
def imu(self):
|
||||||
|
"""Returns IMU-like sensor readings."""
|
||||||
|
imu_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_GYRO,
|
||||||
|
enums.mjtSensor.mjSENS_ACCELEROMETER)
|
||||||
|
return self.named.data.sensordata[imu_sensors]
|
||||||
|
|
||||||
|
def rangefinder(self):
|
||||||
|
"""Returns scaled rangefinder sensor readings."""
|
||||||
|
rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER)
|
||||||
|
rf_readings = self.named.data.sensordata[rf_sensors]
|
||||||
|
no_intersection = -1.0
|
||||||
|
return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings))
|
||||||
|
|
||||||
|
def origin_distance(self):
|
||||||
|
"""Returns the distance from the origin to the workspace."""
|
||||||
|
return np.asarray(np.linalg.norm(self.named.data.site_xpos['workspace']))
|
||||||
|
|
||||||
|
def origin(self):
|
||||||
|
"""Returns origin position in the torso frame."""
|
||||||
|
torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
|
||||||
|
torso_pos = self.named.data.xpos['torso']
|
||||||
|
return -torso_pos.dot(torso_frame)
|
||||||
|
|
||||||
|
def ball_state(self):
|
||||||
|
"""Returns ball position and velocity relative to the torso frame."""
|
||||||
|
data = self.named.data
|
||||||
|
torso_frame = data.xmat['torso'].reshape(3, 3)
|
||||||
|
ball_rel_pos = data.xpos['ball'] - data.xpos['torso']
|
||||||
|
ball_rel_vel = data.qvel['ball_root'][:3] - data.qvel['root'][:3]
|
||||||
|
ball_rot_vel = data.qvel['ball_root'][3:]
|
||||||
|
ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel))
|
||||||
|
return ball_state.dot(torso_frame).ravel()
|
||||||
|
|
||||||
|
def target_position(self):
|
||||||
|
"""Returns target position in torso frame."""
|
||||||
|
torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
|
||||||
|
torso_pos = self.named.data.xpos['torso']
|
||||||
|
torso_to_target = self.named.data.site_xpos['target'] - torso_pos
|
||||||
|
return torso_to_target.dot(torso_frame)
|
||||||
|
|
||||||
|
def ball_to_target_distance(self):
|
||||||
|
"""Returns horizontal distance from the ball to the target."""
|
||||||
|
ball_to_target = (self.named.data.site_xpos['target'] -
|
||||||
|
self.named.data.xpos['ball'])
|
||||||
|
return np.linalg.norm(ball_to_target[:2])
|
||||||
|
|
||||||
|
def self_to_ball_distance(self):
|
||||||
|
"""Returns horizontal distance from the quadruped workspace to the ball."""
|
||||||
|
self_to_ball = (self.named.data.site_xpos['workspace']
|
||||||
|
-self.named.data.xpos['ball'])
|
||||||
|
return np.linalg.norm(self_to_ball[:2])
|
||||||
|
|
||||||
|
|
||||||
|
def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0):
|
||||||
|
"""Find a height with no contacts given a body orientation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
orientation: A quaternion.
|
||||||
|
x_pos: A float. Position along global x-axis.
|
||||||
|
y_pos: A float. Position along global y-axis.
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If a non-contacting configuration has not been found after
|
||||||
|
10,000 attempts.
|
||||||
|
"""
|
||||||
|
z_pos = 0.0 # Start embedded in the floor.
|
||||||
|
num_contacts = 1
|
||||||
|
num_attempts = 0
|
||||||
|
# Move up in 1cm increments until no contacts.
|
||||||
|
while num_contacts > 0:
|
||||||
|
try:
|
||||||
|
with physics.reset_context():
|
||||||
|
physics.named.data.qpos['root'][:3] = x_pos, y_pos, z_pos
|
||||||
|
physics.named.data.qpos['root'][3:] = orientation
|
||||||
|
except control.PhysicsError:
|
||||||
|
# We may encounter a PhysicsError here due to filling the contact
|
||||||
|
# buffer, in which case we simply increment the height and continue.
|
||||||
|
pass
|
||||||
|
num_contacts = physics.data.ncon
|
||||||
|
z_pos += 0.01
|
||||||
|
num_attempts += 1
|
||||||
|
if num_attempts > 10000:
|
||||||
|
raise RuntimeError('Failed to find a non-contacting configuration.')
|
||||||
|
|
||||||
|
|
||||||
|
def _common_observations(physics):
|
||||||
|
"""Returns the observations common to all tasks."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['egocentric_state'] = physics.egocentric_state()
|
||||||
|
obs['torso_velocity'] = physics.torso_velocity()
|
||||||
|
obs['torso_upright'] = physics.torso_upright()
|
||||||
|
obs['imu'] = physics.imu()
|
||||||
|
obs['force_torque'] = physics.force_torque()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
def _upright_reward(physics, deviation_angle=0):
|
||||||
|
"""Returns a reward proportional to how upright the torso is.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: an instance of `Physics`.
|
||||||
|
deviation_angle: A float, in degrees. The reward is 0 when the torso is
|
||||||
|
exactly upside-down and 1 when the torso's z-axis is less than
|
||||||
|
`deviation_angle` away from the global z-axis.
|
||||||
|
"""
|
||||||
|
deviation = np.cos(np.deg2rad(deviation_angle))
|
||||||
|
return rewards.tolerance(
|
||||||
|
physics.torso_upright(),
|
||||||
|
bounds=(deviation, float('inf')),
|
||||||
|
sigmoid='linear',
|
||||||
|
margin=1 + deviation,
|
||||||
|
value_at_margin=0)
|
||||||
|
|
||||||
|
|
||||||
|
class Move(base.Task):
|
||||||
|
"""A quadruped task solved by moving forward at a designated speed."""
|
||||||
|
|
||||||
|
def __init__(self, desired_speed, random=None):
|
||||||
|
"""Initializes an instance of `Move`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
desired_speed: A float. If this value is zero, reward is given simply
|
||||||
|
for standing upright. Otherwise this specifies the horizontal velocity
|
||||||
|
at which the velocity-dependent reward component is maximized.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._desired_speed = desired_speed
|
||||||
|
super(Move, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Initial configuration.
|
||||||
|
orientation = self.random.randn(4)
|
||||||
|
orientation /= np.linalg.norm(orientation)
|
||||||
|
_find_non_contacting_height(physics, orientation)
|
||||||
|
super(Move, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation to the agent."""
|
||||||
|
return _common_observations(physics)
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
|
||||||
|
# Move reward term.
|
||||||
|
move_reward = rewards.tolerance(
|
||||||
|
physics.torso_velocity()[0],
|
||||||
|
bounds=(self._desired_speed, float('inf')),
|
||||||
|
margin=self._desired_speed,
|
||||||
|
value_at_margin=0.5,
|
||||||
|
sigmoid='linear')
|
||||||
|
|
||||||
|
return _upright_reward(physics) * move_reward
|
||||||
|
|
||||||
|
|
||||||
|
class Escape(base.Task):
|
||||||
|
"""A quadruped task solved by escaping a bowl-shaped terrain."""
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get heightfield resolution, assert that it is square.
|
||||||
|
res = physics.model.hfield_nrow[_HEIGHTFIELD_ID]
|
||||||
|
assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID]
|
||||||
|
# Sinusoidal bowl shape.
|
||||||
|
row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j]
|
||||||
|
radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .04, 1)
|
||||||
|
bowl_shape = .5 - np.cos(2*np.pi*radius)/2
|
||||||
|
# Random smooth bumps.
|
||||||
|
terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
|
||||||
|
bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
|
||||||
|
bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
|
||||||
|
smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
|
||||||
|
# Terrain is elementwise product.
|
||||||
|
terrain = bowl_shape * smooth_bumps
|
||||||
|
start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID]
|
||||||
|
physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel()
|
||||||
|
super(Escape, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
# If we have a rendering context, we need to re-upload the modified
|
||||||
|
# heightfield data.
|
||||||
|
if physics.contexts:
|
||||||
|
with physics.contexts.gl.make_current() as ctx:
|
||||||
|
ctx.call(mjlib.mjr_uploadHField,
|
||||||
|
physics.model.ptr,
|
||||||
|
physics.contexts.mujoco.ptr,
|
||||||
|
_HEIGHTFIELD_ID)
|
||||||
|
|
||||||
|
# Initial configuration.
|
||||||
|
orientation = self.random.randn(4)
|
||||||
|
orientation /= np.linalg.norm(orientation)
|
||||||
|
_find_non_contacting_height(physics, orientation)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation to the agent."""
|
||||||
|
obs = _common_observations(physics)
|
||||||
|
obs['origin'] = physics.origin()
|
||||||
|
obs['rangefinder'] = physics.rangefinder()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
|
||||||
|
# Escape reward term.
|
||||||
|
terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
|
||||||
|
escape_reward = rewards.tolerance(
|
||||||
|
physics.origin_distance(),
|
||||||
|
bounds=(terrain_size, float('inf')),
|
||||||
|
margin=terrain_size,
|
||||||
|
value_at_margin=0,
|
||||||
|
sigmoid='linear')
|
||||||
|
|
||||||
|
return _upright_reward(physics, deviation_angle=20) * escape_reward
|
||||||
|
|
||||||
|
|
||||||
|
class Fetch(base.Task):
|
||||||
|
"""A quadruped task solved by bringing a ball to the origin."""
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Initial configuration, random azimuth and horizontal position.
|
||||||
|
azimuth = self.random.uniform(0, 2*np.pi)
|
||||||
|
orientation = np.array((np.cos(azimuth/2), 0, 0, np.sin(azimuth/2)))
|
||||||
|
spawn_radius = 0.9 * physics.named.model.geom_size['floor', 0]
|
||||||
|
x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,))
|
||||||
|
_find_non_contacting_height(physics, orientation, x_pos, y_pos)
|
||||||
|
|
||||||
|
# Initial ball state.
|
||||||
|
physics.named.data.qpos['ball_root'][:2] = self.random.uniform(
|
||||||
|
-spawn_radius, spawn_radius, size=(2,))
|
||||||
|
physics.named.data.qpos['ball_root'][2] = 2
|
||||||
|
physics.named.data.qvel['ball_root'][:2] = 5*self.random.randn(2)
|
||||||
|
super(Fetch, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation to the agent."""
|
||||||
|
obs = _common_observations(physics)
|
||||||
|
obs['ball_state'] = physics.ball_state()
|
||||||
|
obs['target_position'] = physics.target_position()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
|
||||||
|
# Reward for moving close to the ball.
|
||||||
|
arena_radius = physics.named.model.geom_size['floor', 0] * np.sqrt(2)
|
||||||
|
workspace_radius = physics.named.model.site_size['workspace', 0]
|
||||||
|
ball_radius = physics.named.model.geom_size['ball', 0]
|
||||||
|
reach_reward = rewards.tolerance(
|
||||||
|
physics.self_to_ball_distance(),
|
||||||
|
bounds=(0, workspace_radius+ball_radius),
|
||||||
|
sigmoid='linear',
|
||||||
|
margin=arena_radius, value_at_margin=0)
|
||||||
|
|
||||||
|
# Reward for bringing the ball to the target.
|
||||||
|
target_radius = physics.named.model.site_size['target', 0]
|
||||||
|
fetch_reward = rewards.tolerance(
|
||||||
|
physics.ball_to_target_distance(),
|
||||||
|
bounds=(0, target_radius),
|
||||||
|
sigmoid='linear',
|
||||||
|
margin=arena_radius, value_at_margin=0)
|
||||||
|
|
||||||
|
reach_then_fetch = reach_reward * (0.5 + 0.5*fetch_reward)
|
||||||
|
|
||||||
|
return _upright_reward(physics) * reach_then_fetch
|
329
local_dm_control_suite/quadruped.xml
Executable file
329
local_dm_control_suite/quadruped.xml
Executable file
@ -0,0 +1,329 @@
|
|||||||
|
<mujoco model="quadruped">
|
||||||
|
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<visual>
|
||||||
|
<rgba rangefinder="1 1 0.1 0.1"/>
|
||||||
|
<map znear=".005" zfar="20"/>
|
||||||
|
</visual>
|
||||||
|
|
||||||
|
<asset>
|
||||||
|
<hfield name="terrain" ncol="201" nrow="201" size="30 30 5 .1"/>
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
<option timestep=".005"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<geom solimp=".9 .99 .003" solref=".01 1"/>
|
||||||
|
<default class="body">
|
||||||
|
<geom type="capsule" size=".08" condim="1" material="self" density="500"/>
|
||||||
|
<joint type="hinge" damping="30" armature=".01"
|
||||||
|
limited="true" solimplimit="0 .99 .01"/>
|
||||||
|
<default class="hip">
|
||||||
|
<default class="yaw">
|
||||||
|
<joint axis="0 0 1" range="-50 50"/>
|
||||||
|
</default>
|
||||||
|
<default class="pitch">
|
||||||
|
<joint axis="0 1 0" range="-20 60"/>
|
||||||
|
</default>
|
||||||
|
<geom fromto="0 0 0 .3 0 .11"/>
|
||||||
|
</default>
|
||||||
|
<default class="knee">
|
||||||
|
<joint axis="0 1 0" range="-60 50"/>
|
||||||
|
<geom size=".065" fromto="0 0 0 .25 0 -.25"/>
|
||||||
|
</default>
|
||||||
|
<default class="ankle">
|
||||||
|
<joint axis="0 1 0" range="-45 55"/>
|
||||||
|
<geom size=".055" fromto="0 0 0 0 0 -.25"/>
|
||||||
|
</default>
|
||||||
|
<default class="toe">
|
||||||
|
<geom type="sphere" size=".08" material="effector" friction="1.5"/>
|
||||||
|
<site type="sphere" size=".084" material="site" group="4"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
<default class="rangefinder">
|
||||||
|
<site type="capsule" size=".005 .1" material="site" group="4"/>
|
||||||
|
</default>
|
||||||
|
<default class="wall">
|
||||||
|
<geom type="plane" material="decoration"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="coupling">
|
||||||
|
<equality solimp="0.95 0.99 0.01" solref=".005 .5"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<general ctrllimited="true" gainprm="1000" biasprm="0 -1000" biastype="affine" dyntype="filter" dynprm=".1"/>
|
||||||
|
<default class="yaw_act">
|
||||||
|
<general ctrlrange="-1 1"/>
|
||||||
|
</default>
|
||||||
|
<default class="lift_act">
|
||||||
|
<general ctrlrange="-1 1.1"/>
|
||||||
|
</default>
|
||||||
|
<default class="extend_act">
|
||||||
|
<general ctrlrange="-.8 .8"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<asset>
|
||||||
|
<texture name="ball" builtin="checker" mark="cross" width="151" height="151"
|
||||||
|
rgb1="0.1 0.1 0.1" rgb2="0.9 0.9 0.9" markrgb="1 1 1"/>
|
||||||
|
<material name="ball" texture="ball" />
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<geom name="floor" type="plane" size="15 15 .5" material="grid"/>
|
||||||
|
<geom name="wall_px" class="wall" pos="-15.7 0 .7" zaxis="1 0 1" size="1 15 .5"/>
|
||||||
|
<geom name="wall_py" class="wall" pos="0 -15.7 .7" zaxis="0 1 1" size="15 1 .5"/>
|
||||||
|
<geom name="wall_nx" class="wall" pos="15.7 0 .7" zaxis="-1 0 1" size="1 15 .5"/>
|
||||||
|
<geom name="wall_ny" class="wall" pos="0 15.7 .7" zaxis="0 -1 1" size="15 1 .5"/>
|
||||||
|
<site name="target" type="cylinder" size=".4 .06" pos="0 0 .05" material="target"/>
|
||||||
|
|
||||||
|
<geom name="terrain" type="hfield" hfield="terrain" rgba=".2 .3 .4 1" pos="0 0 -.01"/>
|
||||||
|
|
||||||
|
<camera name="global" pos="-10 10 10" xyaxes="-1 -1 0 1 0 1" mode="trackcom"/>
|
||||||
|
<body name="torso" childclass="body" pos="0 0 .57">
|
||||||
|
<freejoint name="root"/>
|
||||||
|
|
||||||
|
<camera name="x" pos="-1.7 0 1" xyaxes="0 -1 0 .75 0 1" mode="trackcom"/>
|
||||||
|
<camera name="y" pos="0 4 2" xyaxes="-1 0 0 0 -.5 1" mode="trackcom"/>
|
||||||
|
<camera name="egocentric" pos=".3 0 .11" xyaxes="0 -1 0 .4 0 1" fovy="60"/>
|
||||||
|
<light name="light" pos="0 0 4" mode="trackcom"/>
|
||||||
|
|
||||||
|
<geom name="eye_r" type="cylinder" size=".05" fromto=".1 -.07 .12 .31 -.07 .08" mass="0"/>
|
||||||
|
<site name="pupil_r" type="sphere" size=".033" pos=".3 -.07 .08" zaxis="1 0 0" material="eye"/>
|
||||||
|
<geom name="eye_l" type="cylinder" size=".05" fromto=".1 .07 .12 .31 .07 .08" mass="0"/>
|
||||||
|
<site name="pupil_l" type="sphere" size=".033" pos=".3 .07 .08" zaxis="1 0 0" material="eye"/>
|
||||||
|
<site name="workspace" type="sphere" size=".3 .3 .3" material="site" pos=".8 0 -.2" group="3"/>
|
||||||
|
|
||||||
|
<site name="rf_00" class="rangefinder" fromto=".41 -.02 .11 .34 0 .115"/>
|
||||||
|
<site name="rf_01" class="rangefinder" fromto=".41 -.01 .11 .34 0 .115"/>
|
||||||
|
<site name="rf_02" class="rangefinder" fromto=".41 0 .11 .34 0 .115"/>
|
||||||
|
<site name="rf_03" class="rangefinder" fromto=".41 .01 .11 .34 0 .115"/>
|
||||||
|
<site name="rf_04" class="rangefinder" fromto=".41 .02 .11 .34 0 .115"/>
|
||||||
|
<site name="rf_10" class="rangefinder" fromto=".41 -.02 .1 .36 0 .11"/>
|
||||||
|
<site name="rf_11" class="rangefinder" fromto=".41 -.02 .1 .36 0 .11"/>
|
||||||
|
<site name="rf_12" class="rangefinder" fromto=".41 0 .1 .36 0 .11"/>
|
||||||
|
<site name="rf_13" class="rangefinder" fromto=".41 .01 .1 .36 0 .11"/>
|
||||||
|
<site name="rf_14" class="rangefinder" fromto=".41 .02 .1 .36 0 .11"/>
|
||||||
|
<site name="rf_20" class="rangefinder" fromto=".41 -.02 .09 .38 0 .105"/>
|
||||||
|
<site name="rf_21" class="rangefinder" fromto=".41 -.01 .09 .38 0 .105"/>
|
||||||
|
<site name="rf_22" class="rangefinder" fromto=".41 0 .09 .38 0 .105"/>
|
||||||
|
<site name="rf_23" class="rangefinder" fromto=".41 .01 .09 .38 0 .105"/>
|
||||||
|
<site name="rf_24" class="rangefinder" fromto=".41 .02 .09 .38 0 .105"/>
|
||||||
|
<site name="rf_30" class="rangefinder" fromto=".41 -.02 .08 .4 0 .1"/>
|
||||||
|
<site name="rf_31" class="rangefinder" fromto=".41 -.01 .08 .4 0 .1"/>
|
||||||
|
<site name="rf_32" class="rangefinder" fromto=".41 0 .08 .4 0 .1"/>
|
||||||
|
<site name="rf_33" class="rangefinder" fromto=".41 .01 .08 .4 0 .1"/>
|
||||||
|
<site name="rf_34" class="rangefinder" fromto=".41 .02 .08 .4 0 .1"/>
|
||||||
|
|
||||||
|
<geom name="torso" type="ellipsoid" size=".3 .27 .2" density="1000"/>
|
||||||
|
<site name="torso_touch" type="box" size=".26 .26 .26" rgba="0 0 1 0"/>
|
||||||
|
<site name="torso" size=".05" rgba="1 0 0 1" />
|
||||||
|
|
||||||
|
<body name="hip_front_left" pos=".2 .2 0" euler="0 0 45" childclass="hip">
|
||||||
|
<joint name="yaw_front_left" class="yaw"/>
|
||||||
|
<joint name="pitch_front_left" class="pitch"/>
|
||||||
|
<geom name="thigh_front_left"/>
|
||||||
|
<body name="knee_front_left" pos=".3 0 .11" childclass="knee">
|
||||||
|
<joint name="knee_front_left"/>
|
||||||
|
<geom name="shin_front_left"/>
|
||||||
|
<body name="ankle_front_left" pos=".25 0 -.25" childclass="ankle">
|
||||||
|
<joint name="ankle_front_left"/>
|
||||||
|
<geom name="foot_front_left"/>
|
||||||
|
<body name="toe_front_left" pos="0 0 -.3" childclass="toe">
|
||||||
|
<geom name="toe_front_left"/>
|
||||||
|
<site name="toe_front_left"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="hip_front_right" pos=".2 -.2 0" euler="0 0 -45" childclass="hip">
|
||||||
|
<joint name="yaw_front_right" class="yaw"/>
|
||||||
|
<joint name="pitch_front_right" class="pitch"/>
|
||||||
|
<geom name="thigh_front_right"/>
|
||||||
|
<body name="knee_front_right" pos=".3 0 .11" childclass="knee">
|
||||||
|
<joint name="knee_front_right"/>
|
||||||
|
<geom name="shin_front_right"/>
|
||||||
|
<body name="ankle_front_right" pos=".25 0 -.25" childclass="ankle">
|
||||||
|
<joint name="ankle_front_right"/>
|
||||||
|
<geom name="foot_front_right"/>
|
||||||
|
<body name="toe_front_right" pos="0 0 -.3" childclass="toe">
|
||||||
|
<geom name="toe_front_right"/>
|
||||||
|
<site name="toe_front_right"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="hip_back_right" pos="-.2 -.2 0" euler="0 0 -135" childclass="hip">
|
||||||
|
<joint name="yaw_back_right" class="yaw"/>
|
||||||
|
<joint name="pitch_back_right" class="pitch"/>
|
||||||
|
<geom name="thigh_back_right"/>
|
||||||
|
<body name="knee_back_right" pos=".3 0 .11" childclass="knee">
|
||||||
|
<joint name="knee_back_right"/>
|
||||||
|
<geom name="shin_back_right"/>
|
||||||
|
<body name="ankle_back_right" pos=".25 0 -.25" childclass="ankle">
|
||||||
|
<joint name="ankle_back_right"/>
|
||||||
|
<geom name="foot_back_right"/>
|
||||||
|
<body name="toe_back_right" pos="0 0 -.3" childclass="toe">
|
||||||
|
<geom name="toe_back_right"/>
|
||||||
|
<site name="toe_back_right"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="hip_back_left" pos="-.2 .2 0" euler="0 0 135" childclass="hip">
|
||||||
|
<joint name="yaw_back_left" class="yaw"/>
|
||||||
|
<joint name="pitch_back_left" class="pitch"/>
|
||||||
|
<geom name="thigh_back_left"/>
|
||||||
|
<body name="knee_back_left" pos=".3 0 .11" childclass="knee">
|
||||||
|
<joint name="knee_back_left"/>
|
||||||
|
<geom name="shin_back_left"/>
|
||||||
|
<body name="ankle_back_left" pos=".25 0 -.25" childclass="ankle">
|
||||||
|
<joint name="ankle_back_left"/>
|
||||||
|
<geom name="foot_back_left"/>
|
||||||
|
<body name="toe_back_left" pos="0 0 -.3" childclass="toe">
|
||||||
|
<geom name="toe_back_left"/>
|
||||||
|
<site name="toe_back_left"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="ball" pos="0 0 3">
|
||||||
|
<freejoint name="ball_root"/>
|
||||||
|
<geom name="ball" size=".15" material="ball" priority="1" condim="6" friction=".7 .005 .005"
|
||||||
|
solref="-10000 -30"/>
|
||||||
|
<light name="ball_light" pos="0 0 4" mode="trackcom"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<tendon>
|
||||||
|
<fixed name="coupling_front_left">
|
||||||
|
<joint joint="pitch_front_left" coef=".333"/>
|
||||||
|
<joint joint="knee_front_left" coef=".333"/>
|
||||||
|
<joint joint="ankle_front_left" coef=".333"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="coupling_front_right">
|
||||||
|
<joint joint="pitch_front_right" coef=".333"/>
|
||||||
|
<joint joint="knee_front_right" coef=".333"/>
|
||||||
|
<joint joint="ankle_front_right" coef=".333"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="coupling_back_right">
|
||||||
|
<joint joint="pitch_back_right" coef=".333"/>
|
||||||
|
<joint joint="knee_back_right" coef=".333"/>
|
||||||
|
<joint joint="ankle_back_right" coef=".333"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="coupling_back_left">
|
||||||
|
<joint joint="pitch_back_left" coef=".333"/>
|
||||||
|
<joint joint="knee_back_left" coef=".333"/>
|
||||||
|
<joint joint="ankle_back_left" coef=".333"/>
|
||||||
|
</fixed>
|
||||||
|
|
||||||
|
<fixed name="extend_front_left">
|
||||||
|
<joint joint="pitch_front_left" coef=".25"/>
|
||||||
|
<joint joint="knee_front_left" coef="-.5"/>
|
||||||
|
<joint joint="ankle_front_left" coef=".25"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="lift_front_left">
|
||||||
|
<joint joint="pitch_front_left" coef=".5"/>
|
||||||
|
<joint joint="ankle_front_left" coef="-.5"/>
|
||||||
|
</fixed>
|
||||||
|
|
||||||
|
<fixed name="extend_front_right">
|
||||||
|
<joint joint="pitch_front_right" coef=".25"/>
|
||||||
|
<joint joint="knee_front_right" coef="-.5"/>
|
||||||
|
<joint joint="ankle_front_right" coef=".25"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="lift_front_right">
|
||||||
|
<joint joint="pitch_front_right" coef=".5"/>
|
||||||
|
<joint joint="ankle_front_right" coef="-.5"/>
|
||||||
|
</fixed>
|
||||||
|
|
||||||
|
<fixed name="extend_back_right">
|
||||||
|
<joint joint="pitch_back_right" coef=".25"/>
|
||||||
|
<joint joint="knee_back_right" coef="-.5"/>
|
||||||
|
<joint joint="ankle_back_right" coef=".25"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="lift_back_right">
|
||||||
|
<joint joint="pitch_back_right" coef=".5"/>
|
||||||
|
<joint joint="ankle_back_right" coef="-.5"/>
|
||||||
|
</fixed>
|
||||||
|
|
||||||
|
<fixed name="extend_back_left">
|
||||||
|
<joint joint="pitch_back_left" coef=".25"/>
|
||||||
|
<joint joint="knee_back_left" coef="-.5"/>
|
||||||
|
<joint joint="ankle_back_left" coef=".25"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="lift_back_left">
|
||||||
|
<joint joint="pitch_back_left" coef=".5"/>
|
||||||
|
<joint joint="ankle_back_left" coef="-.5"/>
|
||||||
|
</fixed>
|
||||||
|
</tendon>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<tendon name="coupling_front_left" tendon1="coupling_front_left" class="coupling"/>
|
||||||
|
<tendon name="coupling_front_right" tendon1="coupling_front_right" class="coupling"/>
|
||||||
|
<tendon name="coupling_back_right" tendon1="coupling_back_right" class="coupling"/>
|
||||||
|
<tendon name="coupling_back_left" tendon1="coupling_back_left" class="coupling"/>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<general name="yaw_front_left" class="yaw_act" joint="yaw_front_left"/>
|
||||||
|
<general name="lift_front_left" class="lift_act" tendon="lift_front_left"/>
|
||||||
|
<general name="extend_front_left" class="extend_act" tendon="extend_front_left"/>
|
||||||
|
<general name="yaw_front_right" class="yaw_act" joint="yaw_front_right"/>
|
||||||
|
<general name="lift_front_right" class="lift_act" tendon="lift_front_right"/>
|
||||||
|
<general name="extend_front_right" class="extend_act" tendon="extend_front_right"/>
|
||||||
|
<general name="yaw_back_right" class="yaw_act" joint="yaw_back_right"/>
|
||||||
|
<general name="lift_back_right" class="lift_act" tendon="lift_back_right"/>
|
||||||
|
<general name="extend_back_right" class="extend_act" tendon="extend_back_right"/>
|
||||||
|
<general name="yaw_back_left" class="yaw_act" joint="yaw_back_left"/>
|
||||||
|
<general name="lift_back_left" class="lift_act" tendon="lift_back_left"/>
|
||||||
|
<general name="extend_back_left" class="extend_act" tendon="extend_back_left"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<accelerometer name="imu_accel" site="torso"/>
|
||||||
|
<gyro name="imu_gyro" site="torso"/>
|
||||||
|
<velocimeter name="velocimeter" site="torso"/>
|
||||||
|
<force name="force_toe_front_left" site="toe_front_left"/>
|
||||||
|
<force name="force_toe_front_right" site="toe_front_right"/>
|
||||||
|
<force name="force_toe_back_right" site="toe_back_right"/>
|
||||||
|
<force name="force_toe_back_left" site="toe_back_left"/>
|
||||||
|
<torque name="torque_toe_front_left" site="toe_front_left"/>
|
||||||
|
<torque name="torque_toe_front_right" site="toe_front_right"/>
|
||||||
|
<torque name="torque_toe_back_right" site="toe_back_right"/>
|
||||||
|
<torque name="torque_toe_back_left" site="toe_back_left"/>
|
||||||
|
<subtreecom name="center_of_mass" body="torso"/>
|
||||||
|
<rangefinder name="rf_00" site="rf_00"/>
|
||||||
|
<rangefinder name="rf_01" site="rf_01"/>
|
||||||
|
<rangefinder name="rf_02" site="rf_02"/>
|
||||||
|
<rangefinder name="rf_03" site="rf_03"/>
|
||||||
|
<rangefinder name="rf_04" site="rf_04"/>
|
||||||
|
<rangefinder name="rf_10" site="rf_10"/>
|
||||||
|
<rangefinder name="rf_11" site="rf_11"/>
|
||||||
|
<rangefinder name="rf_12" site="rf_12"/>
|
||||||
|
<rangefinder name="rf_13" site="rf_13"/>
|
||||||
|
<rangefinder name="rf_14" site="rf_14"/>
|
||||||
|
<rangefinder name="rf_20" site="rf_20"/>
|
||||||
|
<rangefinder name="rf_21" site="rf_21"/>
|
||||||
|
<rangefinder name="rf_22" site="rf_22"/>
|
||||||
|
<rangefinder name="rf_23" site="rf_23"/>
|
||||||
|
<rangefinder name="rf_24" site="rf_24"/>
|
||||||
|
<rangefinder name="rf_30" site="rf_30"/>
|
||||||
|
<rangefinder name="rf_31" site="rf_31"/>
|
||||||
|
<rangefinder name="rf_32" site="rf_32"/>
|
||||||
|
<rangefinder name="rf_33" site="rf_33"/>
|
||||||
|
<rangefinder name="rf_34" site="rf_34"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
|
|
116
local_dm_control_suite/reacher.py
Executable file
116
local_dm_control_suite/reacher.py
Executable file
@ -0,0 +1,116 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Reacher domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
_DEFAULT_TIME_LIMIT = 20
|
||||||
|
_BIG_TARGET = .05
|
||||||
|
_SMALL_TARGET = .015
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('reacher.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking', 'easy')
|
||||||
|
def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns reacher with sparse reward with 5e-2 tol and randomized target."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Reacher(target_size=_BIG_TARGET, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns reacher with sparse reward with 1e-2 tol and randomized target."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = Reacher(target_size=_SMALL_TARGET, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, **environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Reacher domain."""
|
||||||
|
|
||||||
|
def finger_to_target(self):
|
||||||
|
"""Returns the vector from target to finger in global coordinates."""
|
||||||
|
return (self.named.data.geom_xpos['target', :2] -
|
||||||
|
self.named.data.geom_xpos['finger', :2])
|
||||||
|
|
||||||
|
def finger_to_target_dist(self):
|
||||||
|
"""Returns the signed distance between the finger and target surface."""
|
||||||
|
return np.linalg.norm(self.finger_to_target())
|
||||||
|
|
||||||
|
|
||||||
|
class Reacher(base.Task):
|
||||||
|
"""A reacher `Task` to reach the target."""
|
||||||
|
|
||||||
|
def __init__(self, target_size, random=None):
|
||||||
|
"""Initialize an instance of `Reacher`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_size: A `float`, tolerance to determine whether finger reached the
|
||||||
|
target.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._target_size = target_size
|
||||||
|
super(Reacher, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
physics.named.model.geom_size['target', 0] = self._target_size
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
|
||||||
|
|
||||||
|
# Randomize target position
|
||||||
|
angle = self.random.uniform(0, 2 * np.pi)
|
||||||
|
radius = self.random.uniform(.05, .20)
|
||||||
|
physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
|
||||||
|
physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
|
||||||
|
|
||||||
|
super(Reacher, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of the state and the target position."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['position'] = physics.position()
|
||||||
|
obs['to_target'] = physics.finger_to_target()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
radii = physics.named.model.geom_size[['target', 'finger'], 0].sum()
|
||||||
|
return rewards.tolerance(physics.finger_to_target_dist(), (0, radii))
|
47
local_dm_control_suite/reacher.xml
Executable file
47
local_dm_control_suite/reacher.xml
Executable file
@ -0,0 +1,47 @@
|
|||||||
|
<mujoco model="two-link planar reacher">
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<option timestep="0.02">
|
||||||
|
<flag contact="disable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<joint type="hinge" axis="0 0 1" damping="0.01"/>
|
||||||
|
<motor gear=".05" ctrlrange="-1 1" ctrllimited="true"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light name="light" pos="0 0 1"/>
|
||||||
|
<camera name="fixed" pos="0 0 .75" quat="1 0 0 0"/>
|
||||||
|
<!-- Arena -->
|
||||||
|
<geom name="ground" type="plane" pos="0 0 0" size=".3 .3 10" material="grid"/>
|
||||||
|
<geom name="wall_x" type="plane" pos="-.3 0 .02" zaxis="1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||||
|
<geom name="wall_y" type="plane" pos="0 -.3 .02" zaxis="0 1 0" size=".3 .02 .02" material="decoration"/>
|
||||||
|
<geom name="wall_neg_x" type="plane" pos=".3 0 .02" zaxis="-1 0 0" size=".02 .3 .02" material="decoration"/>
|
||||||
|
<geom name="wall_neg_y" type="plane" pos="0 .3 .02" zaxis="0 -1 0" size=".3 .02 .02" material="decoration"/>
|
||||||
|
|
||||||
|
<!-- Arm -->
|
||||||
|
<geom name="root" type="cylinder" fromto="0 0 0 0 0 0.02" size=".011" material="decoration"/>
|
||||||
|
<body name="arm" pos="0 0 .01">
|
||||||
|
<geom name="arm" type="capsule" fromto="0 0 0 0.12 0 0" size=".01" material="self"/>
|
||||||
|
<joint name="shoulder"/>
|
||||||
|
<body name="hand" pos=".12 0 0">
|
||||||
|
<geom name="hand" type="capsule" fromto="0 0 0 0.1 0 0" size=".01" material="self"/>
|
||||||
|
<joint name="wrist" limited="true" range="-160 160"/>
|
||||||
|
<body name="finger" pos=".12 0 0">
|
||||||
|
<camera name="hand" pos="0 0 .2" mode="track"/>
|
||||||
|
<geom name="finger" type="sphere" size=".01" material="effector"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<!-- Target -->
|
||||||
|
<geom name="target" pos="0 0 .01" material="target" type="sphere" size=".05"/>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="shoulder" joint="shoulder"/>
|
||||||
|
<motor name="wrist" joint="wrist"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
208
local_dm_control_suite/stacker.py
Executable file
208
local_dm_control_suite/stacker.py
Executable file
@ -0,0 +1,208 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Planar Stacker domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
from dm_control.utils import xml_tools
|
||||||
|
|
||||||
|
from lxml import etree
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
_CLOSE = .01 # (Meters) Distance below which a thing is considered close.
|
||||||
|
_CONTROL_TIMESTEP = .01 # (Seconds)
|
||||||
|
_TIME_LIMIT = 10 # (Seconds)
|
||||||
|
_ARM_JOINTS = ['arm_root', 'arm_shoulder', 'arm_elbow', 'arm_wrist',
|
||||||
|
'finger', 'fingertip', 'thumb', 'thumbtip']
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def make_model(n_boxes):
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
xml_string = common.read_model('stacker.xml')
|
||||||
|
parser = etree.XMLParser(remove_blank_text=True)
|
||||||
|
mjcf = etree.XML(xml_string, parser)
|
||||||
|
|
||||||
|
# Remove unused boxes
|
||||||
|
for b in range(n_boxes, 4):
|
||||||
|
box = xml_tools.find_element(mjcf, 'body', 'box' + str(b))
|
||||||
|
box.getparent().remove(box)
|
||||||
|
|
||||||
|
return etree.tostring(mjcf, pretty_print=True), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('hard')
|
||||||
|
def stack_2(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns stacker task with 2 boxes."""
|
||||||
|
n_boxes = 2
|
||||||
|
physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
|
||||||
|
task = Stack(n_boxes=n_boxes,
|
||||||
|
fully_observable=fully_observable,
|
||||||
|
random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('hard')
|
||||||
|
def stack_4(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns stacker task with 4 boxes."""
|
||||||
|
n_boxes = 4
|
||||||
|
physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
|
||||||
|
task = Stack(n_boxes=n_boxes,
|
||||||
|
fully_observable=fully_observable,
|
||||||
|
random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics with additional features for the Planar Manipulator domain."""
|
||||||
|
|
||||||
|
def bounded_joint_pos(self, joint_names):
|
||||||
|
"""Returns joint positions as (sin, cos) values."""
|
||||||
|
joint_pos = self.named.data.qpos[joint_names]
|
||||||
|
return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T
|
||||||
|
|
||||||
|
def joint_vel(self, joint_names):
|
||||||
|
"""Returns joint velocities."""
|
||||||
|
return self.named.data.qvel[joint_names]
|
||||||
|
|
||||||
|
def body_2d_pose(self, body_names, orientation=True):
|
||||||
|
"""Returns positions and/or orientations of bodies."""
|
||||||
|
if not isinstance(body_names, str):
|
||||||
|
body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices.
|
||||||
|
pos = self.named.data.xpos[body_names, ['x', 'z']]
|
||||||
|
if orientation:
|
||||||
|
ori = self.named.data.xquat[body_names, ['qw', 'qy']]
|
||||||
|
return np.hstack([pos, ori])
|
||||||
|
else:
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def touch(self):
|
||||||
|
return np.log1p(self.data.sensordata)
|
||||||
|
|
||||||
|
def site_distance(self, site1, site2):
|
||||||
|
site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0)
|
||||||
|
return np.linalg.norm(site1_to_site2)
|
||||||
|
|
||||||
|
|
||||||
|
class Stack(base.Task):
|
||||||
|
"""A Stack `Task`: stack the boxes."""
|
||||||
|
|
||||||
|
def __init__(self, n_boxes, fully_observable, random=None):
|
||||||
|
"""Initialize an instance of the `Stack` task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_boxes: An `int`, number of boxes to stack.
|
||||||
|
fully_observable: A `bool`, whether the observation should contain the
|
||||||
|
positions and velocities of the boxes and the location of the target.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._n_boxes = n_boxes
|
||||||
|
self._box_names = ['box' + str(b) for b in range(n_boxes)]
|
||||||
|
self._box_joint_names = []
|
||||||
|
for name in self._box_names:
|
||||||
|
for dim in 'xyz':
|
||||||
|
self._box_joint_names.append('_'.join([name, dim]))
|
||||||
|
self._fully_observable = fully_observable
|
||||||
|
super(Stack, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
# Local aliases
|
||||||
|
randint = self.random.randint
|
||||||
|
uniform = self.random.uniform
|
||||||
|
model = physics.named.model
|
||||||
|
data = physics.named.data
|
||||||
|
|
||||||
|
# Find a collision-free random initial configuration.
|
||||||
|
penetrating = True
|
||||||
|
while penetrating:
|
||||||
|
|
||||||
|
# Randomise angles of arm joints.
|
||||||
|
is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
|
||||||
|
joint_range = model.jnt_range[_ARM_JOINTS]
|
||||||
|
lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
|
||||||
|
upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
|
||||||
|
angles = uniform(lower_limits, upper_limits)
|
||||||
|
data.qpos[_ARM_JOINTS] = angles
|
||||||
|
|
||||||
|
# Symmetrize hand.
|
||||||
|
data.qpos['finger'] = data.qpos['thumb']
|
||||||
|
|
||||||
|
# Randomise target location.
|
||||||
|
target_height = 2*randint(self._n_boxes) + 1
|
||||||
|
box_size = model.geom_size['target', 0]
|
||||||
|
model.body_pos['target', 'z'] = box_size * target_height
|
||||||
|
model.body_pos['target', 'x'] = uniform(-.37, .37)
|
||||||
|
|
||||||
|
# Randomise box locations.
|
||||||
|
for name in self._box_names:
|
||||||
|
data.qpos[name + '_x'] = uniform(.1, .3)
|
||||||
|
data.qpos[name + '_z'] = uniform(0, .7)
|
||||||
|
data.qpos[name + '_y'] = uniform(0, 2*np.pi)
|
||||||
|
|
||||||
|
# Check for collisions.
|
||||||
|
physics.after_reset()
|
||||||
|
penetrating = physics.data.ncon > 0
|
||||||
|
|
||||||
|
super(Stack, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns either features or only sensors (to be used with pixels)."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['arm_pos'] = physics.bounded_joint_pos(_ARM_JOINTS)
|
||||||
|
obs['arm_vel'] = physics.joint_vel(_ARM_JOINTS)
|
||||||
|
obs['touch'] = physics.touch()
|
||||||
|
if self._fully_observable:
|
||||||
|
obs['hand_pos'] = physics.body_2d_pose('hand')
|
||||||
|
obs['box_pos'] = physics.body_2d_pose(self._box_names)
|
||||||
|
obs['box_vel'] = physics.joint_vel(self._box_joint_names)
|
||||||
|
obs['target_pos'] = physics.body_2d_pose('target', orientation=False)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
box_size = physics.named.model.geom_size['target', 0]
|
||||||
|
min_box_to_target_distance = min(physics.site_distance(name, 'target')
|
||||||
|
for name in self._box_names)
|
||||||
|
box_is_close = rewards.tolerance(min_box_to_target_distance,
|
||||||
|
margin=2*box_size)
|
||||||
|
hand_to_target_distance = physics.site_distance('grasp', 'target')
|
||||||
|
hand_is_far = rewards.tolerance(hand_to_target_distance,
|
||||||
|
bounds=(.1, float('inf')),
|
||||||
|
margin=_CLOSE)
|
||||||
|
return box_is_close * hand_is_far
|
193
local_dm_control_suite/stacker.xml
Executable file
193
local_dm_control_suite/stacker.xml
Executable file
@ -0,0 +1,193 @@
|
|||||||
|
<mujoco model="planar stacker">
|
||||||
|
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials_white_floor.xml"/>
|
||||||
|
<asset>
|
||||||
|
<texture name="background" builtin="flat" type="2d" mark="random" markrgb="1 1 1" width="800" height="800" rgb1=".2 .3 .4"/>
|
||||||
|
<material name="background" texture="background" texrepeat="1 1" texuniform="true"/>
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
<visual>
|
||||||
|
<map shadowclip=".5"/>
|
||||||
|
<quality shadowsize="2048"/>
|
||||||
|
</visual>>
|
||||||
|
|
||||||
|
<option timestep="0.001" cone="elliptic"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<geom friction=".7" solimp="0.9 0.97 0.001" solref=".01 1"/>
|
||||||
|
<joint solimplimit="0 0.99 0.01" solreflimit=".005 1"/>
|
||||||
|
<general ctrllimited="true"/>
|
||||||
|
<tendon width="0.01"/>
|
||||||
|
<site size=".003 .003 .003" material="site" group="3"/>
|
||||||
|
|
||||||
|
<default class="arm">
|
||||||
|
<geom type="capsule" material="self" density="500"/>
|
||||||
|
<joint type="hinge" pos="0 0 0" axis="0 -1 0" limited="true"/>
|
||||||
|
<default class="hand">
|
||||||
|
<joint damping=".5" range="-10 60"/>
|
||||||
|
<geom size=".008"/>
|
||||||
|
<site type="box" size=".018 .005 .005" pos=".022 0 -.002" euler="0 15 0" group="4"/>
|
||||||
|
<default class="fingertip">
|
||||||
|
<geom type="sphere" size=".008" material="effector"/>
|
||||||
|
<joint damping=".01" stiffness=".01" range="-40 20"/>
|
||||||
|
<site size=".012 .005 .008" pos=".003 0 .003" group="4" euler="0 0 0"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="object">
|
||||||
|
<geom material="self"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="task">
|
||||||
|
<site rgba="0 0 0 0"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="obstacle">
|
||||||
|
<geom material="decoration" friction="0"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="ghost">
|
||||||
|
<geom material="target" contype="0" conaffinity="0"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<!-- Arena -->
|
||||||
|
<light name="light" directional="true" diffuse=".6 .6 .6" pos="0 0 1" specular=".3 .3 .3"/>
|
||||||
|
<geom name="floor" type="plane" pos="0 0 0" size=".4 .2 10" material="grid"/>
|
||||||
|
<geom name="wall1" type="plane" pos="-.682843 0 .282843" size=".4 .2 10" material="grid" zaxis="1 0 1"/>
|
||||||
|
<geom name="wall2" type="plane" pos=".682843 0 .282843" size=".4 .2 10" material="grid" zaxis="-1 0 1"/>
|
||||||
|
<geom name="background" type="plane" pos="0 .2 .5" size="1 .5 10" material="background" zaxis="0 -1 0"/>
|
||||||
|
<camera name="fixed" pos="0 -16 .4" xyaxes="1 0 0 0 0 1" fovy="4"/>
|
||||||
|
|
||||||
|
<!-- Arm -->
|
||||||
|
<geom name="arm_root" type="cylinder" fromto="0 -.022 .4 0 .022 .4" size=".024"
|
||||||
|
material="decoration" contype="0" conaffinity="0"/>
|
||||||
|
<body name="upper_arm" pos="0 0 .4" childclass="arm">
|
||||||
|
<joint name="arm_root" damping="2" limited="false"/>
|
||||||
|
<geom name="upper_arm" size=".02" fromto="0 0 0 0 0 .18"/>
|
||||||
|
<body name="middle_arm" pos="0 0 .18" childclass="arm">
|
||||||
|
<joint name="arm_shoulder" damping="1.5" range="-160 160"/>
|
||||||
|
<geom name="middle_arm" size=".017" fromto="0 0 0 0 0 .15"/>
|
||||||
|
<body name="lower_arm" pos="0 0 .15">
|
||||||
|
<joint name="arm_elbow" damping="1" range="-160 160"/>
|
||||||
|
<geom name="lower_arm" size=".014" fromto="0 0 0 0 0 .12"/>
|
||||||
|
<body name="hand" pos="0 0 .12">
|
||||||
|
<joint name="arm_wrist" damping=".5" range="-140 140" />
|
||||||
|
<geom name="hand" size=".011" fromto="0 0 0 0 0 .03"/>
|
||||||
|
<geom name="palm1" fromto="0 0 .03 .03 0 .045" class="hand"/>
|
||||||
|
<geom name="palm2" fromto="0 0 .03 -.03 0 .045" class="hand"/>
|
||||||
|
<site name="grasp" pos="0 0 .065"/>
|
||||||
|
<body name="pinch site" pos="0 0 .090">
|
||||||
|
<site name="pinch"/>
|
||||||
|
<inertial pos="0 0 0" mass="1e-6" diaginertia="1e-12 1e-12 1e-12"/>
|
||||||
|
<camera name="hand" pos="0 -.3 0" xyaxes="1 0 0 0 0 1" mode="track"/>
|
||||||
|
</body>
|
||||||
|
<site name="palm_touch" type="box" group="4" size=".025 .005 .008" pos="0 0 .043"/>
|
||||||
|
|
||||||
|
<body name="thumb" pos=".03 0 .045" euler="0 -90 0" childclass="hand">
|
||||||
|
<joint name="thumb"/>
|
||||||
|
<geom name="thumb1" fromto="0 0 0 .02 0 -.01" size=".007"/>
|
||||||
|
<geom name="thumb2" fromto=".02 0 -.01 .04 0 -.01" size=".007"/>
|
||||||
|
<site name="thumb_touch" group="4"/>
|
||||||
|
<body name="thumbtip" pos=".05 0 -.01" childclass="fingertip">
|
||||||
|
<joint name="thumbtip"/>
|
||||||
|
<geom name="thumbtip1" pos="-.003 0 0" />
|
||||||
|
<geom name="thumbtip2" pos=".003 0 0" />
|
||||||
|
<site name="thumbtip_touch" group="4"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="finger" pos="-.03 0 .045" euler="0 90 180" childclass="hand">
|
||||||
|
<joint name="finger"/>
|
||||||
|
<geom name="finger1" fromto="0 0 0 .02 0 -.01" size=".007" />
|
||||||
|
<geom name="finger2" fromto=".02 0 -.01 .04 0 -.01" size=".007"/>
|
||||||
|
<site name="finger_touch"/>
|
||||||
|
<body name="fingertip" pos=".05 0 -.01" childclass="fingertip">
|
||||||
|
<joint name="fingertip"/>
|
||||||
|
<geom name="fingertip1" pos="-.003 0 0" />
|
||||||
|
<geom name="fingertip2" pos=".003 0 0" />
|
||||||
|
<site name="fingertip_touch"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<!-- props -->
|
||||||
|
<body name="box0" pos=".5 0 .4" childclass="object">
|
||||||
|
<joint name="box0_x" type="slide" axis="1 0 0" ref=".5"/>
|
||||||
|
<joint name="box0_z" type="slide" axis="0 0 1" ref=".4"/>
|
||||||
|
<joint name="box0_y" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="box0" type="box" size=".022 .022 .022" />
|
||||||
|
<site name="box0" type="sphere"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="box1" pos=".4 0 .4" childclass="object">
|
||||||
|
<joint name="box1_x" type="slide" axis="1 0 0" ref=".4"/>
|
||||||
|
<joint name="box1_z" type="slide" axis="0 0 1" ref=".4"/>
|
||||||
|
<joint name="box1_y" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="box1" type="box" size=".022 .022 .022" />
|
||||||
|
<site name="box1" type="sphere"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="box2" pos=".3 0 .4" childclass="object">
|
||||||
|
<joint name="box2_x" type="slide" axis="1 0 0" ref=".3"/>
|
||||||
|
<joint name="box2_z" type="slide" axis="0 0 1" ref=".4"/>
|
||||||
|
<joint name="box2_y" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="box2" type="box" size=".022 .022 .022" />
|
||||||
|
<site name="box2" type="sphere"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="box3" pos=".2 0 .4" childclass="object">
|
||||||
|
<joint name="box3_x" type="slide" axis="1 0 0" ref=".2"/>
|
||||||
|
<joint name="box3_z" type="slide" axis="0 0 1" ref=".4"/>
|
||||||
|
<joint name="box3_y" type="hinge" axis="0 1 0"/>
|
||||||
|
<geom name="box3" type="box" size=".022 .022 .022" />
|
||||||
|
<site name="box3" type="sphere"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
|
||||||
|
<!-- targets -->
|
||||||
|
<body name="target" pos="0 .001 .022" childclass="ghost">
|
||||||
|
<geom name="target" type="box" size=".022 .022 .022" />
|
||||||
|
<site name="target" type="sphere"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<tendon>
|
||||||
|
<fixed name="grasp">
|
||||||
|
<joint joint="thumb" coef=".5"/>
|
||||||
|
<joint joint="finger" coef=".5"/>
|
||||||
|
</fixed>
|
||||||
|
<fixed name="coupling">
|
||||||
|
<joint joint="thumb" coef="-.5"/>
|
||||||
|
<joint joint="finger" coef=".5"/>
|
||||||
|
</fixed>
|
||||||
|
</tendon>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<tendon name="coupling" tendon1="coupling" solimp="0.95 0.99 0.001" solref=".005 .5"/>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<touch name="palm_touch" site="palm_touch"/>
|
||||||
|
<touch name="finger_touch" site="finger_touch"/>
|
||||||
|
<touch name="thumb_touch" site="thumb_touch"/>
|
||||||
|
<touch name="fingertip_touch" site="fingertip_touch"/>
|
||||||
|
<touch name="thumbtip_touch" site="thumbtip_touch"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="root" joint="arm_root" ctrlrange="-1 1" gear="12"/>
|
||||||
|
<motor name="shoulder" joint="arm_shoulder" ctrlrange="-1 1" gear="8"/>
|
||||||
|
<motor name="elbow" joint="arm_elbow" ctrlrange="-1 1" gear="4"/>
|
||||||
|
<motor name="wrist" joint="arm_wrist" ctrlrange="-1 1" gear="2"/>
|
||||||
|
<motor name="grasp" tendon="grasp" ctrlrange="-1 1" gear="2"/>
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
</mujoco>
|
215
local_dm_control_suite/swimmer.py
Executable file
215
local_dm_control_suite/swimmer.py
Executable file
@ -0,0 +1,215 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Procedurally generated Swimmer domain."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
from lxml import etree
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 30
|
||||||
|
_CONTROL_TIMESTEP = .03 # (Seconds)
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets(n_joints):
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_joints: An integer specifying the number of joints in the swimmer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of
|
||||||
|
`{filename: contents_string}` pairs.
|
||||||
|
"""
|
||||||
|
return _make_model(n_joints), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swimmer6(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns a 6-link swimmer."""
|
||||||
|
return _make_swimmer(6, time_limit, random=random,
|
||||||
|
environment_kwargs=environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def swimmer15(time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns a 15-link swimmer."""
|
||||||
|
return _make_swimmer(15, time_limit, random=random,
|
||||||
|
environment_kwargs=environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def swimmer(n_links=3, time_limit=_DEFAULT_TIME_LIMIT,
|
||||||
|
random=None, environment_kwargs=None):
|
||||||
|
"""Returns a swimmer with n links."""
|
||||||
|
return _make_swimmer(n_links, time_limit, random=random,
|
||||||
|
environment_kwargs=environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_swimmer(n_joints, time_limit=_DEFAULT_TIME_LIMIT, random=None,
|
||||||
|
environment_kwargs=None):
|
||||||
|
"""Returns a swimmer control environment."""
|
||||||
|
model_string, assets = get_model_and_assets(n_joints)
|
||||||
|
physics = Physics.from_xml_string(model_string, assets=assets)
|
||||||
|
task = Swimmer(random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model(n_bodies):
|
||||||
|
"""Generates an xml string defining a swimmer with `n_bodies` bodies."""
|
||||||
|
if n_bodies < 3:
|
||||||
|
raise ValueError('At least 3 bodies required. Received {}'.format(n_bodies))
|
||||||
|
mjcf = etree.fromstring(common.read_model('swimmer.xml'))
|
||||||
|
head_body = mjcf.find('./worldbody/body')
|
||||||
|
actuator = etree.SubElement(mjcf, 'actuator')
|
||||||
|
sensor = etree.SubElement(mjcf, 'sensor')
|
||||||
|
|
||||||
|
parent = head_body
|
||||||
|
for body_index in range(n_bodies - 1):
|
||||||
|
site_name = 'site_{}'.format(body_index)
|
||||||
|
child = _make_body(body_index=body_index)
|
||||||
|
child.append(etree.Element('site', name=site_name))
|
||||||
|
joint_name = 'joint_{}'.format(body_index)
|
||||||
|
joint_limit = 360.0/n_bodies
|
||||||
|
joint_range = '{} {}'.format(-joint_limit, joint_limit)
|
||||||
|
child.append(etree.Element('joint', {'name': joint_name,
|
||||||
|
'range': joint_range}))
|
||||||
|
motor_name = 'motor_{}'.format(body_index)
|
||||||
|
actuator.append(etree.Element('motor', name=motor_name, joint=joint_name))
|
||||||
|
velocimeter_name = 'velocimeter_{}'.format(body_index)
|
||||||
|
sensor.append(etree.Element('velocimeter', name=velocimeter_name,
|
||||||
|
site=site_name))
|
||||||
|
gyro_name = 'gyro_{}'.format(body_index)
|
||||||
|
sensor.append(etree.Element('gyro', name=gyro_name, site=site_name))
|
||||||
|
parent.append(child)
|
||||||
|
parent = child
|
||||||
|
|
||||||
|
# Move tracking cameras further away from the swimmer according to its length.
|
||||||
|
cameras = mjcf.findall('./worldbody/body/camera')
|
||||||
|
scale = n_bodies / 6.0
|
||||||
|
for cam in cameras:
|
||||||
|
if cam.get('mode') == 'trackcom':
|
||||||
|
old_pos = cam.get('pos').split(' ')
|
||||||
|
new_pos = ' '.join([str(float(dim) * scale) for dim in old_pos])
|
||||||
|
cam.set('pos', new_pos)
|
||||||
|
|
||||||
|
return etree.tostring(mjcf, pretty_print=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_body(body_index):
|
||||||
|
"""Generates an xml string defining a single physical body."""
|
||||||
|
body_name = 'segment_{}'.format(body_index)
|
||||||
|
visual_name = 'visual_{}'.format(body_index)
|
||||||
|
inertial_name = 'inertial_{}'.format(body_index)
|
||||||
|
body = etree.Element('body', name=body_name)
|
||||||
|
body.set('pos', '0 .1 0')
|
||||||
|
etree.SubElement(body, 'geom', {'class': 'visual', 'name': visual_name})
|
||||||
|
etree.SubElement(body, 'geom', {'class': 'inertial', 'name': inertial_name})
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the swimmer domain."""
|
||||||
|
|
||||||
|
def nose_to_target(self):
|
||||||
|
"""Returns a vector from nose to target in local coordinate of the head."""
|
||||||
|
nose_to_target = (self.named.data.geom_xpos['target'] -
|
||||||
|
self.named.data.geom_xpos['nose'])
|
||||||
|
head_orientation = self.named.data.xmat['head'].reshape(3, 3)
|
||||||
|
return nose_to_target.dot(head_orientation)[:2]
|
||||||
|
|
||||||
|
def nose_to_target_dist(self):
|
||||||
|
"""Returns the distance from the nose to the target."""
|
||||||
|
return np.linalg.norm(self.nose_to_target())
|
||||||
|
|
||||||
|
def body_velocities(self):
|
||||||
|
"""Returns local body velocities: x,y linear, z rotational."""
|
||||||
|
xvel_local = self.data.sensordata[12:].reshape((-1, 6))
|
||||||
|
vx_vy_wz = [0, 1, 5] # Indices for linear x,y vels and rotational z vel.
|
||||||
|
return xvel_local[:, vx_vy_wz].ravel()
|
||||||
|
|
||||||
|
def joints(self):
|
||||||
|
"""Returns all internal joint angles (excluding root joints)."""
|
||||||
|
return self.data.qpos[3:].copy()
|
||||||
|
|
||||||
|
|
||||||
|
class Swimmer(base.Task):
|
||||||
|
"""A swimmer `Task` to reach the target or just swim."""
|
||||||
|
|
||||||
|
def __init__(self, random=None):
|
||||||
|
"""Initializes an instance of `Swimmer`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
super(Swimmer, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
Initializes the swimmer orientation to [-pi, pi) and the relative joint
|
||||||
|
angle of each joint uniformly within its range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
"""
|
||||||
|
# Random joint angles:
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
|
||||||
|
# Random target position.
|
||||||
|
close_target = self.random.rand() < .2 # Probability of a close target.
|
||||||
|
target_box = .3 if close_target else 2
|
||||||
|
xpos, ypos = self.random.uniform(-target_box, target_box, size=2)
|
||||||
|
physics.named.model.geom_pos['target', 'x'] = xpos
|
||||||
|
physics.named.model.geom_pos['target', 'y'] = ypos
|
||||||
|
physics.named.model.light_pos['target_light', 'x'] = xpos
|
||||||
|
physics.named.model.light_pos['target_light', 'y'] = ypos
|
||||||
|
|
||||||
|
super(Swimmer, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of joint angles, body velocities and target."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['joints'] = physics.joints()
|
||||||
|
obs['to_target'] = physics.nose_to_target()
|
||||||
|
obs['body_velocities'] = physics.body_velocities()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a smooth reward."""
|
||||||
|
target_size = physics.named.model.geom_size['target', 0]
|
||||||
|
return rewards.tolerance(physics.nose_to_target_dist(),
|
||||||
|
bounds=(0, target_size),
|
||||||
|
margin=5*target_size,
|
||||||
|
sigmoid='long_tail')
|
57
local_dm_control_suite/swimmer.xml
Executable file
57
local_dm_control_suite/swimmer.xml
Executable file
@ -0,0 +1,57 @@
|
|||||||
|
<mujoco model="swimmer">
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials.xml"/>
|
||||||
|
|
||||||
|
<option timestep="0.002" density="3000">
|
||||||
|
<flag contact="disable"/>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<default class="swimmer">
|
||||||
|
<joint type="hinge" pos="0 -.05 0" axis="0 0 1" limited="true" solreflimit=".05 1" solimplimit="0 .8 .1" armature="1e-6"/>
|
||||||
|
<default class="inertial">
|
||||||
|
<geom type="box" size=".001 .05 .01" rgba="0 0 0 0" mass=".01"/>
|
||||||
|
</default>
|
||||||
|
<default class="visual">
|
||||||
|
<geom type="capsule" size=".01" fromto="0 -.05 0 0 .05 0" material="self" mass="0"/>
|
||||||
|
</default>
|
||||||
|
<site size=".01" rgba="0 0 0 0"/>
|
||||||
|
</default>
|
||||||
|
<default class="free">
|
||||||
|
<joint limited="false" stiffness="0" armature="0"/>
|
||||||
|
</default>
|
||||||
|
<motor gear="5e-4" ctrllimited="true" ctrlrange="-1 1"/>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<geom name="ground" type="plane" size="2 2 0.1" material="grid"/>
|
||||||
|
<body name="head" pos="0 0 .05" childclass="swimmer">
|
||||||
|
<light name="light_1" diffuse=".8 .8 .8" pos="0 0 1.5"/>
|
||||||
|
<geom name="head" type="ellipsoid" size=".02 .04 .017" pos="0 -.022 0" material="self" mass="0"/>
|
||||||
|
<geom name="nose" type="sphere" pos="0 -.06 0" size=".004" material="effector" mass="0"/>
|
||||||
|
<geom name="eyes" type="capsule" fromto="-.006 -.054 .005 .006 -.054 .005" size=".004" material="eye" mass="0"/>
|
||||||
|
<camera name="tracking1" pos="0 -.2 .5" xyaxes="1 0 0 0 1 1" mode="trackcom" fovy="60"/>
|
||||||
|
<camera name="tracking2" pos="-.9 .5 .15" xyaxes="0 -1 0 .3 0 1" mode="trackcom" fovy="60"/>
|
||||||
|
<camera name="eyes" pos="0 -.058 .005" xyaxes="-1 0 0 0 0 1"/>
|
||||||
|
<joint name="rootx" class="free" type="slide" axis="1 0 0" pos="0 -.05 0"/>
|
||||||
|
<joint name="rooty" class="free" type="slide" axis="0 1 0" pos="0 -.05 0"/>
|
||||||
|
<joint name="rootz" class="free" type="hinge" axis="0 0 1" pos="0 -.05 0"/>
|
||||||
|
<geom name="inertial" class="inertial"/>
|
||||||
|
<geom name="visual" class="visual"/>
|
||||||
|
<site name="head"/>
|
||||||
|
</body>
|
||||||
|
<geom name="target" type="sphere" pos="1 1 .05" size=".1" material="target"/>
|
||||||
|
<light name="target_light" diffuse="1 1 1" pos="1 1 1.5"/>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<framepos name="nose_pos" objtype="geom" objname="nose"/>
|
||||||
|
<framepos name="target_pos" objtype="geom" objname="target"/>
|
||||||
|
<framexaxis name="head_xaxis" objtype="xbody" objname="head"/>
|
||||||
|
<frameyaxis name="head_yaxis" objtype="xbody" objname="head"/>
|
||||||
|
<velocimeter name="head_vel" site="head"/>
|
||||||
|
<gyro name="head_gyro" site="head"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
</mujoco>
|
292
local_dm_control_suite/tests/domains_test.py
Executable file
292
local_dm_control_suite/tests/domains_test.py
Executable file
@ -0,0 +1,292 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tests for dm_control.suite domains."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Internal dependencies.
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
from dm_control import suite
|
||||||
|
from dm_control.rl import control
|
||||||
|
import mock
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
from six.moves import range
|
||||||
|
from six.moves import zip
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_random_policy(action_spec, random=None):
|
||||||
|
lower_bounds = action_spec.minimum
|
||||||
|
upper_bounds = action_spec.maximum
|
||||||
|
# Draw values between -1 and 1 for unbounded actions.
|
||||||
|
lower_bounds = np.where(np.isinf(lower_bounds), -1.0, lower_bounds)
|
||||||
|
upper_bounds = np.where(np.isinf(upper_bounds), 1.0, upper_bounds)
|
||||||
|
random_state = np.random.RandomState(random)
|
||||||
|
def policy(time_step):
|
||||||
|
del time_step # Unused.
|
||||||
|
return random_state.uniform(lower_bounds, upper_bounds)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
def step_environment(env, policy, num_episodes=5, max_steps_per_episode=10):
|
||||||
|
for _ in range(num_episodes):
|
||||||
|
step_count = 0
|
||||||
|
time_step = env.reset()
|
||||||
|
yield time_step
|
||||||
|
while not time_step.last():
|
||||||
|
action = policy(time_step)
|
||||||
|
time_step = env.step(action)
|
||||||
|
step_count += 1
|
||||||
|
yield time_step
|
||||||
|
if step_count >= max_steps_per_episode:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def make_trajectory(domain, task, seed, **trajectory_kwargs):
|
||||||
|
env = suite.load(domain, task, task_kwargs={'random': seed})
|
||||||
|
policy = uniform_random_policy(env.action_spec(), random=seed)
|
||||||
|
return step_environment(env, policy, **trajectory_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DomainTest(parameterized.TestCase):
|
||||||
|
"""Tests run on all the tasks registered."""
|
||||||
|
|
||||||
|
def test_constants(self):
|
||||||
|
num_tasks = sum(len(tasks) for tasks in
|
||||||
|
six.itervalues(suite.TASKS_BY_DOMAIN))
|
||||||
|
|
||||||
|
self.assertLen(suite.ALL_TASKS, num_tasks)
|
||||||
|
|
||||||
|
def _validate_observation(self, observation_dict, observation_spec):
|
||||||
|
obs = observation_dict.copy()
|
||||||
|
for name, spec in six.iteritems(observation_spec):
|
||||||
|
arr = obs.pop(name)
|
||||||
|
self.assertEqual(arr.shape, spec.shape)
|
||||||
|
self.assertEqual(arr.dtype, spec.dtype)
|
||||||
|
self.assertTrue(
|
||||||
|
np.all(np.isfinite(arr)),
|
||||||
|
msg='{!r} has non-finite value(s): {!r}'.format(name, arr))
|
||||||
|
self.assertEmpty(
|
||||||
|
obs,
|
||||||
|
msg='Observation contains arrays(s) that are not in the spec: {!r}'
|
||||||
|
.format(obs))
|
||||||
|
|
||||||
|
def _validate_reward_range(self, time_step):
|
||||||
|
if time_step.first():
|
||||||
|
self.assertIsNone(time_step.reward)
|
||||||
|
else:
|
||||||
|
self.assertIsInstance(time_step.reward, float)
|
||||||
|
self.assertBetween(time_step.reward, 0, 1)
|
||||||
|
|
||||||
|
def _validate_discount(self, time_step):
|
||||||
|
if time_step.first():
|
||||||
|
self.assertIsNone(time_step.discount)
|
||||||
|
else:
|
||||||
|
self.assertIsInstance(time_step.discount, float)
|
||||||
|
self.assertBetween(time_step.discount, 0, 1)
|
||||||
|
|
||||||
|
def _validate_control_range(self, lower_bounds, upper_bounds):
|
||||||
|
for b in lower_bounds:
|
||||||
|
self.assertEqual(b, -1.0)
|
||||||
|
for b in upper_bounds:
|
||||||
|
self.assertEqual(b, 1.0)
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_components_have_names(self, domain, task):
|
||||||
|
env = suite.load(domain, task)
|
||||||
|
model = env.physics.model
|
||||||
|
|
||||||
|
object_types_and_size_fields = [
|
||||||
|
('body', 'nbody'),
|
||||||
|
('joint', 'njnt'),
|
||||||
|
('geom', 'ngeom'),
|
||||||
|
('site', 'nsite'),
|
||||||
|
('camera', 'ncam'),
|
||||||
|
('light', 'nlight'),
|
||||||
|
('mesh', 'nmesh'),
|
||||||
|
('hfield', 'nhfield'),
|
||||||
|
('texture', 'ntex'),
|
||||||
|
('material', 'nmat'),
|
||||||
|
('equality', 'neq'),
|
||||||
|
('tendon', 'ntendon'),
|
||||||
|
('actuator', 'nu'),
|
||||||
|
('sensor', 'nsensor'),
|
||||||
|
('numeric', 'nnumeric'),
|
||||||
|
('text', 'ntext'),
|
||||||
|
('tuple', 'ntuple'),
|
||||||
|
]
|
||||||
|
for object_type, size_field in object_types_and_size_fields:
|
||||||
|
for idx in range(getattr(model, size_field)):
|
||||||
|
object_name = model.id2name(idx, object_type)
|
||||||
|
self.assertNotEqual(object_name, '',
|
||||||
|
msg='Model {!r} contains unnamed {!r} with ID {}.'
|
||||||
|
.format(model.name, object_type, idx))
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_model_has_at_least_2_cameras(self, domain, task):
|
||||||
|
env = suite.load(domain, task)
|
||||||
|
model = env.physics.model
|
||||||
|
self.assertGreaterEqual(model.ncam, 2,
|
||||||
|
'Model {!r} should have at least 2 cameras, has {}.'
|
||||||
|
.format(model.name, model.ncam))
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_task_conforms_to_spec(self, domain, task):
|
||||||
|
"""Tests that the environment timesteps conform to specifications."""
|
||||||
|
is_benchmark = (domain, task) in suite.BENCHMARKING
|
||||||
|
env = suite.load(domain, task)
|
||||||
|
observation_spec = env.observation_spec()
|
||||||
|
action_spec = env.action_spec()
|
||||||
|
|
||||||
|
# Check action bounds.
|
||||||
|
if is_benchmark:
|
||||||
|
self._validate_control_range(action_spec.minimum, action_spec.maximum)
|
||||||
|
|
||||||
|
# Step through the environment, applying random actions sampled within the
|
||||||
|
# valid range and check the observations, rewards, and discounts.
|
||||||
|
policy = uniform_random_policy(action_spec)
|
||||||
|
for time_step in step_environment(env, policy):
|
||||||
|
self._validate_observation(time_step.observation, observation_spec)
|
||||||
|
self._validate_discount(time_step)
|
||||||
|
if is_benchmark:
|
||||||
|
self._validate_reward_range(time_step)
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_environment_is_deterministic(self, domain, task):
|
||||||
|
"""Tests that identical seeds and actions produce identical trajectories."""
|
||||||
|
seed = 0
|
||||||
|
# Iterate over two trajectories generated using identical sequences of
|
||||||
|
# random actions, and with identical task random states. Check that the
|
||||||
|
# observations, rewards, discounts and step types are identical.
|
||||||
|
trajectory1 = make_trajectory(domain=domain, task=task, seed=seed)
|
||||||
|
trajectory2 = make_trajectory(domain=domain, task=task, seed=seed)
|
||||||
|
for time_step1, time_step2 in zip(trajectory1, trajectory2):
|
||||||
|
self.assertEqual(time_step1.step_type, time_step2.step_type)
|
||||||
|
self.assertEqual(time_step1.reward, time_step2.reward)
|
||||||
|
self.assertEqual(time_step1.discount, time_step2.discount)
|
||||||
|
for key in six.iterkeys(time_step1.observation):
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
time_step1.observation[key], time_step2.observation[key],
|
||||||
|
err_msg='Observation {!r} is not equal.'.format(key))
|
||||||
|
|
||||||
|
def assertCorrectColors(self, physics, reward):
|
||||||
|
colors = physics.named.model.mat_rgba
|
||||||
|
for material_name in ('self', 'effector', 'target'):
|
||||||
|
highlight = colors[material_name + '_highlight']
|
||||||
|
default = colors[material_name + '_default']
|
||||||
|
blend_coef = reward ** 4
|
||||||
|
expected = blend_coef * highlight + (1.0 - blend_coef) * default
|
||||||
|
actual = colors[material_name]
|
||||||
|
err_msg = ('Material {!r} has unexpected color.\nExpected: {!r}\n'
|
||||||
|
'Actual: {!r}'.format(material_name, expected, actual))
|
||||||
|
np.testing.assert_array_almost_equal(expected, actual, err_msg=err_msg)
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_visualize_reward(self, domain, task):
|
||||||
|
env = suite.load(domain, task)
|
||||||
|
env.task.visualize_reward = True
|
||||||
|
action = np.zeros(env.action_spec().shape)
|
||||||
|
|
||||||
|
with mock.patch.object(env.task, 'get_reward') as mock_get_reward:
|
||||||
|
mock_get_reward.return_value = -3.0 # Rewards < 0 should be clipped.
|
||||||
|
env.reset()
|
||||||
|
mock_get_reward.assert_called_with(env.physics)
|
||||||
|
self.assertCorrectColors(env.physics, reward=0.0)
|
||||||
|
|
||||||
|
mock_get_reward.reset_mock()
|
||||||
|
mock_get_reward.return_value = 0.5
|
||||||
|
env.step(action)
|
||||||
|
mock_get_reward.assert_called_with(env.physics)
|
||||||
|
self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)
|
||||||
|
|
||||||
|
mock_get_reward.reset_mock()
|
||||||
|
mock_get_reward.return_value = 2.0 # Rewards > 1 should be clipped.
|
||||||
|
env.step(action)
|
||||||
|
mock_get_reward.assert_called_with(env.physics)
|
||||||
|
self.assertCorrectColors(env.physics, reward=1.0)
|
||||||
|
|
||||||
|
mock_get_reward.reset_mock()
|
||||||
|
mock_get_reward.return_value = 0.25
|
||||||
|
env.reset()
|
||||||
|
mock_get_reward.assert_called_with(env.physics)
|
||||||
|
self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_task_supports_environment_kwargs(self, domain, task):
|
||||||
|
env = suite.load(domain, task,
|
||||||
|
environment_kwargs=dict(flat_observation=True))
|
||||||
|
# Check that the kwargs are actually passed through to the environment.
|
||||||
|
self.assertSetEqual(set(env.observation_spec()),
|
||||||
|
{control.FLAT_OBSERVATION_KEY})
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_observation_arrays_dont_share_memory(self, domain, task):
|
||||||
|
env = suite.load(domain, task)
|
||||||
|
first_timestep = env.reset()
|
||||||
|
action = np.zeros(env.action_spec().shape)
|
||||||
|
second_timestep = env.step(action)
|
||||||
|
for name, first_array in six.iteritems(first_timestep.observation):
|
||||||
|
second_array = second_timestep.observation[name]
|
||||||
|
self.assertFalse(
|
||||||
|
np.may_share_memory(first_array, second_array),
|
||||||
|
msg='Consecutive observations of {!r} may share memory.'.format(name))
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_observations_dont_contain_constant_elements(self, domain, task):
|
||||||
|
env = suite.load(domain, task)
|
||||||
|
trajectory = make_trajectory(domain=domain, task=task, seed=0,
|
||||||
|
num_episodes=2, max_steps_per_episode=1000)
|
||||||
|
observations = {name: [] for name in env.observation_spec()}
|
||||||
|
for time_step in trajectory:
|
||||||
|
for name, array in six.iteritems(time_step.observation):
|
||||||
|
observations[name].append(array)
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
for name, array_list in six.iteritems(observations):
|
||||||
|
# Sampling random uniform actions generally isn't sufficient to trigger
|
||||||
|
# these touch sensors.
|
||||||
|
if (domain in ('manipulator', 'stacker') and name == 'touch' or
|
||||||
|
domain == 'quadruped' and name == 'force_torque'):
|
||||||
|
continue
|
||||||
|
stacked_arrays = np.array(array_list)
|
||||||
|
is_constant = np.all(stacked_arrays == stacked_arrays[0], axis=0)
|
||||||
|
has_constant_elements = (
|
||||||
|
is_constant if np.isscalar(is_constant) else np.any(is_constant))
|
||||||
|
if has_constant_elements:
|
||||||
|
failures.append((name, is_constant))
|
||||||
|
|
||||||
|
self.assertEmpty(
|
||||||
|
failures,
|
||||||
|
msg='The following observation(s) contain constant elements:\n{}'
|
||||||
|
.format('\n'.join(':\t'.join([name, str(is_constant)])
|
||||||
|
for (name, is_constant) in failures)))
|
||||||
|
|
||||||
|
@parameterized.parameters(*suite.ALL_TASKS)
|
||||||
|
def test_initial_state_is_randomized(self, domain, task):
|
||||||
|
env = suite.load(domain, task, task_kwargs={'random': 42})
|
||||||
|
obs1 = env.reset().observation
|
||||||
|
obs2 = env.reset().observation
|
||||||
|
self.assertFalse(
|
||||||
|
all(np.all(obs1[k] == obs2[k]) for k in obs1),
|
||||||
|
'Two consecutive initial states have identical observations.\n'
|
||||||
|
'First: {}\nSecond: {}'.format(obs1, obs2))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
52
local_dm_control_suite/tests/loader_test.py
Executable file
52
local_dm_control_suite/tests/loader_test.py
Executable file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tests for the dm_control.suite loader."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Internal dependencies.
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
from dm_control import suite
|
||||||
|
from dm_control.rl import control
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_load_without_kwargs(self):
|
||||||
|
env = suite.load('cartpole', 'swingup')
|
||||||
|
self.assertIsInstance(env, control.Environment)
|
||||||
|
|
||||||
|
def test_load_with_kwargs(self):
|
||||||
|
env = suite.load('cartpole', 'swingup',
|
||||||
|
task_kwargs={'time_limit': 40, 'random': 99})
|
||||||
|
self.assertIsInstance(env, control.Environment)
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderConstantsTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def testSuiteConstants(self):
|
||||||
|
self.assertNotEmpty(suite.BENCHMARKING)
|
||||||
|
self.assertNotEmpty(suite.EASY)
|
||||||
|
self.assertNotEmpty(suite.HARD)
|
||||||
|
self.assertNotEmpty(suite.EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
88
local_dm_control_suite/tests/lqr_test.py
Executable file
88
local_dm_control_suite/tests/lqr_test.py
Executable file
@ -0,0 +1,88 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tests specific to the LQR domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# Internal dependencies.
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from local_dm_control_suite import lqr
|
||||||
|
from local_dm_control_suite import lqr_solver
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
|
||||||
|
class LqrTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('lqr_2_1', lqr.lqr_2_1),
|
||||||
|
('lqr_6_2', lqr.lqr_6_2))
|
||||||
|
def test_lqr_optimal_policy(self, make_env):
|
||||||
|
env = make_env()
|
||||||
|
p, k, beta = lqr_solver.solve(env)
|
||||||
|
self.assertPolicyisOptimal(env, p, k, beta)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('lqr_2_1', lqr.lqr_2_1),
|
||||||
|
('lqr_6_2', lqr.lqr_6_2))
|
||||||
|
@unittest.skipUnless(
|
||||||
|
condition=lqr_solver.sp,
|
||||||
|
reason='scipy is not available, so non-scipy DARE solver is the default.')
|
||||||
|
def test_lqr_optimal_policy_no_scipy(self, make_env):
|
||||||
|
env = make_env()
|
||||||
|
old_sp = lqr_solver.sp
|
||||||
|
try:
|
||||||
|
lqr_solver.sp = None # Force the solver to use the non-scipy code path.
|
||||||
|
p, k, beta = lqr_solver.solve(env)
|
||||||
|
finally:
|
||||||
|
lqr_solver.sp = old_sp
|
||||||
|
self.assertPolicyisOptimal(env, p, k, beta)
|
||||||
|
|
||||||
|
def assertPolicyisOptimal(self, env, p, k, beta):
|
||||||
|
tolerance = 1e-3
|
||||||
|
n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
|
||||||
|
logging.info('%d timesteps for %g convergence.', n_steps, tolerance)
|
||||||
|
total_loss = 0.0
|
||||||
|
|
||||||
|
timestep = env.reset()
|
||||||
|
initial_state = np.hstack((timestep.observation['position'],
|
||||||
|
timestep.observation['velocity']))
|
||||||
|
logging.info('Measuring total cost over %d steps.', n_steps)
|
||||||
|
for _ in range(n_steps):
|
||||||
|
x = np.hstack((timestep.observation['position'],
|
||||||
|
timestep.observation['velocity']))
|
||||||
|
# u = k*x is the optimal policy
|
||||||
|
u = k.dot(x)
|
||||||
|
total_loss += 1 - (timestep.reward or 0.0)
|
||||||
|
timestep = env.step(u)
|
||||||
|
|
||||||
|
logging.info('Analytical expected total cost is .5*x^T*p*x.')
|
||||||
|
expected_loss = .5 * initial_state.T.dot(p).dot(initial_state)
|
||||||
|
logging.info('Comparing measured and predicted costs.')
|
||||||
|
np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
16
local_dm_control_suite/utils/__init__.py
Executable file
16
local_dm_control_suite/utils/__init__.py
Executable file
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Utility functions used in the control suite."""
|
251
local_dm_control_suite/utils/parse_amc.py
Executable file
251
local_dm_control_suite/utils/parse_amc.py
Executable file
@ -0,0 +1,251 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Parse and convert amc motion capture data."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control.mujoco.wrapper import mjbindings
|
||||||
|
import numpy as np
|
||||||
|
from scipy import interpolate
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
mjlib = mjbindings.mjlib
|
||||||
|
|
||||||
|
MOCAP_DT = 1.0/120.0
|
||||||
|
CONVERSION_LENGTH = 0.056444
|
||||||
|
|
||||||
|
_CMU_MOCAP_JOINT_ORDER = (
|
||||||
|
'root0', 'root1', 'root2', 'root3', 'root4', 'root5', 'lowerbackrx',
|
||||||
|
'lowerbackry', 'lowerbackrz', 'upperbackrx', 'upperbackry', 'upperbackrz',
|
||||||
|
'thoraxrx', 'thoraxry', 'thoraxrz', 'lowerneckrx', 'lowerneckry',
|
||||||
|
'lowerneckrz', 'upperneckrx', 'upperneckry', 'upperneckrz', 'headrx',
|
||||||
|
'headry', 'headrz', 'rclaviclery', 'rclaviclerz', 'rhumerusrx',
|
||||||
|
'rhumerusry', 'rhumerusrz', 'rradiusrx', 'rwristry', 'rhandrx', 'rhandrz',
|
||||||
|
'rfingersrx', 'rthumbrx', 'rthumbrz', 'lclaviclery', 'lclaviclerz',
|
||||||
|
'lhumerusrx', 'lhumerusry', 'lhumerusrz', 'lradiusrx', 'lwristry',
|
||||||
|
'lhandrx', 'lhandrz', 'lfingersrx', 'lthumbrx', 'lthumbrz', 'rfemurrx',
|
||||||
|
'rfemurry', 'rfemurrz', 'rtibiarx', 'rfootrx', 'rfootrz', 'rtoesrx',
|
||||||
|
'lfemurrx', 'lfemurry', 'lfemurrz', 'ltibiarx', 'lfootrx', 'lfootrz',
|
||||||
|
'ltoesrx'
|
||||||
|
)
|
||||||
|
|
||||||
|
Converted = collections.namedtuple('Converted',
|
||||||
|
['qpos', 'qvel', 'time'])
|
||||||
|
|
||||||
|
|
||||||
|
def convert(file_name, physics, timestep):
|
||||||
|
"""Converts the parsed .amc values into qpos and qvel values and resamples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: The .amc file to be parsed and converted.
|
||||||
|
physics: The corresponding physics instance.
|
||||||
|
timestep: Desired output interval between resampled frames.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A namedtuple with fields:
|
||||||
|
`qpos`, a numpy array containing converted positional variables.
|
||||||
|
`qvel`, a numpy array containing converted velocity variables.
|
||||||
|
`time`, a numpy array containing the corresponding times.
|
||||||
|
"""
|
||||||
|
frame_values = parse(file_name)
|
||||||
|
joint2index = {}
|
||||||
|
for name in physics.named.data.qpos.axes.row.names:
|
||||||
|
joint2index[name] = physics.named.data.qpos.axes.row.convert_key_item(name)
|
||||||
|
index2joint = {}
|
||||||
|
for joint, index in joint2index.items():
|
||||||
|
if isinstance(index, slice):
|
||||||
|
indices = range(index.start, index.stop)
|
||||||
|
else:
|
||||||
|
indices = [index]
|
||||||
|
for ii in indices:
|
||||||
|
index2joint[ii] = joint
|
||||||
|
|
||||||
|
# Convert frame_values to qpos
|
||||||
|
amcvals2qpos_transformer = Amcvals2qpos(index2joint, _CMU_MOCAP_JOINT_ORDER)
|
||||||
|
qpos_values = []
|
||||||
|
for frame_value in frame_values:
|
||||||
|
qpos_values.append(amcvals2qpos_transformer(frame_value))
|
||||||
|
qpos_values = np.stack(qpos_values) # Time by nq
|
||||||
|
|
||||||
|
# Interpolate/resample.
|
||||||
|
# Note: interpolate quaternions rather than euler angles (slerp).
|
||||||
|
# see https://en.wikipedia.org/wiki/Slerp
|
||||||
|
qpos_values_resampled = []
|
||||||
|
time_vals = np.arange(0, len(frame_values)*MOCAP_DT - 1e-8, MOCAP_DT)
|
||||||
|
time_vals_new = np.arange(0, len(frame_values)*MOCAP_DT, timestep)
|
||||||
|
while time_vals_new[-1] > time_vals[-1]:
|
||||||
|
time_vals_new = time_vals_new[:-1]
|
||||||
|
|
||||||
|
for i in range(qpos_values.shape[1]):
|
||||||
|
f = interpolate.splrep(time_vals, qpos_values[:, i])
|
||||||
|
qpos_values_resampled.append(interpolate.splev(time_vals_new, f))
|
||||||
|
|
||||||
|
qpos_values_resampled = np.stack(qpos_values_resampled) # nq by ntime
|
||||||
|
|
||||||
|
qvel_list = []
|
||||||
|
for t in range(qpos_values_resampled.shape[1]-1):
|
||||||
|
p_tp1 = qpos_values_resampled[:, t + 1]
|
||||||
|
p_t = qpos_values_resampled[:, t]
|
||||||
|
qvel = [(p_tp1[:3]-p_t[:3])/ timestep,
|
||||||
|
mj_quat2vel(mj_quatdiff(p_t[3:7], p_tp1[3:7]), timestep),
|
||||||
|
(p_tp1[7:]-p_t[7:])/ timestep]
|
||||||
|
qvel_list.append(np.concatenate(qvel))
|
||||||
|
|
||||||
|
qvel_values_resampled = np.vstack(qvel_list).T
|
||||||
|
|
||||||
|
return Converted(qpos_values_resampled, qvel_values_resampled, time_vals_new)
|
||||||
|
|
||||||
|
|
||||||
|
def parse(file_name):
|
||||||
|
"""Parses the amc file format."""
|
||||||
|
values = []
|
||||||
|
fid = open(file_name, 'r')
|
||||||
|
line = fid.readline().strip()
|
||||||
|
frame_ind = 1
|
||||||
|
first_frame = True
|
||||||
|
while True:
|
||||||
|
# Parse first frame.
|
||||||
|
if first_frame and line[0] == str(frame_ind):
|
||||||
|
first_frame = False
|
||||||
|
frame_ind += 1
|
||||||
|
frame_vals = []
|
||||||
|
while True:
|
||||||
|
line = fid.readline().strip()
|
||||||
|
if not line or line == str(frame_ind):
|
||||||
|
values.append(np.array(frame_vals, dtype=np.float))
|
||||||
|
break
|
||||||
|
tokens = line.split()
|
||||||
|
frame_vals.extend(tokens[1:])
|
||||||
|
# Parse other frames.
|
||||||
|
elif line == str(frame_ind):
|
||||||
|
frame_ind += 1
|
||||||
|
frame_vals = []
|
||||||
|
while True:
|
||||||
|
line = fid.readline().strip()
|
||||||
|
if not line or line == str(frame_ind):
|
||||||
|
values.append(np.array(frame_vals, dtype=np.float))
|
||||||
|
break
|
||||||
|
tokens = line.split()
|
||||||
|
frame_vals.extend(tokens[1:])
|
||||||
|
else:
|
||||||
|
line = fid.readline().strip()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class Amcvals2qpos(object):
|
||||||
|
"""Callable that converts .amc values for a frame and to MuJoCo qpos format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, index2joint, joint_order):
|
||||||
|
"""Initializes a new Amcvals2qpos instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index2joint: List of joint angles in .amc file.
|
||||||
|
joint_order: List of joint names in MuJoco MJCF.
|
||||||
|
"""
|
||||||
|
# Root is x,y,z, then quat.
|
||||||
|
# need to get indices of qpos that order for amc default order
|
||||||
|
self.qpos_root_xyz_ind = [0, 1, 2]
|
||||||
|
self.root_xyz_ransform = np.array(
|
||||||
|
[[1, 0, 0], [0, 0, -1], [0, 1, 0]]) * CONVERSION_LENGTH
|
||||||
|
self.qpos_root_quat_ind = [3, 4, 5, 6]
|
||||||
|
amc2qpos_transform = np.zeros((len(index2joint), len(joint_order)))
|
||||||
|
for i in range(len(index2joint)):
|
||||||
|
for j in range(len(joint_order)):
|
||||||
|
if index2joint[i] == joint_order[j]:
|
||||||
|
if 'rx' in index2joint[i]:
|
||||||
|
amc2qpos_transform[i][j] = 1
|
||||||
|
elif 'ry' in index2joint[i]:
|
||||||
|
amc2qpos_transform[i][j] = 1
|
||||||
|
elif 'rz' in index2joint[i]:
|
||||||
|
amc2qpos_transform[i][j] = 1
|
||||||
|
self.amc2qpos_transform = amc2qpos_transform
|
||||||
|
|
||||||
|
def __call__(self, amc_val):
|
||||||
|
"""Converts a `.amc` frame to MuJoCo qpos format."""
|
||||||
|
amc_val_rad = np.deg2rad(amc_val)
|
||||||
|
qpos = np.dot(self.amc2qpos_transform, amc_val_rad)
|
||||||
|
|
||||||
|
# Root.
|
||||||
|
qpos[:3] = np.dot(self.root_xyz_ransform, amc_val[:3])
|
||||||
|
qpos_quat = euler2quat(amc_val[3], amc_val[4], amc_val[5])
|
||||||
|
qpos_quat = mj_quatprod(euler2quat(90, 0, 0), qpos_quat)
|
||||||
|
|
||||||
|
for i, ind in enumerate(self.qpos_root_quat_ind):
|
||||||
|
qpos[ind] = qpos_quat[i]
|
||||||
|
|
||||||
|
return qpos
|
||||||
|
|
||||||
|
|
||||||
|
def euler2quat(ax, ay, az):
|
||||||
|
"""Converts euler angles to a quaternion.
|
||||||
|
|
||||||
|
Note: rotation order is zyx
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ax: Roll angle (deg)
|
||||||
|
ay: Pitch angle (deg).
|
||||||
|
az: Yaw angle (deg).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array representing the rotation as a quaternion.
|
||||||
|
"""
|
||||||
|
r1 = az
|
||||||
|
r2 = ay
|
||||||
|
r3 = ax
|
||||||
|
|
||||||
|
c1 = np.cos(np.deg2rad(r1 / 2))
|
||||||
|
s1 = np.sin(np.deg2rad(r1 / 2))
|
||||||
|
c2 = np.cos(np.deg2rad(r2 / 2))
|
||||||
|
s2 = np.sin(np.deg2rad(r2 / 2))
|
||||||
|
c3 = np.cos(np.deg2rad(r3 / 2))
|
||||||
|
s3 = np.sin(np.deg2rad(r3 / 2))
|
||||||
|
|
||||||
|
q0 = c1 * c2 * c3 + s1 * s2 * s3
|
||||||
|
q1 = c1 * c2 * s3 - s1 * s2 * c3
|
||||||
|
q2 = c1 * s2 * c3 + s1 * c2 * s3
|
||||||
|
q3 = s1 * c2 * c3 - c1 * s2 * s3
|
||||||
|
|
||||||
|
return np.array([q0, q1, q2, q3])
|
||||||
|
|
||||||
|
|
||||||
|
def mj_quatprod(q, r):
|
||||||
|
quaternion = np.zeros(4)
|
||||||
|
mjlib.mju_mulQuat(quaternion, np.ascontiguousarray(q),
|
||||||
|
np.ascontiguousarray(r))
|
||||||
|
return quaternion
|
||||||
|
|
||||||
|
|
||||||
|
def mj_quat2vel(q, dt):
|
||||||
|
vel = np.zeros(3)
|
||||||
|
mjlib.mju_quat2Vel(vel, np.ascontiguousarray(q), dt)
|
||||||
|
return vel
|
||||||
|
|
||||||
|
|
||||||
|
def mj_quatneg(q):
|
||||||
|
quaternion = np.zeros(4)
|
||||||
|
mjlib.mju_negQuat(quaternion, np.ascontiguousarray(q))
|
||||||
|
return quaternion
|
||||||
|
|
||||||
|
|
||||||
|
def mj_quatdiff(source, target):
|
||||||
|
return mj_quatprod(mj_quatneg(source), np.ascontiguousarray(target))
|
68
local_dm_control_suite/utils/parse_amc_test.py
Executable file
68
local_dm_control_suite/utils/parse_amc_test.py
Executable file
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tests for parse_amc utility."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Internal dependencies.
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from local_dm_control_suite import humanoid_CMU
|
||||||
|
from dm_control.suite.utils import parse_amc
|
||||||
|
|
||||||
|
from dm_control.utils import io as resources
|
||||||
|
|
||||||
|
_TEST_AMC_PATH = resources.GetResourceFilename(
|
||||||
|
os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
|
||||||
|
|
||||||
|
|
||||||
|
class ParseAMCTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_sizes_of_parsed_data(self):
|
||||||
|
|
||||||
|
# Instantiate the humanoid environment.
|
||||||
|
env = humanoid_CMU.stand()
|
||||||
|
|
||||||
|
# Parse and convert specified clip.
|
||||||
|
converted = parse_amc.convert(
|
||||||
|
_TEST_AMC_PATH, env.physics, env.control_timestep())
|
||||||
|
|
||||||
|
self.assertEqual(converted.qpos.shape[0], 63)
|
||||||
|
self.assertEqual(converted.qvel.shape[0], 62)
|
||||||
|
self.assertEqual(converted.time.shape[0], converted.qpos.shape[1])
|
||||||
|
self.assertEqual(converted.qpos.shape[1],
|
||||||
|
converted.qvel.shape[1] + 1)
|
||||||
|
|
||||||
|
# Parse and convert specified clip -- WITH SMALLER TIMESTEP
|
||||||
|
converted2 = parse_amc.convert(
|
||||||
|
_TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep())
|
||||||
|
|
||||||
|
self.assertEqual(converted2.qpos.shape[0], 63)
|
||||||
|
self.assertEqual(converted2.qvel.shape[0], 62)
|
||||||
|
self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1])
|
||||||
|
self.assertEqual(converted.qpos.shape[1],
|
||||||
|
converted.qvel.shape[1] + 1)
|
||||||
|
|
||||||
|
# Compare sizes of parsed objects for different timesteps
|
||||||
|
self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
91
local_dm_control_suite/utils/randomizers.py
Executable file
91
local_dm_control_suite/utils/randomizers.py
Executable file
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Randomization functions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from dm_control.mujoco.wrapper import mjbindings
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
|
||||||
|
def random_limited_quaternion(random, limit):
|
||||||
|
"""Generates a random quaternion limited to the specified rotations."""
|
||||||
|
axis = random.randn(3)
|
||||||
|
axis /= np.linalg.norm(axis)
|
||||||
|
angle = random.rand() * limit
|
||||||
|
|
||||||
|
quaternion = np.zeros(4)
|
||||||
|
mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
|
||||||
|
|
||||||
|
return quaternion
|
||||||
|
|
||||||
|
|
||||||
|
def randomize_limited_and_rotational_joints(physics, random=None):
|
||||||
|
"""Randomizes the positions of joints defined in the physics body.
|
||||||
|
|
||||||
|
The following randomization rules apply:
|
||||||
|
- Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
|
||||||
|
- Unbounded hinges are samples uniformly in [-pi, pi]
|
||||||
|
- Quaternions for unlimited free joints and ball joints are sampled
|
||||||
|
uniformly on the unit 3-sphere.
|
||||||
|
- Quaternions for limited ball joints are sampled uniformly on a sector
|
||||||
|
of the unit 3-sphere.
|
||||||
|
- The linear degrees of freedom of free joints are not randomized.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: Instance of 'Physics' class that holds a loaded model.
|
||||||
|
random: Optional instance of 'np.random.RandomState'. Defaults to the global
|
||||||
|
NumPy random state.
|
||||||
|
"""
|
||||||
|
random = random or np.random
|
||||||
|
|
||||||
|
hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
|
||||||
|
slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
|
||||||
|
ball = mjbindings.enums.mjtJoint.mjJNT_BALL
|
||||||
|
free = mjbindings.enums.mjtJoint.mjJNT_FREE
|
||||||
|
|
||||||
|
qpos = physics.named.data.qpos
|
||||||
|
|
||||||
|
for joint_id in range(physics.model.njnt):
|
||||||
|
joint_name = physics.model.id2name(joint_id, 'joint')
|
||||||
|
joint_type = physics.model.jnt_type[joint_id]
|
||||||
|
is_limited = physics.model.jnt_limited[joint_id]
|
||||||
|
range_min, range_max = physics.model.jnt_range[joint_id]
|
||||||
|
|
||||||
|
if is_limited:
|
||||||
|
if joint_type == hinge or joint_type == slide:
|
||||||
|
qpos[joint_name] = random.uniform(range_min, range_max)
|
||||||
|
|
||||||
|
elif joint_type == ball:
|
||||||
|
qpos[joint_name] = random_limited_quaternion(random, range_max)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if joint_type == hinge:
|
||||||
|
qpos[joint_name] = random.uniform(-np.pi, np.pi)
|
||||||
|
|
||||||
|
elif joint_type == ball:
|
||||||
|
quat = random.randn(4)
|
||||||
|
quat /= np.linalg.norm(quat)
|
||||||
|
qpos[joint_name] = quat
|
||||||
|
|
||||||
|
elif joint_type == free:
|
||||||
|
quat = random.rand(4)
|
||||||
|
quat /= np.linalg.norm(quat)
|
||||||
|
qpos[joint_name][3:] = quat
|
||||||
|
|
164
local_dm_control_suite/utils/randomizers_test.py
Executable file
164
local_dm_control_suite/utils/randomizers_test.py
Executable file
@ -0,0 +1,164 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tests for randomizers.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Internal dependencies.
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.mujoco.wrapper import mjbindings
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
mjlib = mjbindings.mjlib
|
||||||
|
|
||||||
|
|
||||||
|
class RandomizeUnlimitedJointsTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.rand = np.random.RandomState(100)
|
||||||
|
|
||||||
|
def test_single_joint_of_each_type(self):
|
||||||
|
physics = mujoco.Physics.from_xml_string("""<mujoco>
|
||||||
|
<default>
|
||||||
|
<joint range="0 90" />
|
||||||
|
</default>
|
||||||
|
<worldbody>
|
||||||
|
<body>
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="free" type="free"/>
|
||||||
|
</body>
|
||||||
|
<body>
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="limited_hinge" type="hinge" limited="true"/>
|
||||||
|
<joint name="slide" type="slide"/>
|
||||||
|
<joint name="limited_slide" type="slide" limited="true"/>
|
||||||
|
<joint name="hinge" type="hinge"/>
|
||||||
|
</body>
|
||||||
|
<body>
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="ball" type="ball"/>
|
||||||
|
</body>
|
||||||
|
<body>
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="limited_ball" type="ball" limited="true"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
</mujoco>""")
|
||||||
|
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
|
||||||
|
self.assertNotEqual(0., physics.named.data.qpos['hinge'])
|
||||||
|
self.assertNotEqual(0., physics.named.data.qpos['limited_hinge'])
|
||||||
|
self.assertNotEqual(0., physics.named.data.qpos['limited_slide'])
|
||||||
|
|
||||||
|
self.assertNotEqual(0., np.sum(physics.named.data.qpos['ball']))
|
||||||
|
self.assertNotEqual(0., np.sum(physics.named.data.qpos['limited_ball']))
|
||||||
|
|
||||||
|
self.assertNotEqual(0., np.sum(physics.named.data.qpos['free'][3:]))
|
||||||
|
|
||||||
|
# Unlimited slide and the positional part of the free joint remains
|
||||||
|
# uninitialized.
|
||||||
|
self.assertEqual(0., physics.named.data.qpos['slide'])
|
||||||
|
self.assertEqual(0., np.sum(physics.named.data.qpos['free'][:3]))
|
||||||
|
|
||||||
|
def test_multiple_joints_of_same_type(self):
|
||||||
|
physics = mujoco.Physics.from_xml_string("""<mujoco>
|
||||||
|
<worldbody>
|
||||||
|
<body>
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="hinge_1" type="hinge"/>
|
||||||
|
<joint name="hinge_2" type="hinge"/>
|
||||||
|
<joint name="hinge_3" type="hinge"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
</mujoco>""")
|
||||||
|
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
|
||||||
|
self.assertNotEqual(0., physics.named.data.qpos['hinge_1'])
|
||||||
|
self.assertNotEqual(0., physics.named.data.qpos['hinge_2'])
|
||||||
|
self.assertNotEqual(0., physics.named.data.qpos['hinge_3'])
|
||||||
|
|
||||||
|
self.assertNotEqual(physics.named.data.qpos['hinge_1'],
|
||||||
|
physics.named.data.qpos['hinge_2'])
|
||||||
|
|
||||||
|
self.assertNotEqual(physics.named.data.qpos['hinge_2'],
|
||||||
|
physics.named.data.qpos['hinge_3'])
|
||||||
|
|
||||||
|
self.assertNotEqual(physics.named.data.qpos['hinge_1'],
|
||||||
|
physics.named.data.qpos['hinge_3'])
|
||||||
|
|
||||||
|
def test_unlimited_hinge_randomization_range(self):
|
||||||
|
physics = mujoco.Physics.from_xml_string("""<mujoco>
|
||||||
|
<worldbody>
|
||||||
|
<body>
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="hinge" type="hinge"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
</mujoco>""")
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
|
||||||
|
self.assertBetween(physics.named.data.qpos['hinge'], -np.pi, np.pi)
|
||||||
|
|
||||||
|
def test_limited_1d_joint_limits_are_respected(self):
|
||||||
|
physics = mujoco.Physics.from_xml_string("""<mujoco>
|
||||||
|
<default>
|
||||||
|
<joint limited="true"/>
|
||||||
|
</default>
|
||||||
|
<worldbody>
|
||||||
|
<body>
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="hinge" type="hinge" range="0 10"/>
|
||||||
|
<joint name="slide" type="slide" range="30 50"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
</mujoco>""")
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
|
||||||
|
self.assertBetween(physics.named.data.qpos['hinge'],
|
||||||
|
np.deg2rad(0), np.deg2rad(10))
|
||||||
|
self.assertBetween(physics.named.data.qpos['slide'], 30, 50)
|
||||||
|
|
||||||
|
def test_limited_ball_joint_are_respected(self):
|
||||||
|
physics = mujoco.Physics.from_xml_string("""<mujoco>
|
||||||
|
<worldbody>
|
||||||
|
<body name="body" zaxis="1 0 0">
|
||||||
|
<geom type="box" size="1 1 1"/>
|
||||||
|
<joint name="ball" type="ball" limited="true" range="0 60"/>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
</mujoco>""")
|
||||||
|
|
||||||
|
body_axis = np.array([1., 0., 0.])
|
||||||
|
joint_axis = np.zeros(3)
|
||||||
|
for _ in range(10):
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
|
||||||
|
|
||||||
|
quat = physics.named.data.qpos['ball']
|
||||||
|
mjlib.mju_rotVecQuat(joint_axis, body_axis, quat)
|
||||||
|
angle_cos = np.dot(body_axis, joint_axis)
|
||||||
|
self.assertGreater(angle_cos, 0.5) # cos(60) = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
158
local_dm_control_suite/walker.py
Executable file
158
local_dm_control_suite/walker.py
Executable file
@ -0,0 +1,158 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Planar Walker Domain."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from local_dm_control_suite import base
|
||||||
|
from local_dm_control_suite import common
|
||||||
|
from dm_control.suite.utils import randomizers
|
||||||
|
from dm_control.utils import containers
|
||||||
|
from dm_control.utils import rewards
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_TIME_LIMIT = 25
|
||||||
|
_CONTROL_TIMESTEP = .025
|
||||||
|
|
||||||
|
# Minimal height of torso over foot above which stand reward is 1.
|
||||||
|
_STAND_HEIGHT = 1.2
|
||||||
|
|
||||||
|
# Horizontal speeds (meters/second) above which move reward is 1.
|
||||||
|
_WALK_SPEED = 1
|
||||||
|
_RUN_SPEED = 8
|
||||||
|
|
||||||
|
|
||||||
|
SUITE = containers.TaggedTasks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_and_assets():
|
||||||
|
"""Returns a tuple containing the model XML string and a dict of assets."""
|
||||||
|
return common.read_model('walker.xml'), common.ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Stand task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = PlanarWalker(move_speed=0, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Walk task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = PlanarWalker(move_speed=_WALK_SPEED, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@SUITE.add('benchmarking')
|
||||||
|
def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
|
||||||
|
"""Returns the Run task."""
|
||||||
|
physics = Physics.from_xml_string(*get_model_and_assets())
|
||||||
|
task = PlanarWalker(move_speed=_RUN_SPEED, random=random)
|
||||||
|
environment_kwargs = environment_kwargs or {}
|
||||||
|
return control.Environment(
|
||||||
|
physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
|
||||||
|
**environment_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Physics(mujoco.Physics):
|
||||||
|
"""Physics simulation with additional features for the Walker domain."""
|
||||||
|
|
||||||
|
def torso_upright(self):
|
||||||
|
"""Returns projection from z-axes of torso to the z-axes of world."""
|
||||||
|
return self.named.data.xmat['torso', 'zz']
|
||||||
|
|
||||||
|
def torso_height(self):
|
||||||
|
"""Returns the height of the torso."""
|
||||||
|
return self.named.data.xpos['torso', 'z']
|
||||||
|
|
||||||
|
def horizontal_velocity(self):
|
||||||
|
"""Returns the horizontal velocity of the center-of-mass."""
|
||||||
|
return self.named.data.sensordata['torso_subtreelinvel'][0]
|
||||||
|
|
||||||
|
def orientations(self):
|
||||||
|
"""Returns planar orientations of all bodies."""
|
||||||
|
return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
|
||||||
|
|
||||||
|
|
||||||
|
class PlanarWalker(base.Task):
|
||||||
|
"""A planar walker task."""
|
||||||
|
|
||||||
|
def __init__(self, move_speed, random=None):
|
||||||
|
"""Initializes an instance of `PlanarWalker`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
move_speed: A float. If this value is zero, reward is given simply for
|
||||||
|
standing up. Otherwise this specifies a target horizontal velocity for
|
||||||
|
the walking task.
|
||||||
|
random: Optional, either a `numpy.random.RandomState` instance, an
|
||||||
|
integer seed for creating a new `RandomState`, or None to select a seed
|
||||||
|
automatically (default).
|
||||||
|
"""
|
||||||
|
self._move_speed = move_speed
|
||||||
|
super(PlanarWalker, self).__init__(random=random)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode.
|
||||||
|
|
||||||
|
In 'standing' mode, use initial orientation and small velocities.
|
||||||
|
In 'random' mode, randomize joint angles and let fall to the floor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physics: An instance of `Physics`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
|
||||||
|
super(PlanarWalker, self).initialize_episode(physics)
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
"""Returns an observation of body orientations, height and velocites."""
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['orientations'] = physics.orientations()
|
||||||
|
obs['height'] = physics.torso_height()
|
||||||
|
obs['velocity'] = physics.velocity()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
"""Returns a reward to the agent."""
|
||||||
|
standing = rewards.tolerance(physics.torso_height(),
|
||||||
|
bounds=(_STAND_HEIGHT, float('inf')),
|
||||||
|
margin=_STAND_HEIGHT/2)
|
||||||
|
upright = (1 + physics.torso_upright()) / 2
|
||||||
|
stand_reward = (3*standing + upright) / 4
|
||||||
|
if self._move_speed == 0:
|
||||||
|
return stand_reward
|
||||||
|
else:
|
||||||
|
move_reward = rewards.tolerance(physics.horizontal_velocity(),
|
||||||
|
bounds=(self._move_speed, float('inf')),
|
||||||
|
margin=self._move_speed/2,
|
||||||
|
value_at_margin=0.5,
|
||||||
|
sigmoid='linear')
|
||||||
|
return stand_reward * (5*move_reward + 1) / 6
|
70
local_dm_control_suite/walker.xml
Executable file
70
local_dm_control_suite/walker.xml
Executable file
@ -0,0 +1,70 @@
|
|||||||
|
<mujoco model="planar walker">
|
||||||
|
<include file="./common/visual.xml"/>
|
||||||
|
<include file="./common/skybox.xml"/>
|
||||||
|
<include file="./common/materials_white_floor.xml"/>
|
||||||
|
|
||||||
|
<option timestep="0.0025"/>
|
||||||
|
|
||||||
|
<statistic extent="2" center="0 0 1"/>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<joint damping=".1" armature="0.01" limited="true" solimplimit="0 .99 .01"/>
|
||||||
|
<geom contype="1" conaffinity="0" friction=".7 .1 .1"/>
|
||||||
|
<motor ctrlrange="-1 1" ctrllimited="true"/>
|
||||||
|
<site size="0.01"/>
|
||||||
|
<default class="walker">
|
||||||
|
<geom material="self" type="capsule"/>
|
||||||
|
<joint axis="0 -1 0"/>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<geom name="floor" type="plane" conaffinity="1" pos="248 0 0" size="250 .8 .2" material="grid" zaxis="0 0 1"/>
|
||||||
|
<body name="torso" pos="0 0 1.3" childclass="walker">
|
||||||
|
<light name="light" pos="0 0 2" mode="trackcom"/>
|
||||||
|
<camera name="side" pos="0 -2 .7" euler="60 0 0" mode="trackcom"/>
|
||||||
|
<camera name="back" pos="-2 0 .5" xyaxes="0 -1 0 1 0 3" mode="trackcom"/>
|
||||||
|
<joint name="rootz" axis="0 0 1" type="slide" limited="false" armature="0" damping="0"/>
|
||||||
|
<joint name="rootx" axis="1 0 0" type="slide" limited="false" armature="0" damping="0"/>
|
||||||
|
<joint name="rooty" axis="0 1 0" type="hinge" limited="false" armature="0" damping="0"/>
|
||||||
|
<geom name="torso" size="0.07 0.3"/>
|
||||||
|
<body name="right_thigh" pos="0 -.05 -0.3">
|
||||||
|
<joint name="right_hip" range="-20 100"/>
|
||||||
|
<geom name="right_thigh" pos="0 0 -0.225" size="0.05 0.225"/>
|
||||||
|
<body name="right_leg" pos="0 0 -0.7">
|
||||||
|
<joint name="right_knee" pos="0 0 0.25" range="-150 0"/>
|
||||||
|
<geom name="right_leg" size="0.04 0.25"/>
|
||||||
|
<body name="right_foot" pos="0.06 0 -0.25">
|
||||||
|
<joint name="right_ankle" pos="-0.06 0 0" range="-45 45"/>
|
||||||
|
<geom name="right_foot" zaxis="1 0 0" size="0.05 0.1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="left_thigh" pos="0 .05 -0.3" >
|
||||||
|
<joint name="left_hip" range="-20 100"/>
|
||||||
|
<geom name="left_thigh" pos="0 0 -0.225" size="0.05 0.225"/>
|
||||||
|
<body name="left_leg" pos="0 0 -0.7">
|
||||||
|
<joint name="left_knee" pos="0 0 0.25" range="-150 0"/>
|
||||||
|
<geom name="left_leg" size="0.04 0.25"/>
|
||||||
|
<body name="left_foot" pos="0.06 0 -0.25">
|
||||||
|
<joint name="left_ankle" pos="-0.06 0 0" range="-45 45"/>
|
||||||
|
<geom name="left_foot" zaxis="1 0 0" size="0.05 0.1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<sensor>
|
||||||
|
<subtreelinvel name="torso_subtreelinvel" body="torso"/>
|
||||||
|
</sensor>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor name="right_hip" joint="right_hip" gear="100"/>
|
||||||
|
<motor name="right_knee" joint="right_knee" gear="50"/>
|
||||||
|
<motor name="right_ankle" joint="right_ankle" gear="20"/>
|
||||||
|
<motor name="left_hip" joint="left_hip" gear="100"/>
|
||||||
|
<motor name="left_knee" joint="left_knee" gear="50"/>
|
||||||
|
<motor name="left_ankle" joint="left_ankle" gear="20"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
16
local_dm_control_suite/wrappers/__init__.py
Executable file
16
local_dm_control_suite/wrappers/__init__.py
Executable file
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2018 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Environment wrappers used to extend or modify environment behaviour."""
|
74
local_dm_control_suite/wrappers/action_noise.py
Executable file
74
local_dm_control_suite/wrappers/action_noise.py
Executable file
@ -0,0 +1,74 @@
|
|||||||
|
# Copyright 2018 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Wrapper control suite environments that adds Gaussian noise to actions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import dm_env
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
_BOUNDS_MUST_BE_FINITE = (
|
||||||
|
'All bounds in `env.action_spec()` must be finite, got: {action_spec}')
|
||||||
|
|
||||||
|
|
||||||
|
class Wrapper(dm_env.Environment):
|
||||||
|
"""Wraps a control environment and adds Gaussian noise to actions."""
|
||||||
|
|
||||||
|
def __init__(self, env, scale=0.01):
|
||||||
|
"""Initializes a new action noise Wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The control suite environment to wrap.
|
||||||
|
scale: The standard deviation of the noise, expressed as a fraction
|
||||||
|
of the max-min range for each action dimension.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the action dimensions of the wrapped environment are
|
||||||
|
unbounded.
|
||||||
|
"""
|
||||||
|
action_spec = env.action_spec()
|
||||||
|
if not (np.all(np.isfinite(action_spec.minimum)) and
|
||||||
|
np.all(np.isfinite(action_spec.maximum))):
|
||||||
|
raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec))
|
||||||
|
self._minimum = action_spec.minimum
|
||||||
|
self._maximum = action_spec.maximum
|
||||||
|
self._noise_std = scale * (action_spec.maximum - action_spec.minimum)
|
||||||
|
self._env = env
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
noisy_action = action + self._env.task.random.normal(scale=self._noise_std)
|
||||||
|
# Clip the noisy actions in place so that they fall within the bounds
|
||||||
|
# specified by the `action_spec`. Note that MuJoCo implicitly clips out-of-
|
||||||
|
# bounds control inputs, but we also clip here in case the actions do not
|
||||||
|
# correspond directly to MuJoCo actuators, or if there are other wrapper
|
||||||
|
# layers that expect the actions to be within bounds.
|
||||||
|
np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action)
|
||||||
|
return self._env.step(noisy_action)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self._env.reset()
|
||||||
|
|
||||||
|
def observation_spec(self):
|
||||||
|
return self._env.observation_spec()
|
||||||
|
|
||||||
|
def action_spec(self):
|
||||||
|
return self._env.action_spec()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._env, name)
|
136
local_dm_control_suite/wrappers/action_noise_test.py
Executable file
136
local_dm_control_suite/wrappers/action_noise_test.py
Executable file
@ -0,0 +1,136 @@
|
|||||||
|
# Copyright 2018 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tests for the action noise wrapper."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Internal dependencies.
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
from dm_control.rl import control
|
||||||
|
from dm_control.suite.wrappers import action_noise
|
||||||
|
from dm_env import specs
|
||||||
|
import mock
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ActionNoiseTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def make_action_spec(self, lower=(-1.,), upper=(1.,)):
|
||||||
|
lower, upper = np.broadcast_arrays(lower, upper)
|
||||||
|
return specs.BoundedArray(
|
||||||
|
shape=lower.shape, dtype=float, minimum=lower, maximum=upper)
|
||||||
|
|
||||||
|
def make_mock_env(self, action_spec=None):
|
||||||
|
action_spec = action_spec or self.make_action_spec()
|
||||||
|
env = mock.Mock(spec=control.Environment)
|
||||||
|
env.action_spec.return_value = action_spec
|
||||||
|
return env
|
||||||
|
|
||||||
|
def assertStepCalledOnceWithCorrectAction(self, env, expected_action):
|
||||||
|
# NB: `assert_called_once_with()` doesn't support numpy arrays.
|
||||||
|
env.step.assert_called_once()
|
||||||
|
actual_action = env.step.call_args_list[0][0][0]
|
||||||
|
np.testing.assert_array_equal(expected_action, actual_action)
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.05),
|
||||||
|
dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.),
|
||||||
|
dict(lower=np.r_[-1., 0.], upper=np.r_[-1., 0.], scale=0.05),
|
||||||
|
])
|
||||||
|
def test_step(self, lower, upper, scale):
|
||||||
|
seed = 0
|
||||||
|
std = scale * (upper - lower)
|
||||||
|
expected_noise = np.random.RandomState(seed).normal(scale=std)
|
||||||
|
action = np.random.RandomState(seed).uniform(lower, upper)
|
||||||
|
expected_noisy_action = np.clip(action + expected_noise, lower, upper)
|
||||||
|
task = mock.Mock(spec=control.Task)
|
||||||
|
task.random = np.random.RandomState(seed)
|
||||||
|
action_spec = self.make_action_spec(lower=lower, upper=upper)
|
||||||
|
env = self.make_mock_env(action_spec=action_spec)
|
||||||
|
env.task = task
|
||||||
|
wrapped_env = action_noise.Wrapper(env, scale=scale)
|
||||||
|
time_step = wrapped_env.step(action)
|
||||||
|
self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action)
|
||||||
|
self.assertIs(time_step, env.step(expected_noisy_action))
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
dict(testcase_name='within_bounds', action=np.r_[-1.], noise=np.r_[0.1]),
|
||||||
|
dict(testcase_name='below_lower', action=np.r_[-1.], noise=np.r_[-0.1]),
|
||||||
|
dict(testcase_name='above_upper', action=np.r_[1.], noise=np.r_[0.1]),
|
||||||
|
])
|
||||||
|
def test_action_clipping(self, action, noise):
|
||||||
|
lower = -1.
|
||||||
|
upper = 1.
|
||||||
|
expected_noisy_action = np.clip(action + noise, lower, upper)
|
||||||
|
task = mock.Mock(spec=control.Task)
|
||||||
|
task.random = mock.Mock(spec=np.random.RandomState)
|
||||||
|
task.random.normal.return_value = noise
|
||||||
|
action_spec = self.make_action_spec(lower=lower, upper=upper)
|
||||||
|
env = self.make_mock_env(action_spec=action_spec)
|
||||||
|
env.task = task
|
||||||
|
wrapped_env = action_noise.Wrapper(env)
|
||||||
|
time_step = wrapped_env.step(action)
|
||||||
|
self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action)
|
||||||
|
self.assertIs(time_step, env.step(expected_noisy_action))
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
dict(lower=np.r_[-1., 0.], upper=np.r_[1., np.inf]),
|
||||||
|
dict(lower=np.r_[np.nan, 0.], upper=np.r_[1., 2.]),
|
||||||
|
])
|
||||||
|
def test_error_if_action_bounds_non_finite(self, lower, upper):
|
||||||
|
action_spec = self.make_action_spec(lower=lower, upper=upper)
|
||||||
|
env = self.make_mock_env(action_spec=action_spec)
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError,
|
||||||
|
action_noise._BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)):
|
||||||
|
_ = action_noise.Wrapper(env)
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
env = self.make_mock_env()
|
||||||
|
wrapped_env = action_noise.Wrapper(env)
|
||||||
|
time_step = wrapped_env.reset()
|
||||||
|
env.reset.assert_called_once_with()
|
||||||
|
self.assertIs(time_step, env.reset())
|
||||||
|
|
||||||
|
def test_observation_spec(self):
|
||||||
|
env = self.make_mock_env()
|
||||||
|
wrapped_env = action_noise.Wrapper(env)
|
||||||
|
observation_spec = wrapped_env.observation_spec()
|
||||||
|
env.observation_spec.assert_called_once_with()
|
||||||
|
self.assertIs(observation_spec, env.observation_spec())
|
||||||
|
|
||||||
|
def test_action_spec(self):
|
||||||
|
env = self.make_mock_env()
|
||||||
|
wrapped_env = action_noise.Wrapper(env)
|
||||||
|
# `env.action_spec()` is called in `Wrapper.__init__()`
|
||||||
|
env.action_spec.reset_mock()
|
||||||
|
action_spec = wrapped_env.action_spec()
|
||||||
|
env.action_spec.assert_called_once_with()
|
||||||
|
self.assertIs(action_spec, env.action_spec())
|
||||||
|
|
||||||
|
@parameterized.parameters(['task', 'physics', 'control_timestep'])
|
||||||
|
def test_getattr(self, attribute_name):
|
||||||
|
env = self.make_mock_env()
|
||||||
|
wrapped_env = action_noise.Wrapper(env)
|
||||||
|
attr = getattr(wrapped_env, attribute_name)
|
||||||
|
self.assertIs(attr, getattr(env, attribute_name))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
120
local_dm_control_suite/wrappers/pixels.py
Executable file
120
local_dm_control_suite/wrappers/pixels.py
Executable file
@ -0,0 +1,120 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Wrapper that adds pixel observations to a control environment."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import dm_env
|
||||||
|
from dm_env import specs
|
||||||
|
|
||||||
|
STATE_KEY = 'state'
|
||||||
|
|
||||||
|
|
||||||
|
class Wrapper(dm_env.Environment):
|
||||||
|
"""Wraps a control environment and adds a rendered pixel observation."""
|
||||||
|
|
||||||
|
def __init__(self, env, pixels_only=True, render_kwargs=None,
|
||||||
|
observation_key='pixels'):
|
||||||
|
"""Initializes a new pixel Wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap.
|
||||||
|
pixels_only: If True (default), the original set of 'state' observations
|
||||||
|
returned by the wrapped environment will be discarded, and the
|
||||||
|
`OrderedDict` of observations will only contain pixels. If False, the
|
||||||
|
`OrderedDict` will contain the original observations as well as the
|
||||||
|
pixel observations.
|
||||||
|
render_kwargs: Optional `dict` containing keyword arguments passed to the
|
||||||
|
`mujoco.Physics.render` method.
|
||||||
|
observation_key: Optional custom string specifying the pixel observation's
|
||||||
|
key in the `OrderedDict` of observations. Defaults to 'pixels'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `env`'s observation spec is not compatible with the
|
||||||
|
wrapper. Supported formats are a single array, or a dict of arrays.
|
||||||
|
ValueError: If `env`'s observation already contains the specified
|
||||||
|
`observation_key`.
|
||||||
|
"""
|
||||||
|
if render_kwargs is None:
|
||||||
|
render_kwargs = {}
|
||||||
|
|
||||||
|
wrapped_observation_spec = env.observation_spec()
|
||||||
|
|
||||||
|
if isinstance(wrapped_observation_spec, specs.Array):
|
||||||
|
self._observation_is_dict = False
|
||||||
|
invalid_keys = set([STATE_KEY])
|
||||||
|
elif isinstance(wrapped_observation_spec, collections.MutableMapping):
|
||||||
|
self._observation_is_dict = True
|
||||||
|
invalid_keys = set(wrapped_observation_spec.keys())
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported observation spec structure.')
|
||||||
|
|
||||||
|
if not pixels_only and observation_key in invalid_keys:
|
||||||
|
raise ValueError('Duplicate or reserved observation key {!r}.'
|
||||||
|
.format(observation_key))
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
self._observation_spec = collections.OrderedDict()
|
||||||
|
elif self._observation_is_dict:
|
||||||
|
self._observation_spec = wrapped_observation_spec.copy()
|
||||||
|
else:
|
||||||
|
self._observation_spec = collections.OrderedDict()
|
||||||
|
self._observation_spec[STATE_KEY] = wrapped_observation_spec
|
||||||
|
|
||||||
|
# Extend observation spec.
|
||||||
|
pixels = env.physics.render(**render_kwargs)
|
||||||
|
pixels_spec = specs.Array(
|
||||||
|
shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
|
||||||
|
self._observation_spec[observation_key] = pixels_spec
|
||||||
|
|
||||||
|
self._env = env
|
||||||
|
self._pixels_only = pixels_only
|
||||||
|
self._render_kwargs = render_kwargs
|
||||||
|
self._observation_key = observation_key
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
time_step = self._env.reset()
|
||||||
|
return self._add_pixel_observation(time_step)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
time_step = self._env.step(action)
|
||||||
|
return self._add_pixel_observation(time_step)
|
||||||
|
|
||||||
|
def observation_spec(self):
|
||||||
|
return self._observation_spec
|
||||||
|
|
||||||
|
def action_spec(self):
|
||||||
|
return self._env.action_spec()
|
||||||
|
|
||||||
|
def _add_pixel_observation(self, time_step):
|
||||||
|
if self._pixels_only:
|
||||||
|
observation = collections.OrderedDict()
|
||||||
|
elif self._observation_is_dict:
|
||||||
|
observation = type(time_step.observation)(time_step.observation)
|
||||||
|
else:
|
||||||
|
observation = collections.OrderedDict()
|
||||||
|
observation[STATE_KEY] = time_step.observation
|
||||||
|
|
||||||
|
pixels = self._env.physics.render(**self._render_kwargs)
|
||||||
|
observation[self._observation_key] = pixels
|
||||||
|
return time_step._replace(observation=observation)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._env, name)
|
133
local_dm_control_suite/wrappers/pixels_test.py
Executable file
133
local_dm_control_suite/wrappers/pixels_test.py
Executable file
@ -0,0 +1,133 @@
|
|||||||
|
# Copyright 2017 The dm_control Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tests for the pixel wrapper."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
# Internal dependencies.
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
from local_dm_control_suite import cartpole
|
||||||
|
from dm_control.suite.wrappers import pixels
|
||||||
|
import dm_env
|
||||||
|
from dm_env import specs
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class FakePhysics(object):
|
||||||
|
|
||||||
|
def render(self, *args, **kwargs):
|
||||||
|
del args
|
||||||
|
del kwargs
|
||||||
|
return np.zeros((4, 5, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeArrayObservationEnvironment(dm_env.Environment):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.physics = FakePhysics()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return dm_env.restart(np.zeros((2,)))
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
del action
|
||||||
|
return dm_env.transition(0.0, np.zeros((2,)))
|
||||||
|
|
||||||
|
def action_spec(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def observation_spec(self):
|
||||||
|
return specs.Array(shape=(2,), dtype=np.float)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelsTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.parameters(True, False)
|
||||||
|
def test_dict_observation(self, pixels_only):
|
||||||
|
pixel_key = 'rgb'
|
||||||
|
|
||||||
|
env = cartpole.swingup()
|
||||||
|
|
||||||
|
# Make sure we are testing the right environment for the test.
|
||||||
|
observation_spec = env.observation_spec()
|
||||||
|
self.assertIsInstance(observation_spec, collections.OrderedDict)
|
||||||
|
|
||||||
|
width = 320
|
||||||
|
height = 240
|
||||||
|
|
||||||
|
# The wrapper should only add one observation.
|
||||||
|
wrapped = pixels.Wrapper(env,
|
||||||
|
observation_key=pixel_key,
|
||||||
|
pixels_only=pixels_only,
|
||||||
|
render_kwargs={'width': width, 'height': height})
|
||||||
|
|
||||||
|
wrapped_observation_spec = wrapped.observation_spec()
|
||||||
|
self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
self.assertLen(wrapped_observation_spec, 1)
|
||||||
|
self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
|
||||||
|
else:
|
||||||
|
expected_length = len(observation_spec) + 1
|
||||||
|
self.assertLen(wrapped_observation_spec, expected_length)
|
||||||
|
expected_keys = list(observation_spec.keys()) + [pixel_key]
|
||||||
|
self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
|
||||||
|
|
||||||
|
# Check that the added spec item is consistent with the added observation.
|
||||||
|
time_step = wrapped.reset()
|
||||||
|
rgb_observation = time_step.observation[pixel_key]
|
||||||
|
wrapped_observation_spec[pixel_key].validate(rgb_observation)
|
||||||
|
|
||||||
|
self.assertEqual(rgb_observation.shape, (height, width, 3))
|
||||||
|
self.assertEqual(rgb_observation.dtype, np.uint8)
|
||||||
|
|
||||||
|
@parameterized.parameters(True, False)
|
||||||
|
def test_single_array_observation(self, pixels_only):
|
||||||
|
pixel_key = 'depth'
|
||||||
|
|
||||||
|
env = FakeArrayObservationEnvironment()
|
||||||
|
observation_spec = env.observation_spec()
|
||||||
|
self.assertIsInstance(observation_spec, specs.Array)
|
||||||
|
|
||||||
|
wrapped = pixels.Wrapper(env, observation_key=pixel_key,
|
||||||
|
pixels_only=pixels_only)
|
||||||
|
wrapped_observation_spec = wrapped.observation_spec()
|
||||||
|
self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
self.assertLen(wrapped_observation_spec, 1)
|
||||||
|
self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
|
||||||
|
else:
|
||||||
|
self.assertLen(wrapped_observation_spec, 2)
|
||||||
|
self.assertEqual([pixels.STATE_KEY, pixel_key],
|
||||||
|
list(wrapped_observation_spec.keys()))
|
||||||
|
|
||||||
|
time_step = wrapped.reset()
|
||||||
|
|
||||||
|
depth_observation = time_step.observation[pixel_key]
|
||||||
|
wrapped_observation_spec[pixel_key].validate(depth_observation)
|
||||||
|
|
||||||
|
self.assertEqual(depth_observation.shape, (4, 5, 3))
|
||||||
|
self.assertEqual(depth_observation.dtype, np.uint8)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
170
logger.py
Normal file
170
logger.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import numpy as np
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
FORMAT_CONFIG = {
|
||||||
|
'rl': {
|
||||||
|
'train': [
|
||||||
|
('episode', 'E', 'int'), ('step', 'S', 'int'),
|
||||||
|
('duration', 'D', 'time'), ('episode_reward', 'R', 'float'),
|
||||||
|
('batch_reward', 'BR', 'float'), ('actor_loss', 'ALOSS', 'float'),
|
||||||
|
('critic_loss', 'CLOSS', 'float'), ('ae_loss', 'RLOSS', 'float'),
|
||||||
|
('max_rat', 'MR', 'float')
|
||||||
|
],
|
||||||
|
'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._sum = 0
|
||||||
|
self._count = 0
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self._sum += value
|
||||||
|
self._count += n
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
return self._sum / max(1, self._count)
|
||||||
|
|
||||||
|
|
||||||
|
class MetersGroup(object):
|
||||||
|
def __init__(self, file_name, formating):
|
||||||
|
self._file_name = file_name
|
||||||
|
if os.path.exists(file_name):
|
||||||
|
os.remove(file_name)
|
||||||
|
self._formating = formating
|
||||||
|
self._meters = defaultdict(AverageMeter)
|
||||||
|
|
||||||
|
def log(self, key, value, n=1):
|
||||||
|
self._meters[key].update(value, n)
|
||||||
|
|
||||||
|
def _prime_meters(self):
|
||||||
|
data = dict()
|
||||||
|
for key, meter in self._meters.items():
|
||||||
|
if key.startswith('train'):
|
||||||
|
key = key[len('train') + 1:]
|
||||||
|
else:
|
||||||
|
key = key[len('eval') + 1:]
|
||||||
|
key = key.replace('/', '_')
|
||||||
|
data[key] = meter.value()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _dump_to_file(self, data):
|
||||||
|
with open(self._file_name, 'a') as f:
|
||||||
|
f.write(json.dumps(data) + '\n')
|
||||||
|
|
||||||
|
def _format(self, key, value, ty):
|
||||||
|
template = '%s: '
|
||||||
|
if ty == 'int':
|
||||||
|
template += '%d'
|
||||||
|
elif ty == 'float':
|
||||||
|
template += '%.04f'
|
||||||
|
elif ty == 'time':
|
||||||
|
template += '%.01f s'
|
||||||
|
else:
|
||||||
|
raise 'invalid format type: %s' % ty
|
||||||
|
return template % (key, value)
|
||||||
|
|
||||||
|
def _dump_to_console(self, data, prefix):
|
||||||
|
prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
|
||||||
|
pieces = ['{:5}'.format(prefix)]
|
||||||
|
for key, disp_key, ty in self._formating:
|
||||||
|
value = data.get(key, 0)
|
||||||
|
pieces.append(self._format(disp_key, value, ty))
|
||||||
|
print('| %s' % (' | '.join(pieces)))
|
||||||
|
|
||||||
|
def dump(self, step, prefix):
|
||||||
|
if len(self._meters) == 0:
|
||||||
|
return
|
||||||
|
data = self._prime_meters()
|
||||||
|
data['step'] = step
|
||||||
|
self._dump_to_file(data)
|
||||||
|
self._dump_to_console(data, prefix)
|
||||||
|
self._meters.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(object):
|
||||||
|
def __init__(self, log_dir, use_tb=True, config='rl'):
|
||||||
|
self._log_dir = log_dir
|
||||||
|
if use_tb:
|
||||||
|
tb_dir = os.path.join(log_dir, 'tb')
|
||||||
|
if os.path.exists(tb_dir):
|
||||||
|
shutil.rmtree(tb_dir)
|
||||||
|
self._sw = SummaryWriter(tb_dir)
|
||||||
|
else:
|
||||||
|
self._sw = None
|
||||||
|
self._train_mg = MetersGroup(
|
||||||
|
os.path.join(log_dir, 'train.log'),
|
||||||
|
formating=FORMAT_CONFIG[config]['train']
|
||||||
|
)
|
||||||
|
self._eval_mg = MetersGroup(
|
||||||
|
os.path.join(log_dir, 'eval.log'),
|
||||||
|
formating=FORMAT_CONFIG[config]['eval']
|
||||||
|
)
|
||||||
|
|
||||||
|
def _try_sw_log(self, key, value, step):
|
||||||
|
if self._sw is not None:
|
||||||
|
self._sw.add_scalar(key, value, step)
|
||||||
|
|
||||||
|
def _try_sw_log_image(self, key, image, step):
|
||||||
|
if self._sw is not None:
|
||||||
|
assert image.dim() == 3
|
||||||
|
grid = torchvision.utils.make_grid(image.unsqueeze(1))
|
||||||
|
self._sw.add_image(key, grid, step)
|
||||||
|
|
||||||
|
def _try_sw_log_video(self, key, frames, step):
|
||||||
|
if self._sw is not None:
|
||||||
|
frames = torch.from_numpy(np.array(frames))
|
||||||
|
frames = frames.unsqueeze(0)
|
||||||
|
self._sw.add_video(key, frames, step, fps=30)
|
||||||
|
|
||||||
|
def _try_sw_log_histogram(self, key, histogram, step):
|
||||||
|
if self._sw is not None:
|
||||||
|
self._sw.add_histogram(key, histogram, step)
|
||||||
|
|
||||||
|
def log(self, key, value, step, n=1):
|
||||||
|
assert key.startswith('train') or key.startswith('eval')
|
||||||
|
if type(value) == torch.Tensor:
|
||||||
|
value = value.item()
|
||||||
|
self._try_sw_log(key, value / n, step)
|
||||||
|
mg = self._train_mg if key.startswith('train') else self._eval_mg
|
||||||
|
mg.log(key, value, n)
|
||||||
|
|
||||||
|
def log_param(self, key, param, step):
|
||||||
|
self.log_histogram(key + '_w', param.weight.data, step)
|
||||||
|
if hasattr(param.weight, 'grad') and param.weight.grad is not None:
|
||||||
|
self.log_histogram(key + '_w_g', param.weight.grad.data, step)
|
||||||
|
if hasattr(param, 'bias'):
|
||||||
|
self.log_histogram(key + '_b', param.bias.data, step)
|
||||||
|
if hasattr(param.bias, 'grad') and param.bias.grad is not None:
|
||||||
|
self.log_histogram(key + '_b_g', param.bias.grad.data, step)
|
||||||
|
|
||||||
|
def log_image(self, key, image, step):
|
||||||
|
assert key.startswith('train') or key.startswith('eval')
|
||||||
|
self._try_sw_log_image(key, image, step)
|
||||||
|
|
||||||
|
def log_video(self, key, frames, step):
|
||||||
|
assert key.startswith('train') or key.startswith('eval')
|
||||||
|
self._try_sw_log_video(key, frames, step)
|
||||||
|
|
||||||
|
def log_histogram(self, key, histogram, step):
|
||||||
|
assert key.startswith('train') or key.startswith('eval')
|
||||||
|
self._try_sw_log_histogram(key, histogram, step)
|
||||||
|
|
||||||
|
def dump(self, step):
|
||||||
|
self._train_mg.dump(step, 'train')
|
||||||
|
self._eval_mg.dump(step, 'eval')
|
19
run_all.sh
Executable file
19
run_all.sh
Executable file
@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NOW=$(date +"%m%d%H%M")
|
||||||
|
|
||||||
|
./$1 cartpole swingup 2 ${NOW}
|
||||||
|
./$1 reacher easy 2 ${NOW}
|
||||||
|
./$1 cheetah run 2 ${NOW}
|
||||||
|
./$1 finger spin 2 ${NOW}
|
||||||
|
# ./$1 ball_in_cup catch 2 ${NOW}
|
||||||
|
./$1 walker walk 2 ${NOW}
|
||||||
|
./$1 walker stand 2 ${NOW}
|
||||||
|
./$1 walker run 2 ${NOW}
|
||||||
|
# ./$1 acrobot swingup 2 ${NOW}
|
||||||
|
./$1 hopper stand 2 ${NOW}
|
||||||
|
./$1 hopper hop 2 ${NOW}
|
||||||
|
# ./$1 manipulator bring_ball 2 ${NOW}
|
||||||
|
# ./$1 humanoid stand 2 ${NOW}
|
||||||
|
# ./$1 humanoid walk 2 ${NOW}
|
||||||
|
# ./$1 humanoid run 2 ${NOW}
|
89
run_cluster.sh
Executable file
89
run_cluster.sh
Executable file
@ -0,0 +1,89 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CURDIR=`pwd`
|
||||||
|
CODEDIR=`mktemp -d -p ${CURDIR}/tmp`
|
||||||
|
|
||||||
|
cp ${CURDIR}/*.py ${CODEDIR}
|
||||||
|
cp -r ${CURDIR}/local_dm_control_suite ${CODEDIR}/
|
||||||
|
cp -r ${CURDIR}/dmc2gym ${CODEDIR}/
|
||||||
|
cp -r ${CURDIR}/agent ${CODEDIR}/
|
||||||
|
|
||||||
|
DOMAIN=${1:-walker}
|
||||||
|
TASK=${2:-walk}
|
||||||
|
ACTION_REPEAT=${3:-2}
|
||||||
|
NOW=${4:-$(date +"%m%d%H%M")}
|
||||||
|
ENCODER_TYPE=pixel
|
||||||
|
|
||||||
|
DECODER_TYPE=identity
|
||||||
|
NUM_LAYERS=4
|
||||||
|
NUM_FILTERS=32
|
||||||
|
IMG_SOURCE=video
|
||||||
|
AGENT=bisim
|
||||||
|
BATCH_SIZE=512
|
||||||
|
ENCODER_LR=0.001
|
||||||
|
NUM_FRAMES=100
|
||||||
|
BISIM_COEF=0.5
|
||||||
|
CDIR=/checkpoint/${USER}/DBC/${DOMAIN}_${TASK}
|
||||||
|
mkdir -p ${CDIR}
|
||||||
|
|
||||||
|
for NUM_FRAMES in 1000; do
|
||||||
|
for TRANSITION_MODEL_TYPE in 'ensemble'; do
|
||||||
|
for SEED in 1 2 3; do
|
||||||
|
SUBDIR=${AGENT}_${BISIM_COEF}coef_${TRANSITION_MODEL_TYPE}_frames${NUM_FRAMES}_${IMG_SOURCE}kinetics/seed_${SEED}
|
||||||
|
SAVEDIR=${CDIR}/${SUBDIR}
|
||||||
|
mkdir -p ${SAVEDIR}
|
||||||
|
JOBNAME=${NOW}_${DOMAIN}_${TASK}
|
||||||
|
SCRIPT=${SAVEDIR}/run.sh
|
||||||
|
SLURM=${SAVEDIR}/run.slrm
|
||||||
|
CODEREF=${SAVEDIR}/code
|
||||||
|
extra=""
|
||||||
|
echo "#!/bin/sh" > ${SCRIPT}
|
||||||
|
echo "#!/bin/sh" > ${SLURM}
|
||||||
|
echo ${CODEDIR} > ${CODEREF}
|
||||||
|
echo "#SBATCH --job-name=${JOBNAME}" >> ${SLURM}
|
||||||
|
echo "#SBATCH --output=${SAVEDIR}/stdout" >> ${SLURM}
|
||||||
|
echo "#SBATCH --error=${SAVEDIR}/stderr" >> ${SLURM}
|
||||||
|
echo "#SBATCH --partition=learnfair" >> ${SLURM}
|
||||||
|
echo "#SBATCH --nodes=1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --time=4000" >> ${SLURM}
|
||||||
|
echo "#SBATCH --ntasks-per-node=1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --signal=USR1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --gres=gpu:volta:1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --mem=500000" >> ${SLURM}
|
||||||
|
echo "#SBATCH -c 1" >> ${SLURM}
|
||||||
|
echo "srun sh ${SCRIPT}" >> ${SLURM}
|
||||||
|
echo "echo \$SLURM_JOB_ID >> ${SAVEDIR}/id" >> ${SCRIPT}
|
||||||
|
echo "nvidia-smi" >> ${SCRIPT}
|
||||||
|
echo "cd ${CODEDIR}" >> ${SCRIPT}
|
||||||
|
echo MUJOCO_GL="osmesa" LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/nvidia-opengl/:$LD_LIBRARY_PATH python train.py \
|
||||||
|
--domain_name ${DOMAIN} \
|
||||||
|
--task_name ${TASK} \
|
||||||
|
--agent ${AGENT} \
|
||||||
|
--init_steps 1000 \
|
||||||
|
--bisim_coef ${BISIM_COEF} \
|
||||||
|
--num_train_steps 1000000 \
|
||||||
|
--encoder_type ${ENCODER_TYPE} \
|
||||||
|
--decoder_type ${DECODER_TYPE} \
|
||||||
|
--encoder_lr ${ENCODER_LR} \
|
||||||
|
--action_repeat ${ACTION_REPEAT} \
|
||||||
|
--img_source ${IMG_SOURCE} \
|
||||||
|
--num_layers ${NUM_LAYERS} \
|
||||||
|
--num_filters ${NUM_FILTERS} \
|
||||||
|
--resource_files \'/datasets01/kinetics/070618/400/train/driving_car/*.mp4\' \
|
||||||
|
--eval_resource_files \'/datasets01/kinetics/070618/400/train/driving_car/*.mp4\' \
|
||||||
|
--critic_tau 0.01 \
|
||||||
|
--encoder_tau 0.05 \
|
||||||
|
--total_frames ${NUM_FRAMES} \
|
||||||
|
--decoder_weight_lambda 0.0000001 \
|
||||||
|
--hidden_dim 1024 \
|
||||||
|
--batch_size ${BATCH_SIZE} \
|
||||||
|
--transition_model_type ${TRANSITION_MODEL_TYPE} \
|
||||||
|
--init_temperature 0.1 \
|
||||||
|
--alpha_lr 1e-4 \
|
||||||
|
--alpha_beta 0.5\
|
||||||
|
--work_dir ${SAVEDIR} \
|
||||||
|
--seed ${SEED} >> ${SCRIPT}
|
||||||
|
sbatch ${SLURM}
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
82
run_cluster_nobg.sh
Executable file
82
run_cluster_nobg.sh
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CURDIR=`pwd`
|
||||||
|
CODEDIR=`mktemp -d -p ${CURDIR}/tmp`
|
||||||
|
|
||||||
|
cp ${CURDIR}/*.py ${CODEDIR}
|
||||||
|
cp -r ${CURDIR}/local_dm_control_suite ${CODEDIR}/
|
||||||
|
cp -r ${CURDIR}/dmc2gym ${CODEDIR}/
|
||||||
|
cp -r ${CURDIR}/agent ${CODEDIR}/
|
||||||
|
|
||||||
|
DOMAIN=${1:-walker}
|
||||||
|
TASK=${2:-walk}
|
||||||
|
ACTION_REPEAT=${3:-2}
|
||||||
|
NOW=${4:-$(date +"%m%d%H%M")}
|
||||||
|
ENCODER_TYPE=pixel
|
||||||
|
|
||||||
|
DECODER_TYPE=pixel
|
||||||
|
NUM_LAYERS=4
|
||||||
|
NUM_FILTERS=32
|
||||||
|
IMG_SOURCE=video
|
||||||
|
AGENT=bisim
|
||||||
|
|
||||||
|
CDIR=/checkpoint/${USER}/DBC/${DOMAIN}_${TASK}
|
||||||
|
mkdir -p ${CDIR}
|
||||||
|
|
||||||
|
for TRANSITION_MODEL_TYPE in 'probabilistic'; do
|
||||||
|
for DECODER_TYPE in 'identity'; do
|
||||||
|
for SEED in 1 2 3; do
|
||||||
|
SUBDIR=${AGENT}_transition${TRANSITION_MODEL_TYPE}_nobg/seed_${SEED}
|
||||||
|
SAVEDIR=${CDIR}/${SUBDIR}
|
||||||
|
mkdir -p ${SAVEDIR}
|
||||||
|
JOBNAME=${NOW}_${DOMAIN}_${TASK}
|
||||||
|
SCRIPT=${SAVEDIR}/run.sh
|
||||||
|
SLURM=${SAVEDIR}/run.slrm
|
||||||
|
CODEREF=${SAVEDIR}/code
|
||||||
|
extra=""
|
||||||
|
echo "#!/bin/sh" > ${SCRIPT}
|
||||||
|
echo "#!/bin/sh" > ${SLURM}
|
||||||
|
echo ${CODEDIR} > ${CODEREF}
|
||||||
|
echo "#SBATCH --job-name=${JOBNAME}" >> ${SLURM}
|
||||||
|
echo "#SBATCH --output=${SAVEDIR}/stdout" >> ${SLURM}
|
||||||
|
echo "#SBATCH --error=${SAVEDIR}/stderr" >> ${SLURM}
|
||||||
|
echo "#SBATCH --partition=learnfair" >> ${SLURM}
|
||||||
|
echo "#SBATCH --nodes=1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --time=4000" >> ${SLURM}
|
||||||
|
echo "#SBATCH --ntasks-per-node=1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --signal=USR1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --gres=gpu:volta:1" >> ${SLURM}
|
||||||
|
echo "#SBATCH --mem=500000" >> ${SLURM}
|
||||||
|
echo "#SBATCH -c 1" >> ${SLURM}
|
||||||
|
echo "srun sh ${SCRIPT}" >> ${SLURM}
|
||||||
|
echo "echo \$SLURM_JOB_ID >> ${SAVEDIR}/id" >> ${SCRIPT}
|
||||||
|
echo "nvidia-smi" >> ${SCRIPT}
|
||||||
|
echo "cd ${CODEDIR}" >> ${SCRIPT}
|
||||||
|
echo MUJOCO_GL="osmesa" LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/nvidia-opengl/:$LD_LIBRARY_PATH python train.py \
|
||||||
|
--domain_name ${DOMAIN} \
|
||||||
|
--task_name ${TASK} \
|
||||||
|
--agent ${AGENT} \
|
||||||
|
--init_steps 1000 \
|
||||||
|
--num_train_steps 1000000 \
|
||||||
|
--encoder_type ${ENCODER_TYPE} \
|
||||||
|
--decoder_type ${DECODER_TYPE} \
|
||||||
|
--action_repeat ${ACTION_REPEAT} \
|
||||||
|
--resource_files \'/datasets01/kinetics/070618/400/train/driving_car/*.mp4\' \
|
||||||
|
--num_layers ${NUM_LAYERS} \
|
||||||
|
--num_filters ${NUM_FILTERS} \
|
||||||
|
--transition_model_type ${TRANSITION_MODEL_TYPE} \
|
||||||
|
--critic_tau 0.01 \
|
||||||
|
--encoder_tau 0.05 \
|
||||||
|
--decoder_weight_lambda 0.0000001 \
|
||||||
|
--hidden_dim 1024 \
|
||||||
|
--batch_size 128 \
|
||||||
|
--init_temperature 0.1 \
|
||||||
|
--alpha_lr 1e-4 \
|
||||||
|
--alpha_beta 0.5\
|
||||||
|
--save_model \
|
||||||
|
--work_dir ${SAVEDIR} \
|
||||||
|
--seed ${SEED} >> ${SCRIPT}
|
||||||
|
sbatch ${SLURM}
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
32
run_local.sh
Executable file
32
run_local.sh
Executable file
@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
DOMAIN=cartpole
|
||||||
|
TASK=swingup
|
||||||
|
|
||||||
|
SAVEDIR=./save
|
||||||
|
|
||||||
|
MUJOCO_GL="osmesa" LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/nvidia-opengl/:$LD_LIBRARY_PATH CUDA_VISIBLE_DEVICES=1 python train.py \
|
||||||
|
--domain_name ${DOMAIN} \
|
||||||
|
--task_name ${TASK} \
|
||||||
|
--agent 'bisim' \
|
||||||
|
--init_steps 1000 \
|
||||||
|
--num_train_steps 1000000 \
|
||||||
|
--encoder_type pixel \
|
||||||
|
--decoder_type pixel \
|
||||||
|
--img_source video \
|
||||||
|
--resource_files 'distractors/*.mp4' \
|
||||||
|
--transition_model_type 'probabilistic' \
|
||||||
|
--action_repeat 2 \
|
||||||
|
--critic_tau 0.01 \
|
||||||
|
--encoder_tau 0.05 \
|
||||||
|
--decoder_weight_lambda 0.0000001 \
|
||||||
|
--hidden_dim 1024 \
|
||||||
|
--total_frames 1000 \
|
||||||
|
--num_layers 4 \
|
||||||
|
--num_filters 32 \
|
||||||
|
--batch_size 128 \
|
||||||
|
--init_temperature 0.1 \
|
||||||
|
--alpha_lr 1e-4 \
|
||||||
|
--alpha_beta 0.5 \
|
||||||
|
--work_dir ${SAVEDIR}/${DOMAIN}_${TASK} \
|
||||||
|
--seed 1 $@
|
43
run_local_carla096.sh
Executable file
43
run_local_carla096.sh
Executable file
@ -0,0 +1,43 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
DOMAIN=carla096
|
||||||
|
TASK=highway
|
||||||
|
AGENT=deepmdp
|
||||||
|
SEED=5
|
||||||
|
DECODER_TYPE=identity
|
||||||
|
TRANSITION_MODEL=deterministic
|
||||||
|
|
||||||
|
SAVEDIR=./save
|
||||||
|
#SAVEDIR=/checkpoint/${USER}/pixel-pets/carla/${AGENT}_${TRANSITION_MODEL}_currrew_${DECODER_TYPE}/seed_${SEED}
|
||||||
|
mkdir -p ${SAVEDIR}
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=1 python train.py \
|
||||||
|
--domain_name ${DOMAIN} \
|
||||||
|
--task_name ${TASK} \
|
||||||
|
--agent ${AGENT} \
|
||||||
|
--init_steps 100 \
|
||||||
|
--num_train_steps 100000 \
|
||||||
|
--encoder_type pixelCarla096 \
|
||||||
|
--decoder_type ${DECODER_TYPE} \
|
||||||
|
--resource_files 'distractors/*.mp4' \
|
||||||
|
--action_repeat 4 \
|
||||||
|
--critic_tau 0.01 \
|
||||||
|
--encoder_tau 0.05 \
|
||||||
|
--encoder_stride 2 \
|
||||||
|
--decoder_weight_lambda 0.0000001 \
|
||||||
|
--hidden_dim 1024 \
|
||||||
|
--replay_buffer_capacity 100000 \
|
||||||
|
--total_frames 10000 \
|
||||||
|
--num_layers 4 \
|
||||||
|
--num_filters 32 \
|
||||||
|
--batch_size 128 \
|
||||||
|
--init_temperature 0.1 \
|
||||||
|
--alpha_lr 1e-4 \
|
||||||
|
--alpha_beta 0.5 \
|
||||||
|
--work_dir ${SAVEDIR} \
|
||||||
|
--transition_model_type ${TRANSITION_MODEL} \
|
||||||
|
--seed ${SEED} $@ \
|
||||||
|
--frame_stack 3 \
|
||||||
|
--image_size 84 \
|
||||||
|
--save_model >> ${SAVEDIR}/output.txt \
|
||||||
|
# --render
|
36
run_local_carla098.sh
Executable file
36
run_local_carla098.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
DOMAIN=carla098
|
||||||
|
TASK=highway
|
||||||
|
|
||||||
|
SAVEDIR=./save
|
||||||
|
mkdir -p ${SAVEDIR}
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python train.py \
|
||||||
|
--domain_name ${DOMAIN} \
|
||||||
|
--task_name ${TASK} \
|
||||||
|
--agent 'deepmdp' \
|
||||||
|
--init_steps 1000 \
|
||||||
|
--num_train_steps 1000000 \
|
||||||
|
--encoder_type pixelCarla098 \
|
||||||
|
--decoder_type pixel \
|
||||||
|
--img_source video \
|
||||||
|
--resource_files 'distractors/*.mp4' \
|
||||||
|
--action_repeat 4 \
|
||||||
|
--critic_tau 0.01 \
|
||||||
|
--encoder_tau 0.05 \
|
||||||
|
--decoder_weight_lambda 0.0000001 \
|
||||||
|
--hidden_dim 1024 \
|
||||||
|
--total_frames 10000 \
|
||||||
|
--num_filters 32 \
|
||||||
|
--batch_size 128 \
|
||||||
|
--init_temperature 0.1 \
|
||||||
|
--alpha_lr 1e-4 \
|
||||||
|
--alpha_beta 0.5 \
|
||||||
|
--work_dir ${SAVEDIR}/${DOMAIN}_${TASK} \
|
||||||
|
--seed 1 $@ \
|
||||||
|
--frame_stack 3 \
|
||||||
|
--image_size 84 \
|
||||||
|
--eval_freq 10000 \
|
||||||
|
--num_eval_episodes 25 \
|
||||||
|
--render
|
188
sac_ae.py
Normal file
188
sac_ae.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from encoder import make_encoder
|
||||||
|
|
||||||
|
LOG_FREQ = 10000
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_logprob(noise, log_std):
|
||||||
|
"""Compute Gaussian log probability."""
|
||||||
|
residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
|
||||||
|
return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def squash(mu, pi, log_pi):
|
||||||
|
"""Apply squashing function.
|
||||||
|
See appendix C from https://arxiv.org/pdf/1812.05905.pdf.
|
||||||
|
"""
|
||||||
|
mu = torch.tanh(mu)
|
||||||
|
if pi is not None:
|
||||||
|
pi = torch.tanh(pi)
|
||||||
|
if log_pi is not None:
|
||||||
|
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
||||||
|
return mu, pi, log_pi
|
||||||
|
|
||||||
|
|
||||||
|
def weight_init(m):
|
||||||
|
"""Custom weight init for Conv2D and Linear layers."""
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.orthogonal_(m.weight.data)
|
||||||
|
m.bias.data.fill_(0.0)
|
||||||
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||||
|
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
|
||||||
|
assert m.weight.size(2) == m.weight.size(3)
|
||||||
|
m.weight.data.fill_(0.0)
|
||||||
|
m.bias.data.fill_(0.0)
|
||||||
|
mid = m.weight.size(2) // 2
|
||||||
|
gain = nn.init.calculate_gain('relu')
|
||||||
|
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(nn.Module):
|
||||||
|
"""MLP actor network."""
|
||||||
|
def __init__(
|
||||||
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters, stride
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = make_encoder(
|
||||||
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||||
|
num_filters, stride
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_std_min = log_std_min
|
||||||
|
self.log_std_max = log_std_max
|
||||||
|
|
||||||
|
self.trunk = nn.Sequential(
|
||||||
|
nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, 2 * action_shape[0])
|
||||||
|
)
|
||||||
|
|
||||||
|
self.outputs = dict()
|
||||||
|
self.apply(weight_init)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
|
||||||
|
):
|
||||||
|
obs = self.encoder(obs, detach=detach_encoder)
|
||||||
|
|
||||||
|
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
# constrain log_std inside [log_std_min, log_std_max]
|
||||||
|
log_std = torch.tanh(log_std)
|
||||||
|
log_std = self.log_std_min + 0.5 * (
|
||||||
|
self.log_std_max - self.log_std_min
|
||||||
|
) * (log_std + 1)
|
||||||
|
|
||||||
|
self.outputs['mu'] = mu
|
||||||
|
self.outputs['std'] = log_std.exp()
|
||||||
|
|
||||||
|
if compute_pi:
|
||||||
|
std = log_std.exp()
|
||||||
|
noise = torch.randn_like(mu)
|
||||||
|
pi = mu + noise * std
|
||||||
|
else:
|
||||||
|
pi = None
|
||||||
|
entropy = None
|
||||||
|
|
||||||
|
if compute_log_pi:
|
||||||
|
log_pi = gaussian_logprob(noise, log_std)
|
||||||
|
else:
|
||||||
|
log_pi = None
|
||||||
|
|
||||||
|
mu, pi, log_pi = squash(mu, pi, log_pi)
|
||||||
|
|
||||||
|
return mu, pi, log_pi, log_std
|
||||||
|
|
||||||
|
def log(self, L, step, log_freq=LOG_FREQ):
|
||||||
|
if step % log_freq != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for k, v in self.outputs.items():
|
||||||
|
L.log_histogram('train_actor/%s_hist' % k, v, step)
|
||||||
|
|
||||||
|
L.log_param('train_actor/fc1', self.trunk[0], step)
|
||||||
|
L.log_param('train_actor/fc2', self.trunk[2], step)
|
||||||
|
L.log_param('train_actor/fc3', self.trunk[4], step)
|
||||||
|
|
||||||
|
|
||||||
|
class QFunction(nn.Module):
|
||||||
|
"""MLP for q-function."""
|
||||||
|
def __init__(self, obs_dim, action_dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.trunk = nn.Sequential(
|
||||||
|
nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, obs, action):
|
||||||
|
assert obs.size(0) == action.size(0)
|
||||||
|
|
||||||
|
obs_action = torch.cat([obs, action], dim=1)
|
||||||
|
return self.trunk(obs_action)
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
"""Critic network, employes two q-functions."""
|
||||||
|
def __init__(
|
||||||
|
self, obs_shape, action_shape, hidden_dim, encoder_type,
|
||||||
|
encoder_feature_dim, num_layers, num_filters, stride
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = make_encoder(
|
||||||
|
encoder_type, obs_shape, encoder_feature_dim, num_layers,
|
||||||
|
num_filters, stride
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Q1 = QFunction(
|
||||||
|
self.encoder.feature_dim, action_shape[0], hidden_dim
|
||||||
|
)
|
||||||
|
self.Q2 = QFunction(
|
||||||
|
self.encoder.feature_dim, action_shape[0], hidden_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
self.outputs = dict()
|
||||||
|
self.apply(weight_init)
|
||||||
|
|
||||||
|
def forward(self, obs, action, detach_encoder=False):
|
||||||
|
# detach_encoder allows to stop gradient propogation to encoder
|
||||||
|
obs = self.encoder(obs, detach=detach_encoder)
|
||||||
|
|
||||||
|
q1 = self.Q1(obs, action)
|
||||||
|
q2 = self.Q2(obs, action)
|
||||||
|
|
||||||
|
self.outputs['q1'] = q1
|
||||||
|
self.outputs['q2'] = q2
|
||||||
|
|
||||||
|
return q1, q2
|
||||||
|
|
||||||
|
def log(self, L, step, log_freq=LOG_FREQ):
|
||||||
|
if step % log_freq != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.encoder.log(L, step, log_freq)
|
||||||
|
|
||||||
|
for k, v in self.outputs.items():
|
||||||
|
L.log_histogram('train_critic/%s_hist' % k, v, step)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
|
||||||
|
L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)
|
||||||
|
|
||||||
|
|
||||||
|
|
449
train.py
Normal file
449
train.py
Normal file
@ -0,0 +1,449 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import gym
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import dmc2gym
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from logger import Logger
|
||||||
|
from video import VideoRecorder
|
||||||
|
|
||||||
|
from agent.baseline_agent import BaselineAgent
|
||||||
|
from agent.bisim_agent import BisimAgent
|
||||||
|
from agent.deepmdp_agent import DeepMDPAgent
|
||||||
|
from agents.navigation.carla_env import CarlaEnv
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# environment
|
||||||
|
parser.add_argument('--domain_name', default='cheetah')
|
||||||
|
parser.add_argument('--task_name', default='run')
|
||||||
|
parser.add_argument('--image_size', default=84, type=int)
|
||||||
|
parser.add_argument('--action_repeat', default=1, type=int)
|
||||||
|
parser.add_argument('--frame_stack', default=3, type=int)
|
||||||
|
parser.add_argument('--resource_files', type=str)
|
||||||
|
parser.add_argument('--eval_resource_files', type=str)
|
||||||
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||||
|
parser.add_argument('--total_frames', default=1000, type=int)
|
||||||
|
# replay buffer
|
||||||
|
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
||||||
|
# train
|
||||||
|
parser.add_argument('--agent', default='bisim', type=str, choices=['baseline', 'bisim', 'deepmdp'])
|
||||||
|
parser.add_argument('--init_steps', default=1000, type=int)
|
||||||
|
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||||
|
parser.add_argument('--batch_size', default=512, type=int)
|
||||||
|
parser.add_argument('--hidden_dim', default=256, type=int)
|
||||||
|
parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
|
||||||
|
parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms')
|
||||||
|
parser.add_argument('--load_encoder', default=None, type=str)
|
||||||
|
# eval
|
||||||
|
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
|
||||||
|
parser.add_argument('--num_eval_episodes', default=20, type=int)
|
||||||
|
# critic
|
||||||
|
parser.add_argument('--critic_lr', default=1e-3, type=float)
|
||||||
|
parser.add_argument('--critic_beta', default=0.9, type=float)
|
||||||
|
parser.add_argument('--critic_tau', default=0.005, type=float)
|
||||||
|
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
||||||
|
# actor
|
||||||
|
parser.add_argument('--actor_lr', default=1e-3, type=float)
|
||||||
|
parser.add_argument('--actor_beta', default=0.9, type=float)
|
||||||
|
parser.add_argument('--actor_log_std_min', default=-10, type=float)
|
||||||
|
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
||||||
|
parser.add_argument('--actor_update_freq', default=2, type=int)
|
||||||
|
# encoder/decoder
|
||||||
|
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
|
||||||
|
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
||||||
|
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
||||||
|
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
||||||
|
parser.add_argument('--encoder_stride', default=1, type=int)
|
||||||
|
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
|
||||||
|
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
||||||
|
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
||||||
|
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
||||||
|
parser.add_argument('--num_layers', default=4, type=int)
|
||||||
|
parser.add_argument('--num_filters', default=32, type=int)
|
||||||
|
# sac
|
||||||
|
parser.add_argument('--discount', default=0.99, type=float)
|
||||||
|
parser.add_argument('--init_temperature', default=0.01, type=float)
|
||||||
|
parser.add_argument('--alpha_lr', default=1e-3, type=float)
|
||||||
|
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
||||||
|
# misc
|
||||||
|
parser.add_argument('--seed', default=1, type=int)
|
||||||
|
parser.add_argument('--work_dir', default='.', type=str)
|
||||||
|
parser.add_argument('--save_tb', default=False, action='store_true')
|
||||||
|
parser.add_argument('--save_model', default=False, action='store_true')
|
||||||
|
parser.add_argument('--save_buffer', default=False, action='store_true')
|
||||||
|
parser.add_argument('--save_video', default=False, action='store_true')
|
||||||
|
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
|
||||||
|
parser.add_argument('--render', default=False, action='store_true')
|
||||||
|
parser.add_argument('--port', default=2000, type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(env, agent, video, num_episodes, L, step, device=None, embed_viz_dir=None, do_carla_metrics=None):
|
||||||
|
# carla metrics:
|
||||||
|
reason_each_episode_ended = []
|
||||||
|
distance_driven_each_episode = []
|
||||||
|
crash_intensity = 0.
|
||||||
|
steer = 0.
|
||||||
|
brake = 0.
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# embedding visualization
|
||||||
|
obses = []
|
||||||
|
values = []
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for i in range(num_episodes):
|
||||||
|
# carla metrics:
|
||||||
|
dist_driven_this_episode = 0.
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
video.init(enabled=(i == 0))
|
||||||
|
done = False
|
||||||
|
episode_reward = 0
|
||||||
|
while not done:
|
||||||
|
with utils.eval_mode(agent):
|
||||||
|
action = agent.select_action(obs)
|
||||||
|
|
||||||
|
if embed_viz_dir:
|
||||||
|
obses.append(obs)
|
||||||
|
with torch.no_grad():
|
||||||
|
values.append(min(agent.critic(torch.Tensor(obs).to(device).unsqueeze(0), torch.Tensor(action).to(device).unsqueeze(0))).item())
|
||||||
|
embeddings.append(agent.critic.encoder(torch.Tensor(obs).unsqueeze(0).to(device)).cpu().detach().numpy())
|
||||||
|
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# metrics:
|
||||||
|
if do_carla_metrics:
|
||||||
|
dist_driven_this_episode += info['distance']
|
||||||
|
crash_intensity += info['crash_intensity']
|
||||||
|
steer += abs(info['steer'])
|
||||||
|
brake += info['brake']
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
video.record(env)
|
||||||
|
episode_reward += reward
|
||||||
|
|
||||||
|
# metrics:
|
||||||
|
if do_carla_metrics:
|
||||||
|
reason_each_episode_ended.append(info['reason_episode_ended'])
|
||||||
|
distance_driven_each_episode.append(dist_driven_this_episode)
|
||||||
|
|
||||||
|
video.save('%d.mp4' % step)
|
||||||
|
L.log('eval/episode_reward', episode_reward, step)
|
||||||
|
|
||||||
|
if embed_viz_dir:
|
||||||
|
dataset = {'obs': obses, 'values': values, 'embeddings': embeddings}
|
||||||
|
torch.save(dataset, os.path.join(embed_viz_dir, 'train_dataset_{}.pt'.format(step)))
|
||||||
|
|
||||||
|
L.dump(step)
|
||||||
|
|
||||||
|
if do_carla_metrics:
|
||||||
|
print('METRICS--------------------------')
|
||||||
|
print("reason_each_episode_ended: {}".format(reason_each_episode_ended))
|
||||||
|
print("distance_driven_each_episode: {}".format(distance_driven_each_episode))
|
||||||
|
print('crash_intensity: {}'.format(crash_intensity / num_episodes))
|
||||||
|
print('steer: {}'.format(steer / count))
|
||||||
|
print('brake: {}'.format(brake / count))
|
||||||
|
print('---------------------------------')
|
||||||
|
|
||||||
|
|
||||||
|
def make_agent(obs_shape, action_shape, args, device):
|
||||||
|
if args.agent == 'baseline':
|
||||||
|
agent = BaselineAgent(
|
||||||
|
obs_shape=obs_shape,
|
||||||
|
action_shape=action_shape,
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
discount=args.discount,
|
||||||
|
init_temperature=args.init_temperature,
|
||||||
|
alpha_lr=args.alpha_lr,
|
||||||
|
alpha_beta=args.alpha_beta,
|
||||||
|
actor_lr=args.actor_lr,
|
||||||
|
actor_beta=args.actor_beta,
|
||||||
|
actor_log_std_min=args.actor_log_std_min,
|
||||||
|
actor_log_std_max=args.actor_log_std_max,
|
||||||
|
actor_update_freq=args.actor_update_freq,
|
||||||
|
critic_lr=args.critic_lr,
|
||||||
|
critic_beta=args.critic_beta,
|
||||||
|
critic_tau=args.critic_tau,
|
||||||
|
critic_target_update_freq=args.critic_target_update_freq,
|
||||||
|
encoder_type=args.encoder_type,
|
||||||
|
encoder_feature_dim=args.encoder_feature_dim,
|
||||||
|
encoder_lr=args.encoder_lr,
|
||||||
|
encoder_tau=args.encoder_tau,
|
||||||
|
encoder_stride=args.encoder_stride,
|
||||||
|
decoder_type=args.decoder_type,
|
||||||
|
decoder_lr=args.decoder_lr,
|
||||||
|
decoder_update_freq=args.decoder_update_freq,
|
||||||
|
decoder_weight_lambda=args.decoder_weight_lambda,
|
||||||
|
transition_model_type=args.transition_model_type,
|
||||||
|
num_layers=args.num_layers,
|
||||||
|
num_filters=args.num_filters
|
||||||
|
)
|
||||||
|
elif args.agent == 'bisim':
|
||||||
|
agent = BisimAgent(
|
||||||
|
obs_shape=obs_shape,
|
||||||
|
action_shape=action_shape,
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
discount=args.discount,
|
||||||
|
init_temperature=args.init_temperature,
|
||||||
|
alpha_lr=args.alpha_lr,
|
||||||
|
alpha_beta=args.alpha_beta,
|
||||||
|
actor_lr=args.actor_lr,
|
||||||
|
actor_beta=args.actor_beta,
|
||||||
|
actor_log_std_min=args.actor_log_std_min,
|
||||||
|
actor_log_std_max=args.actor_log_std_max,
|
||||||
|
actor_update_freq=args.actor_update_freq,
|
||||||
|
critic_lr=args.critic_lr,
|
||||||
|
critic_beta=args.critic_beta,
|
||||||
|
critic_tau=args.critic_tau,
|
||||||
|
critic_target_update_freq=args.critic_target_update_freq,
|
||||||
|
encoder_type=args.encoder_type,
|
||||||
|
encoder_feature_dim=args.encoder_feature_dim,
|
||||||
|
encoder_lr=args.encoder_lr,
|
||||||
|
encoder_tau=args.encoder_tau,
|
||||||
|
encoder_stride=args.encoder_stride,
|
||||||
|
decoder_type=args.decoder_type,
|
||||||
|
decoder_lr=args.decoder_lr,
|
||||||
|
decoder_update_freq=args.decoder_update_freq,
|
||||||
|
decoder_weight_lambda=args.decoder_weight_lambda,
|
||||||
|
transition_model_type=args.transition_model_type,
|
||||||
|
num_layers=args.num_layers,
|
||||||
|
num_filters=args.num_filters,
|
||||||
|
bisim_coef=args.bisim_coef
|
||||||
|
)
|
||||||
|
elif args.agent == 'deepmdp':
|
||||||
|
agent = DeepMDPAgent(
|
||||||
|
obs_shape=obs_shape,
|
||||||
|
action_shape=action_shape,
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
discount=args.discount,
|
||||||
|
init_temperature=args.init_temperature,
|
||||||
|
alpha_lr=args.alpha_lr,
|
||||||
|
alpha_beta=args.alpha_beta,
|
||||||
|
actor_lr=args.actor_lr,
|
||||||
|
actor_beta=args.actor_beta,
|
||||||
|
actor_log_std_min=args.actor_log_std_min,
|
||||||
|
actor_log_std_max=args.actor_log_std_max,
|
||||||
|
actor_update_freq=args.actor_update_freq,
|
||||||
|
encoder_stride=args.encoder_stride,
|
||||||
|
critic_lr=args.critic_lr,
|
||||||
|
critic_beta=args.critic_beta,
|
||||||
|
critic_tau=args.critic_tau,
|
||||||
|
critic_target_update_freq=args.critic_target_update_freq,
|
||||||
|
encoder_type=args.encoder_type,
|
||||||
|
encoder_feature_dim=args.encoder_feature_dim,
|
||||||
|
encoder_lr=args.encoder_lr,
|
||||||
|
encoder_tau=args.encoder_tau,
|
||||||
|
decoder_type=args.decoder_type,
|
||||||
|
decoder_lr=args.decoder_lr,
|
||||||
|
decoder_update_freq=args.decoder_update_freq,
|
||||||
|
decoder_weight_lambda=args.decoder_weight_lambda,
|
||||||
|
transition_model_type=args.transition_model_type,
|
||||||
|
num_layers=args.num_layers,
|
||||||
|
num_filters=args.num_filters
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.load_encoder:
|
||||||
|
model_dict = agent.actor.encoder.state_dict()
|
||||||
|
encoder_dict = torch.load(args.load_encoder)
|
||||||
|
encoder_dict = {k[8:]: v for k, v in encoder_dict.items() if 'encoder.' in k} # hack to remove encoder. string
|
||||||
|
agent.actor.encoder.load_state_dict(encoder_dict)
|
||||||
|
agent.critic.encoder.load_state_dict(encoder_dict)
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
utils.set_seed_everywhere(args.seed)
|
||||||
|
|
||||||
|
if args.domain_name == 'carla':
|
||||||
|
env = CarlaEnv(
|
||||||
|
render_display=args.render, # for local debugging only
|
||||||
|
display_text=args.render, # for local debugging only
|
||||||
|
changing_weather_speed=0.1, # [0, +inf)
|
||||||
|
rl_image_size=args.image_size,
|
||||||
|
max_episode_steps=1000,
|
||||||
|
frame_skip=args.action_repeat,
|
||||||
|
is_other_cars=True,
|
||||||
|
port=args.port
|
||||||
|
)
|
||||||
|
# TODO: implement env.seed(args.seed) ?
|
||||||
|
|
||||||
|
eval_env = env
|
||||||
|
else:
|
||||||
|
env = dmc2gym.make(
|
||||||
|
domain_name=args.domain_name,
|
||||||
|
task_name=args.task_name,
|
||||||
|
resource_files=args.resource_files,
|
||||||
|
img_source=args.img_source,
|
||||||
|
total_frames=args.total_frames,
|
||||||
|
seed=args.seed,
|
||||||
|
visualize_reward=False,
|
||||||
|
from_pixels=(args.encoder_type == 'pixel'),
|
||||||
|
height=args.image_size,
|
||||||
|
width=args.image_size,
|
||||||
|
frame_skip=args.action_repeat
|
||||||
|
)
|
||||||
|
env.seed(args.seed)
|
||||||
|
|
||||||
|
eval_env = dmc2gym.make(
|
||||||
|
domain_name=args.domain_name,
|
||||||
|
task_name=args.task_name,
|
||||||
|
resource_files=args.eval_resource_files,
|
||||||
|
img_source=args.img_source,
|
||||||
|
total_frames=args.total_frames,
|
||||||
|
seed=args.seed,
|
||||||
|
visualize_reward=False,
|
||||||
|
from_pixels=(args.encoder_type == 'pixel'),
|
||||||
|
height=args.image_size,
|
||||||
|
width=args.image_size,
|
||||||
|
frame_skip=args.action_repeat
|
||||||
|
)
|
||||||
|
|
||||||
|
# stack several consecutive frames together
|
||||||
|
if args.encoder_type == 'pixel':
|
||||||
|
env = utils.FrameStack(env, k=args.frame_stack)
|
||||||
|
eval_env = utils.FrameStack(eval_env, k=args.frame_stack)
|
||||||
|
|
||||||
|
utils.make_dir(args.work_dir)
|
||||||
|
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
|
||||||
|
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
|
||||||
|
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
|
||||||
|
|
||||||
|
video = VideoRecorder(video_dir if args.save_video else None)
|
||||||
|
|
||||||
|
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
|
||||||
|
json.dump(vars(args), f, sort_keys=True, indent=4)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# the dmc2gym wrapper standardizes actions
|
||||||
|
assert env.action_space.low.min() >= -1
|
||||||
|
assert env.action_space.high.max() <= 1
|
||||||
|
|
||||||
|
replay_buffer = utils.ReplayBuffer(
|
||||||
|
obs_shape=env.observation_space.shape,
|
||||||
|
action_shape=env.action_space.shape,
|
||||||
|
capacity=args.replay_buffer_capacity,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = make_agent(
|
||||||
|
obs_shape=env.observation_space.shape,
|
||||||
|
action_shape=env.action_space.shape,
|
||||||
|
args=args,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
L = Logger(args.work_dir, use_tb=args.save_tb)
|
||||||
|
|
||||||
|
episode, episode_reward, done = 0, 0, True
|
||||||
|
start_time = time.time()
|
||||||
|
for step in range(args.num_train_steps):
|
||||||
|
if done:
|
||||||
|
if args.decoder_type == 'inverse':
|
||||||
|
for i in range(1, args.k): # fill k_obs with 0s if episode is done
|
||||||
|
replay_buffer.k_obses[replay_buffer.idx - i] = 0
|
||||||
|
if step > 0:
|
||||||
|
L.log('train/duration', time.time() - start_time, step)
|
||||||
|
start_time = time.time()
|
||||||
|
L.dump(step)
|
||||||
|
|
||||||
|
# evaluate agent periodically
|
||||||
|
if episode % args.eval_freq == 0:
|
||||||
|
L.log('eval/episode', episode, step)
|
||||||
|
evaluate(eval_env, agent, video, args.num_eval_episodes, L, step)
|
||||||
|
if args.save_model:
|
||||||
|
agent.save(model_dir, step)
|
||||||
|
if args.save_buffer:
|
||||||
|
replay_buffer.save(buffer_dir)
|
||||||
|
|
||||||
|
L.log('train/episode_reward', episode_reward, step)
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
episode_reward = 0
|
||||||
|
episode_step = 0
|
||||||
|
episode += 1
|
||||||
|
reward = 0
|
||||||
|
|
||||||
|
L.log('train/episode', episode, step)
|
||||||
|
|
||||||
|
# sample action for data collection
|
||||||
|
if step < args.init_steps:
|
||||||
|
action = env.action_space.sample()
|
||||||
|
else:
|
||||||
|
with utils.eval_mode(agent):
|
||||||
|
action = agent.sample_action(obs)
|
||||||
|
|
||||||
|
# run training update
|
||||||
|
if step >= args.init_steps:
|
||||||
|
num_updates = args.init_steps if step == args.init_steps else 1
|
||||||
|
for _ in range(num_updates):
|
||||||
|
agent.update(replay_buffer, L, step)
|
||||||
|
|
||||||
|
curr_reward = reward
|
||||||
|
next_obs, reward, done, _ = env.step(action)
|
||||||
|
|
||||||
|
# allow infinit bootstrap
|
||||||
|
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
||||||
|
done
|
||||||
|
)
|
||||||
|
episode_reward += reward
|
||||||
|
|
||||||
|
replay_buffer.add(obs, action, curr_reward, reward, next_obs, done_bool)
|
||||||
|
np.copyto(replay_buffer.k_obses[replay_buffer.idx - args.k], next_obs)
|
||||||
|
|
||||||
|
obs = next_obs
|
||||||
|
episode_step += 1
|
||||||
|
|
||||||
|
|
||||||
|
def collect_data(env, agent, num_rollouts, path_length, checkpoint_path):
|
||||||
|
rollouts = []
|
||||||
|
for i in range(num_rollouts):
|
||||||
|
obses = []
|
||||||
|
acs = []
|
||||||
|
rews = []
|
||||||
|
observation = env.reset()
|
||||||
|
for j in range(path_length):
|
||||||
|
action = agent.sample_action(observation)
|
||||||
|
next_observation, reward, done, _ = env.step(action)
|
||||||
|
obses.append(observation)
|
||||||
|
acs.append(action)
|
||||||
|
rews.append(reward)
|
||||||
|
observation = next_observation
|
||||||
|
obses.append(next_observation)
|
||||||
|
rollouts.append((obses, acs, rews))
|
||||||
|
|
||||||
|
from scipy.io import savemat
|
||||||
|
|
||||||
|
savemat(
|
||||||
|
os.path.join(checkpoint_path, "dynamics-data.mat"),
|
||||||
|
{
|
||||||
|
"trajs": np.array([path[0] for path in rollouts]),
|
||||||
|
"acs": np.array([path[1] for path in rollouts]),
|
||||||
|
"rews": np.array([path[2] for path in rollouts])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
97
train_vae.py
Normal file
97
train_vae.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import dmc2gym
|
||||||
|
import numpy as np
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from encoder import make_encoder
|
||||||
|
from decoder import make_decoder
|
||||||
|
from sac_ae import weight_init
|
||||||
|
from train import parse_args
|
||||||
|
import utils
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
args.domain_name = 'walker'
|
||||||
|
args.task_name = 'walk'
|
||||||
|
args.image_size = 84
|
||||||
|
args.seed = 1
|
||||||
|
args.agent = 'bisim'
|
||||||
|
args.encoder_type = 'pixel'
|
||||||
|
args.action_repeat = 2
|
||||||
|
args.img_source = 'video'
|
||||||
|
args.num_layers = 4
|
||||||
|
args.num_filters = 32
|
||||||
|
args.hidden_dim = 1024
|
||||||
|
args.resource_files = '/datasets01/kinetics/070618/400/train/driving_car/*.mp4'
|
||||||
|
args.total_frames = 5000
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
class VAE(nn.Module):
|
||||||
|
def __init__(self, obs_shape):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = make_encoder(
|
||||||
|
encoder_type='pixel',
|
||||||
|
obs_shape=obs_shape,
|
||||||
|
feature_dim=100,
|
||||||
|
num_layers=4,
|
||||||
|
num_filters=32).to(device)
|
||||||
|
|
||||||
|
self.decoder = make_decoder(
|
||||||
|
'pixel', obs_shape, 50, 4, 32).to(device)
|
||||||
|
self.decoder.apply(weight_init)
|
||||||
|
|
||||||
|
def train(self, obs):
|
||||||
|
h = self.encoder(obs)
|
||||||
|
mu, log_var = h[:, :50], h[:, 50:]
|
||||||
|
eps = torch.randn_like(mu)
|
||||||
|
reparam = mu + torch.exp(log_var / 2) * eps
|
||||||
|
rec_obs = torch.sigmoid(self.decoder(reparam))
|
||||||
|
BCE = F.binary_cross_entropy(rec_obs, obs / 255, reduction='sum')
|
||||||
|
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
||||||
|
loss = BCE + KLD
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
env = dmc2gym.make(
|
||||||
|
domain_name=args.domain_name,
|
||||||
|
task_name=args.task_name,
|
||||||
|
resource_files=args.resource_files,
|
||||||
|
img_source=args.img_source,
|
||||||
|
total_frames=10,
|
||||||
|
seed=args.seed,
|
||||||
|
visualize_reward=False,
|
||||||
|
from_pixels=(args.encoder_type == 'pixel'),
|
||||||
|
height=args.image_size,
|
||||||
|
width=args.image_size,
|
||||||
|
frame_skip=args.action_repeat
|
||||||
|
)
|
||||||
|
env = utils.FrameStack(env, k=args.frame_stack)
|
||||||
|
vae = VAE(env.observation_space.shape)
|
||||||
|
train_dataset = torch.load('train_dataset.pt')
|
||||||
|
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_dataset['obs'], batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
# training loop
|
||||||
|
for i in range(100):
|
||||||
|
total_loss = []
|
||||||
|
for obs_batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = vae.train(obs_batch.to(device).float())
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss.append(loss.item())
|
||||||
|
|
||||||
|
print(np.mean(total_loss), i)
|
||||||
|
|
||||||
|
dataset = torch.load('dataset.pt')
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = vae.encoder(torch.FloatTensor(dataset['obs']).to(device)).cpu().numpy()
|
||||||
|
torch.save(embeddings, 'vae_embeddings.pt')
|
104
transition_model.py
Normal file
104
transition_model.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DeterministicTransitionModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, encoder_feature_dim, action_shape, layer_width):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
|
||||||
|
self.ln = nn.LayerNorm(layer_width)
|
||||||
|
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
|
||||||
|
print("Deterministic transition model chosen.")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc(x)
|
||||||
|
x = self.ln(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
|
||||||
|
mu = self.fc_mu(x)
|
||||||
|
sigma = None
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def sample_prediction(self, x):
|
||||||
|
mu, sigma = self(x)
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
class ProbabilisticTransitionModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, encoder_feature_dim, action_shape, layer_width, announce=True, max_sigma=1e1, min_sigma=1e-4):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
|
||||||
|
self.ln = nn.LayerNorm(layer_width)
|
||||||
|
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
|
||||||
|
self.fc_sigma = nn.Linear(layer_width, encoder_feature_dim)
|
||||||
|
|
||||||
|
self.max_sigma = max_sigma
|
||||||
|
self.min_sigma = min_sigma
|
||||||
|
assert(self.max_sigma >= self.min_sigma)
|
||||||
|
if announce:
|
||||||
|
print("Probabilistic transition model chosen.")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc(x)
|
||||||
|
x = self.ln(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
|
||||||
|
mu = self.fc_mu(x)
|
||||||
|
sigma = torch.sigmoid(self.fc_sigma(x)) # range (0, 1.)
|
||||||
|
sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma # scaled range (min_sigma, max_sigma)
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def sample_prediction(self, x):
|
||||||
|
mu, sigma = self(x)
|
||||||
|
eps = torch.randn_like(sigma)
|
||||||
|
return mu + sigma * eps
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleOfProbabilisticTransitionModels(object):
|
||||||
|
|
||||||
|
def __init__(self, encoder_feature_dim, action_shape, layer_width, ensemble_size=5):
|
||||||
|
self.models = [ProbabilisticTransitionModel(encoder_feature_dim, action_shape, layer_width, announce=False)
|
||||||
|
for _ in range(ensemble_size)]
|
||||||
|
print("Ensemble of probabilistic transition models chosen.")
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
mu_sigma_list = [model.forward(x) for model in self.models]
|
||||||
|
mus, sigmas = zip(*mu_sigma_list)
|
||||||
|
mus, sigmas = torch.stack(mus), torch.stack(sigmas)
|
||||||
|
return mus, sigmas
|
||||||
|
|
||||||
|
def sample_prediction(self, x):
|
||||||
|
model = random.choice(self.models)
|
||||||
|
return model.sample_prediction(x)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for model in self.models:
|
||||||
|
model.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def parameters(self):
|
||||||
|
list_of_parameters = [list(model.parameters()) for model in self.models]
|
||||||
|
parameters = [p for ps in list_of_parameters for p in ps]
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
|
_AVAILABLE_TRANSITION_MODELS = {'': DeterministicTransitionModel,
|
||||||
|
'deterministic': DeterministicTransitionModel,
|
||||||
|
'probabilistic': ProbabilisticTransitionModel,
|
||||||
|
'ensemble': EnsembleOfProbabilisticTransitionModels}
|
||||||
|
|
||||||
|
|
||||||
|
def make_transition_model(transition_model_type, encoder_feature_dim, action_shape, layer_width=512):
|
||||||
|
assert transition_model_type in _AVAILABLE_TRANSITION_MODELS
|
||||||
|
return _AVAILABLE_TRANSITION_MODELS[transition_model_type](
|
||||||
|
encoder_feature_dim, action_shape, layer_width
|
||||||
|
)
|
183
utils.py
Normal file
183
utils.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import gym
|
||||||
|
import os
|
||||||
|
from collections import deque
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class eval_mode(object):
|
||||||
|
def __init__(self, *models):
|
||||||
|
self.models = models
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.prev_states = []
|
||||||
|
for model in self.models:
|
||||||
|
self.prev_states.append(model.training)
|
||||||
|
model.train(False)
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
for model, state in zip(self.models, self.prev_states):
|
||||||
|
model.train(state)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def soft_update_params(net, target_net, tau):
|
||||||
|
for param, target_param in zip(net.parameters(), target_net.parameters()):
|
||||||
|
target_param.data.copy_(
|
||||||
|
tau * param.data + (1 - tau) * target_param.data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed_everywhere(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def module_hash(module):
|
||||||
|
result = 0
|
||||||
|
for tensor in module.state_dict().values():
|
||||||
|
result += tensor.sum().item()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def make_dir(dir_path):
|
||||||
|
try:
|
||||||
|
os.mkdir(dir_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return dir_path
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_obs(obs, bits=5):
|
||||||
|
"""Preprocessing image, see https://arxiv.org/abs/1807.03039."""
|
||||||
|
bins = 2**bits
|
||||||
|
assert obs.dtype == torch.float32
|
||||||
|
if bits < 8:
|
||||||
|
obs = torch.floor(obs / 2**(8 - bits))
|
||||||
|
obs = obs / bins
|
||||||
|
obs = obs + torch.rand_like(obs) / bins
|
||||||
|
obs = obs - 0.5
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayBuffer(object):
|
||||||
|
"""Buffer to store environment transitions."""
|
||||||
|
def __init__(self, obs_shape, action_shape, capacity, batch_size, device):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# the proprioceptive obs is stored as float32, pixels obs as uint8
|
||||||
|
obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
|
||||||
|
|
||||||
|
self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
|
||||||
|
self.k_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
|
||||||
|
self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
|
||||||
|
self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
|
||||||
|
self.curr_rewards = np.empty((capacity, 1), dtype=np.float32)
|
||||||
|
self.rewards = np.empty((capacity, 1), dtype=np.float32)
|
||||||
|
self.not_dones = np.empty((capacity, 1), dtype=np.float32)
|
||||||
|
|
||||||
|
self.idx = 0
|
||||||
|
self.last_save = 0
|
||||||
|
self.full = False
|
||||||
|
|
||||||
|
def add(self, obs, action, curr_reward, reward, next_obs, done):
|
||||||
|
np.copyto(self.obses[self.idx], obs)
|
||||||
|
np.copyto(self.actions[self.idx], action)
|
||||||
|
np.copyto(self.curr_rewards[self.idx], curr_reward)
|
||||||
|
np.copyto(self.rewards[self.idx], reward)
|
||||||
|
np.copyto(self.next_obses[self.idx], next_obs)
|
||||||
|
np.copyto(self.not_dones[self.idx], not done)
|
||||||
|
|
||||||
|
self.idx = (self.idx + 1) % self.capacity
|
||||||
|
self.full = self.full or self.idx == 0
|
||||||
|
|
||||||
|
def sample(self, k=False):
|
||||||
|
idxs = np.random.randint(
|
||||||
|
0, self.capacity if self.full else self.idx, size=self.batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
|
||||||
|
actions = torch.as_tensor(self.actions[idxs], device=self.device)
|
||||||
|
curr_rewards = torch.as_tensor(self.curr_rewards[idxs], device=self.device)
|
||||||
|
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
|
||||||
|
next_obses = torch.as_tensor(
|
||||||
|
self.next_obses[idxs], device=self.device
|
||||||
|
).float()
|
||||||
|
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
|
||||||
|
if k:
|
||||||
|
return obses, actions, rewards, next_obses, not_dones, torch.as_tensor(self.k_obses[idxs], device=self.device)
|
||||||
|
return obses, actions, curr_rewards, rewards, next_obses, not_dones
|
||||||
|
|
||||||
|
def save(self, save_dir):
|
||||||
|
if self.idx == self.last_save:
|
||||||
|
return
|
||||||
|
path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
|
||||||
|
payload = [
|
||||||
|
self.obses[self.last_save:self.idx],
|
||||||
|
self.next_obses[self.last_save:self.idx],
|
||||||
|
self.actions[self.last_save:self.idx],
|
||||||
|
self.rewards[self.last_save:self.idx],
|
||||||
|
self.curr_rewards[self.last_save:self.idx],
|
||||||
|
self.not_dones[self.last_save:self.idx]
|
||||||
|
]
|
||||||
|
self.last_save = self.idx
|
||||||
|
torch.save(payload, path)
|
||||||
|
|
||||||
|
def load(self, save_dir):
|
||||||
|
chunks = os.listdir(save_dir)
|
||||||
|
chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
|
||||||
|
for chunk in chucks:
|
||||||
|
start, end = [int(x) for x in chunk.split('.')[0].split('_')]
|
||||||
|
path = os.path.join(save_dir, chunk)
|
||||||
|
payload = torch.load(path)
|
||||||
|
assert self.idx == start
|
||||||
|
self.obses[start:end] = payload[0]
|
||||||
|
self.next_obses[start:end] = payload[1]
|
||||||
|
self.actions[start:end] = payload[2]
|
||||||
|
self.rewards[start:end] = payload[3]
|
||||||
|
self.curr_rewards[start:end] = payload[4]
|
||||||
|
self.not_dones[start:end] = payload[5]
|
||||||
|
self.idx = end
|
||||||
|
|
||||||
|
|
||||||
|
class FrameStack(gym.Wrapper):
|
||||||
|
def __init__(self, env, k):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self._k = k
|
||||||
|
self._frames = deque([], maxlen=k)
|
||||||
|
shp = env.observation_space.shape
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=1,
|
||||||
|
shape=((shp[0] * k,) + shp[1:]),
|
||||||
|
dtype=env.observation_space.dtype
|
||||||
|
)
|
||||||
|
self._max_episode_steps = env._max_episode_steps
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
obs = self.env.reset()
|
||||||
|
for _ in range(self._k):
|
||||||
|
self._frames.append(obs)
|
||||||
|
return self._get_obs()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
self._frames.append(obs)
|
||||||
|
return self._get_obs(), reward, done, info
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
assert len(self._frames) == self._k
|
||||||
|
return np.concatenate(list(self._frames), axis=0)
|
50
video.py
Normal file
50
video.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import glob
|
||||||
|
|
||||||
|
from dmc2gym.natural_imgsource import RandomVideoSource
|
||||||
|
|
||||||
|
|
||||||
|
class VideoRecorder(object):
|
||||||
|
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30):
|
||||||
|
self.dir_name = dir_name
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.camera_id = camera_id
|
||||||
|
self.fps = fps
|
||||||
|
self.frames = []
|
||||||
|
if resource_files:
|
||||||
|
files = glob.glob(os.path.expanduser(resource_files))
|
||||||
|
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000)
|
||||||
|
else:
|
||||||
|
self._bg_source = None
|
||||||
|
|
||||||
|
def init(self, enabled=True):
|
||||||
|
self.frames = []
|
||||||
|
self.enabled = self.dir_name is not None and enabled
|
||||||
|
|
||||||
|
def record(self, env):
|
||||||
|
if self.enabled:
|
||||||
|
frame = env.render(
|
||||||
|
mode='rgb_array',
|
||||||
|
height=self.height,
|
||||||
|
width=self.width,
|
||||||
|
camera_id=self.camera_id
|
||||||
|
)
|
||||||
|
if self._bg_source:
|
||||||
|
mask = np.logical_and((frame[:, :, 2] > frame[:, :, 1]), (frame[:, :, 2] > frame[:, :, 0])) # hardcoded for dmc
|
||||||
|
bg = self._bg_source.get_image()
|
||||||
|
frame[mask] = bg[mask]
|
||||||
|
self.frames.append(frame)
|
||||||
|
|
||||||
|
def save(self, file_name):
|
||||||
|
if self.enabled:
|
||||||
|
path = os.path.join(self.dir_name, file_name)
|
||||||
|
imageio.mimsave(path, self.frames, fps=self.fps)
|
Loading…
Reference in New Issue
Block a user