sac_ae_if/dmc2gym/wrappers.py
2023-05-16 12:40:47 +02:00

199 lines
6.5 KiB
Python

from gym import core, spaces
import glob
import os
import local_dm_control_suite as suite
from dm_env import specs
import numpy as np
import skimage.io
from dmc2gym import natural_imgsource
def _spec_to_box(spec):
def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32
dim = np.int(np.prod(s.shape))
if type(s) == specs.Array:
bound = np.inf * np.ones(dim, dtype=np.float32)
return -bound, bound
elif type(s) == specs.BoundedArray:
zeros = np.zeros(dim, dtype=np.float32)
return s.minimum + zeros, s.maximum + zeros
mins, maxs = [], []
for s in spec:
mn, mx = extract_min_max(s)
mins.append(mn)
maxs.append(mx)
low = np.concatenate(mins, axis=0)
high = np.concatenate(maxs, axis=0)
assert low.shape == high.shape
return spaces.Box(low, high, dtype=np.float32)
def _flatten_obs(obs):
obs_pieces = []
for v in obs.values():
flat = np.array([v]) if np.isscalar(v) else v.ravel()
obs_pieces.append(flat)
return np.concatenate(obs_pieces, axis=0)
class DMCWrapper(core.Env):
def __init__(
self,
domain_name,
task_name,
resource_files,
img_source,
total_frames,
task_kwargs=None,
visualize_reward={},
from_pixels=False,
height=84,
width=84,
camera_id=0,
frame_skip=1,
environment_kwargs=None
):
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
self._from_pixels = from_pixels
self._height = height
self._width = width
self._camera_id = camera_id
self._frame_skip = frame_skip
self._img_source = img_source
# create task
self._env = suite.load(
domain_name=domain_name,
task_name=task_name,
task_kwargs=task_kwargs,
visualize_reward=visualize_reward,
environment_kwargs=environment_kwargs
)
# true and normalized action spaces
self._true_action_space = _spec_to_box([self._env.action_spec()])
self._norm_action_space = spaces.Box(
low=-1.0,
high=1.0,
shape=self._true_action_space.shape,
dtype=np.float32
)
# create observation space
if from_pixels:
self._observation_space = spaces.Box(
low=0, high=255, shape=[3, height, width], dtype=np.uint8
)
else:
self._observation_space = _spec_to_box(
self._env.observation_spec().values()
)
self._internal_state_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=self._env.physics.get_state().shape,
dtype=np.float32
)
# background
if img_source is not None:
shape2d = (height, width)
if img_source == "color":
self._bg_source = natural_imgsource.RandomColorSource(shape2d)
elif img_source == "noise":
self._bg_source = natural_imgsource.NoiseSource(shape2d)
else:
files = glob.glob(os.path.expanduser(resource_files))
assert len(files), "Pattern {} does not match any files".format(
resource_files
)
if img_source == "images":
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
elif img_source == "video":
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
else:
raise Exception("img_source %s not defined." % img_source)
# set seed
self.seed(seed=task_kwargs.get('random', 1))
def __getattr__(self, name):
return getattr(self._env, name)
def _get_obs(self, time_step):
if self._from_pixels:
obs = self.render(
height=self._height,
width=self._width,
camera_id=self._camera_id
)
if self._img_source is not None:
mask = np.logical_and((obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])) # hardcoded for dmc
bg = self._bg_source.get_image()
obs[mask] = bg[mask]
obs = obs.transpose(2, 0, 1).copy()
else:
obs = _flatten_obs(time_step.observation)
return obs
def _convert_action(self, action):
action = action.astype(np.float64)
true_delta = self._true_action_space.high - self._true_action_space.low
norm_delta = self._norm_action_space.high - self._norm_action_space.low
action = (action - self._norm_action_space.low) / norm_delta
action = action * true_delta + self._true_action_space.low
action = action.astype(np.float32)
return action
@property
def observation_space(self):
return self._observation_space
@property
def internal_state_space(self):
return self._internal_state_space
@property
def action_space(self):
return self._norm_action_space
def seed(self, seed):
self._true_action_space.seed(seed)
self._norm_action_space.seed(seed)
self._observation_space.seed(seed)
def step(self, action):
assert self._norm_action_space.contains(action)
action = self._convert_action(action)
assert self._true_action_space.contains(action)
reward = 0
extra = {'internal_state': self._env.physics.get_state().copy()}
for _ in range(self._frame_skip):
time_step = self._env.step(action)
reward += time_step.reward or 0
done = time_step.last()
if done:
break
obs = self._get_obs(time_step)
extra['discount'] = time_step.discount
return obs, reward, done, extra
def reset(self):
time_step = self._env.reset()
obs = self._get_obs(time_step)
return obs
def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
height = height or self._height
width = width or self._width
camera_id = camera_id or self._camera_id
return self._env.physics.render(
height=height, width=width, camera_id=camera_id
)