diff --git a/dmc2gym/natural_imgsource.py b/dmc2gym/natural_imgsource.py index 42ef62f..bc85dd3 100644 --- a/dmc2gym/natural_imgsource.py +++ b/dmc2gym/natural_imgsource.py @@ -1,36 +1,19 @@ +# This code provides the class that is used to generate backgrounds for the natural background setting +# the class is used inside an environment wrapper and will be called each time the env generates an observation +# the code is largely based on https://github.com/facebookresearch/deep_bisim4control -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import cv2 -import skvideo.io import random -import tqdm -class BackgroundMatting(object): - """ - Produce a mask by masking the given color. This is a simple strategy - but effective for many games. - """ - def __init__(self, color): - """ - Args: - color: a (r, g, b) tuple or single value for grayscale - """ - self._color = color - - def get_mask(self, img): - return img == self._color +import cv2 +import numpy as np +import skvideo.io class ImageSource(object): """ Source of natural images to be added to a simulated environment. """ + def get_image(self): """ Returns: @@ -43,141 +26,56 @@ class ImageSource(object): pass -class FixedColorSource(ImageSource): - def __init__(self, shape, color): - """ - Args: - shape: [h, w] - color: a 3-tuple - """ - self.arr = np.zeros((shape[0], shape[1], 3)) - self.arr[:, :] = color - - def get_image(self): - return self.arr - - -class RandomColorSource(ImageSource): - def __init__(self, shape): - """ - Args: - shape: [h, w] - """ - self.shape = shape - self.arr = None - self.reset() - - def reset(self): - self._color = np.random.randint(0, 256, size=(3,)) - self.arr = np.zeros((self.shape[0], self.shape[1], 3)) - self.arr[:, :] = self._color - - def get_image(self): - return self.arr - - -class NoiseSource(ImageSource): - def __init__(self, shape, strength=255): - """ - Args: - shape: [h, w] - strength (int): the strength of noise, in range [0, 255] - """ - self.shape = shape - self.strength = strength - - def get_image(self): - return np.random.randn(self.shape[0], self.shape[1], 3) * self.strength - - -class RandomImageSource(ImageSource): - def __init__(self, shape, filelist, total_frames=None, grayscale=False): - """ - Args: - shape: [h, w] - filelist: a list of image files - """ - self.grayscale = grayscale - self.total_frames = total_frames - self.shape = shape - self.filelist = filelist - self.build_arr() - self.current_idx = 0 - self.reset() - - def build_arr(self): - self.total_frames = self.total_frames if self.total_frames else len(self.filelist) - self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,))) - for i in range(self.total_frames): - # if i % len(self.filelist) == 0: random.shuffle(self.filelist) - fname = self.filelist[i % len(self.filelist)] - if self.grayscale: im = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)[..., None] - else: im = cv2.imread(fname, cv2.IMREAD_COLOR) - self.arr[i] = cv2.resize(im, (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height) - - def reset(self): - self._loc = np.random.randint(0, self.total_frames) - - def get_image(self): - return self.arr[self._loc] - - class RandomVideoSource(ImageSource): - def __init__(self, shape, filelist, total_frames=None, grayscale=False): + def __init__(self, shape, filelist, random_bg=False, max_videos=50, grayscale=False): """ Args: shape: [h, w] filelist: a list of video files """ self.grayscale = grayscale - self.total_frames = total_frames self.shape = shape self.filelist = filelist - self.build_arr() + random.shuffle(self.filelist) + self.filelist = self.filelist[:max_videos] + self.max_videos = max_videos + self.random_bg = random_bg self.current_idx = 0 + self._current_vid = None self.reset() - def build_arr(self): - if not self.total_frames: - self.total_frames = 0 - self.arr = None - random.shuffle(self.filelist) - for fname in tqdm.tqdm(self.filelist, desc="Loading videos for natural", position=0): - if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"}) - else: frames = skvideo.io.vread(fname) - local_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,))) - for i in tqdm.tqdm(range(frames.shape[0]), desc="video frames", position=1): - local_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height) - if self.arr is None: - self.arr = local_arr - else: - self.arr = np.concatenate([self.arr, local_arr], 0) - self.total_frames += local_arr.shape[0] + def load_video(self, vid_id): + fname = self.filelist[vid_id] + if self.grayscale: + frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"}) else: - self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,))) - total_frame_i = 0 - file_i = 0 - with tqdm.tqdm(total=self.total_frames, desc="Loading videos for natural") as pbar: - while total_frame_i < self.total_frames: - if file_i % len(self.filelist) == 0: random.shuffle(self.filelist) - file_i += 1 - fname = self.filelist[file_i % len(self.filelist)] - if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"}) - else: frames = skvideo.io.vread(fname) - for frame_i in range(frames.shape[0]): - if total_frame_i >= self.total_frames: break - if self.grayscale: - self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))[..., None] ## THIS IS NOT A BUG! cv2 uses (width, height) - else: - self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0])) - pbar.update(1) - total_frame_i += 1 + frames = skvideo.io.vread(fname, num_frames=1000) + img_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,))) + for i in range(frames.shape[0]): + if self.grayscale: + img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))[..., None] # THIS IS NOT A BUG! cv2 uses (width, height) + else: + img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0])) + return img_arr def reset(self): - self._loc = np.random.randint(0, self.total_frames) + del self._current_vid + self._video_id = np.random.randint(0, len(self.filelist)) + self._current_vid = self.load_video(self._video_id) + while True: + try: + self._video_id = np.random.randint(0, len(self.filelist)) + self._current_vid = self.load_video(self._video_id) + break + except Exception: + continue + self._loc = np.random.randint(0, len(self._current_vid)) def get_image(self): - img = self.arr[self._loc % self.total_frames] - self._loc += 1 - return img + if self.random_bg: + self._loc = np.random.randint(0, len(self._current_vid)) + else: + self._loc += 1 + img = self._current_vid[self._loc % len(self._current_vid)] + return img \ No newline at end of file diff --git a/dmc2gym/wrappers.py b/dmc2gym/wrappers.py index 077f2eb..887b7e6 100644 --- a/dmc2gym/wrappers.py +++ b/dmc2gym/wrappers.py @@ -8,6 +8,11 @@ import skimage.io from dmc2gym import natural_imgsource +high_noise = False + +def set_global_var(set_high_noise): + global high_noise + high_noise = set_high_noise def _spec_to_box(spec): def extract_min_max(s): @@ -108,13 +113,16 @@ class DMCWrapper(core.Env): self._bg_source = natural_imgsource.NoiseSource(shape2d) else: files = glob.glob(os.path.expanduser(resource_files)) + self.files = files + self.total_frames = total_frames + self.shape2d = shape2d assert len(files), "Pattern {} does not match any files".format( resource_files ) if img_source == "images": - self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames) + self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=False, max_videos=50, random_bg=False) elif img_source == "video": - self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames) + self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=False,max_videos=50, random_bg=False) else: raise Exception("img_source %s not defined." % img_source) @@ -185,6 +193,8 @@ class DMCWrapper(core.Env): def reset(self): time_step = self._env.reset() + self._bg_source.reset() + #self._bg_source = natural_imgsource.RandomVideoSource(self.shape2d, self.files, grayscale=True, total_frames=self.total_frames, high_noise=high_noise) obs = self._get_obs(time_step) return obs diff --git a/local_dm_control_suite/__pycache__/__init__.cpython-37.pyc b/local_dm_control_suite/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..c78a7fd Binary files /dev/null and b/local_dm_control_suite/__pycache__/__init__.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/acrobot.cpython-37.pyc b/local_dm_control_suite/__pycache__/acrobot.cpython-37.pyc new file mode 100644 index 0000000..35d1caa Binary files /dev/null and b/local_dm_control_suite/__pycache__/acrobot.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/ball_in_cup.cpython-37.pyc b/local_dm_control_suite/__pycache__/ball_in_cup.cpython-37.pyc new file mode 100644 index 0000000..2fbd4da Binary files /dev/null and b/local_dm_control_suite/__pycache__/ball_in_cup.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/base.cpython-37.pyc b/local_dm_control_suite/__pycache__/base.cpython-37.pyc new file mode 100644 index 0000000..8dcfba2 Binary files /dev/null and b/local_dm_control_suite/__pycache__/base.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/cartpole.cpython-37.pyc b/local_dm_control_suite/__pycache__/cartpole.cpython-37.pyc new file mode 100644 index 0000000..2502c7d Binary files /dev/null and b/local_dm_control_suite/__pycache__/cartpole.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/cheetah.cpython-37.pyc b/local_dm_control_suite/__pycache__/cheetah.cpython-37.pyc new file mode 100644 index 0000000..842360a Binary files /dev/null and b/local_dm_control_suite/__pycache__/cheetah.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/finger.cpython-37.pyc b/local_dm_control_suite/__pycache__/finger.cpython-37.pyc new file mode 100644 index 0000000..ac75337 Binary files /dev/null and b/local_dm_control_suite/__pycache__/finger.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/fish.cpython-37.pyc b/local_dm_control_suite/__pycache__/fish.cpython-37.pyc new file mode 100644 index 0000000..f053cc1 Binary files /dev/null and b/local_dm_control_suite/__pycache__/fish.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/hopper.cpython-37.pyc b/local_dm_control_suite/__pycache__/hopper.cpython-37.pyc new file mode 100644 index 0000000..99bac8a Binary files /dev/null and b/local_dm_control_suite/__pycache__/hopper.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/humanoid.cpython-37.pyc b/local_dm_control_suite/__pycache__/humanoid.cpython-37.pyc new file mode 100644 index 0000000..42dc7c4 Binary files /dev/null and b/local_dm_control_suite/__pycache__/humanoid.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/humanoid_CMU.cpython-37.pyc b/local_dm_control_suite/__pycache__/humanoid_CMU.cpython-37.pyc new file mode 100644 index 0000000..19273d9 Binary files /dev/null and b/local_dm_control_suite/__pycache__/humanoid_CMU.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/lqr.cpython-37.pyc b/local_dm_control_suite/__pycache__/lqr.cpython-37.pyc new file mode 100644 index 0000000..3d39637 Binary files /dev/null and b/local_dm_control_suite/__pycache__/lqr.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/manipulator.cpython-37.pyc b/local_dm_control_suite/__pycache__/manipulator.cpython-37.pyc new file mode 100644 index 0000000..33a144f Binary files /dev/null and b/local_dm_control_suite/__pycache__/manipulator.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/pendulum.cpython-37.pyc b/local_dm_control_suite/__pycache__/pendulum.cpython-37.pyc new file mode 100644 index 0000000..c68f0d6 Binary files /dev/null and b/local_dm_control_suite/__pycache__/pendulum.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/point_mass.cpython-37.pyc b/local_dm_control_suite/__pycache__/point_mass.cpython-37.pyc new file mode 100644 index 0000000..eaa9839 Binary files /dev/null and b/local_dm_control_suite/__pycache__/point_mass.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/quadruped.cpython-37.pyc b/local_dm_control_suite/__pycache__/quadruped.cpython-37.pyc new file mode 100644 index 0000000..533a47b Binary files /dev/null and b/local_dm_control_suite/__pycache__/quadruped.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/reacher.cpython-37.pyc b/local_dm_control_suite/__pycache__/reacher.cpython-37.pyc new file mode 100644 index 0000000..b2012fe Binary files /dev/null and b/local_dm_control_suite/__pycache__/reacher.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/stacker.cpython-37.pyc b/local_dm_control_suite/__pycache__/stacker.cpython-37.pyc new file mode 100644 index 0000000..6bcea80 Binary files /dev/null and b/local_dm_control_suite/__pycache__/stacker.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/swimmer.cpython-37.pyc b/local_dm_control_suite/__pycache__/swimmer.cpython-37.pyc new file mode 100644 index 0000000..499855e Binary files /dev/null and b/local_dm_control_suite/__pycache__/swimmer.cpython-37.pyc differ diff --git a/local_dm_control_suite/__pycache__/walker.cpython-37.pyc b/local_dm_control_suite/__pycache__/walker.cpython-37.pyc new file mode 100644 index 0000000..aba2ec4 Binary files /dev/null and b/local_dm_control_suite/__pycache__/walker.cpython-37.pyc differ diff --git a/local_dm_control_suite/common/__pycache__/__init__.cpython-37.pyc b/local_dm_control_suite/common/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..98acd29 Binary files /dev/null and b/local_dm_control_suite/common/__pycache__/__init__.cpython-37.pyc differ