209 lines
6.9 KiB
Python
209 lines
6.9 KiB
Python
from gym import core, spaces
|
|
import glob
|
|
import os
|
|
import local_dm_control_suite as suite
|
|
from dm_env import specs
|
|
import numpy as np
|
|
import skimage.io
|
|
|
|
from dmc2gym import natural_imgsource
|
|
|
|
high_noise = False
|
|
|
|
def set_global_var(set_high_noise):
|
|
global high_noise
|
|
high_noise = set_high_noise
|
|
|
|
def _spec_to_box(spec):
|
|
def extract_min_max(s):
|
|
assert s.dtype == np.float64 or s.dtype == np.float32
|
|
dim = np.int(np.prod(s.shape))
|
|
if type(s) == specs.Array:
|
|
bound = np.inf * np.ones(dim, dtype=np.float32)
|
|
return -bound, bound
|
|
elif type(s) == specs.BoundedArray:
|
|
zeros = np.zeros(dim, dtype=np.float32)
|
|
return s.minimum + zeros, s.maximum + zeros
|
|
|
|
mins, maxs = [], []
|
|
for s in spec:
|
|
mn, mx = extract_min_max(s)
|
|
mins.append(mn)
|
|
maxs.append(mx)
|
|
low = np.concatenate(mins, axis=0)
|
|
high = np.concatenate(maxs, axis=0)
|
|
assert low.shape == high.shape
|
|
return spaces.Box(low, high, dtype=np.float32)
|
|
|
|
|
|
def _flatten_obs(obs):
|
|
obs_pieces = []
|
|
for v in obs.values():
|
|
flat = np.array([v]) if np.isscalar(v) else v.ravel()
|
|
obs_pieces.append(flat)
|
|
return np.concatenate(obs_pieces, axis=0)
|
|
|
|
|
|
class DMCWrapper(core.Env):
|
|
def __init__(
|
|
self,
|
|
domain_name,
|
|
task_name,
|
|
resource_files,
|
|
img_source,
|
|
total_frames,
|
|
task_kwargs=None,
|
|
visualize_reward={},
|
|
from_pixels=False,
|
|
height=84,
|
|
width=84,
|
|
camera_id=0,
|
|
frame_skip=1,
|
|
environment_kwargs=None
|
|
):
|
|
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
|
|
self._from_pixels = from_pixels
|
|
self._height = height
|
|
self._width = width
|
|
self._camera_id = camera_id
|
|
self._frame_skip = frame_skip
|
|
self._img_source = img_source
|
|
|
|
# create task
|
|
self._env = suite.load(
|
|
domain_name=domain_name,
|
|
task_name=task_name,
|
|
task_kwargs=task_kwargs,
|
|
visualize_reward=visualize_reward,
|
|
environment_kwargs=environment_kwargs
|
|
)
|
|
|
|
# true and normalized action spaces
|
|
self._true_action_space = _spec_to_box([self._env.action_spec()])
|
|
self._norm_action_space = spaces.Box(
|
|
low=-1.0,
|
|
high=1.0,
|
|
shape=self._true_action_space.shape,
|
|
dtype=np.float32
|
|
)
|
|
|
|
# create observation space
|
|
if from_pixels:
|
|
self._observation_space = spaces.Box(
|
|
low=0, high=255, shape=[3, height, width], dtype=np.uint8
|
|
)
|
|
else:
|
|
self._observation_space = _spec_to_box(
|
|
self._env.observation_spec().values()
|
|
)
|
|
|
|
self._internal_state_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=self._env.physics.get_state().shape,
|
|
dtype=np.float32
|
|
)
|
|
|
|
# background
|
|
if img_source is not None:
|
|
shape2d = (height, width)
|
|
if img_source == "color":
|
|
self._bg_source = natural_imgsource.RandomColorSource(shape2d)
|
|
elif img_source == "noise":
|
|
self._bg_source = natural_imgsource.NoiseSource(shape2d)
|
|
else:
|
|
files = glob.glob(os.path.expanduser(resource_files))
|
|
self.files = files
|
|
self.total_frames = total_frames
|
|
self.shape2d = shape2d
|
|
assert len(files), "Pattern {} does not match any files".format(
|
|
resource_files
|
|
)
|
|
if img_source == "images":
|
|
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=False, max_videos=50, random_bg=False)
|
|
elif img_source == "video":
|
|
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=False,max_videos=50, random_bg=False)
|
|
else:
|
|
raise Exception("img_source %s not defined." % img_source)
|
|
|
|
# set seed
|
|
self.seed(seed=task_kwargs.get('random', 1))
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._env, name)
|
|
|
|
def _get_obs(self, time_step):
|
|
if self._from_pixels:
|
|
obs = self.render(
|
|
height=self._height,
|
|
width=self._width,
|
|
camera_id=self._camera_id
|
|
)
|
|
if self._img_source is not None:
|
|
mask = np.logical_and((obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])) # hardcoded for dmc
|
|
bg = self._bg_source.get_image()
|
|
obs[mask] = bg[mask]
|
|
obs = obs.transpose(2, 0, 1).copy()
|
|
else:
|
|
obs = _flatten_obs(time_step.observation)
|
|
return obs
|
|
|
|
def _convert_action(self, action):
|
|
action = action.astype(np.float64)
|
|
true_delta = self._true_action_space.high - self._true_action_space.low
|
|
norm_delta = self._norm_action_space.high - self._norm_action_space.low
|
|
action = (action - self._norm_action_space.low) / norm_delta
|
|
action = action * true_delta + self._true_action_space.low
|
|
action = action.astype(np.float32)
|
|
return action
|
|
|
|
@property
|
|
def observation_space(self):
|
|
return self._observation_space
|
|
|
|
@property
|
|
def internal_state_space(self):
|
|
return self._internal_state_space
|
|
|
|
@property
|
|
def action_space(self):
|
|
return self._norm_action_space
|
|
|
|
def seed(self, seed):
|
|
self._true_action_space.seed(seed)
|
|
self._norm_action_space.seed(seed)
|
|
self._observation_space.seed(seed)
|
|
|
|
def step(self, action):
|
|
assert self._norm_action_space.contains(action)
|
|
action = self._convert_action(action)
|
|
assert self._true_action_space.contains(action)
|
|
reward = 0
|
|
extra = {'internal_state': self._env.physics.get_state().copy()}
|
|
|
|
for _ in range(self._frame_skip):
|
|
time_step = self._env.step(action)
|
|
reward += time_step.reward or 0
|
|
done = time_step.last()
|
|
if done:
|
|
break
|
|
obs = self._get_obs(time_step)
|
|
extra['discount'] = time_step.discount
|
|
return obs, reward, done, extra
|
|
|
|
def reset(self):
|
|
time_step = self._env.reset()
|
|
self._bg_source.reset()
|
|
#self._bg_source = natural_imgsource.RandomVideoSource(self.shape2d, self.files, grayscale=True, total_frames=self.total_frames, high_noise=high_noise)
|
|
obs = self._get_obs(time_step)
|
|
return obs
|
|
|
|
def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
|
|
assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
|
|
height = height or self._height
|
|
width = width or self._width
|
|
camera_id = camera_id or self._camera_id
|
|
return self._env.physics.render(
|
|
height=height, width=width, camera_id=camera_id
|
|
)
|