Adding files
This commit is contained in:
parent
4d57be91cd
commit
6c82d972f8
@ -1,36 +1,19 @@
|
||||
# This code provides the class that is used to generate backgrounds for the natural background setting
|
||||
# the class is used inside an environment wrapper and will be called each time the env generates an observation
|
||||
# the code is largely based on https://github.com/facebookresearch/deep_bisim4control
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import skvideo.io
|
||||
import random
|
||||
import tqdm
|
||||
|
||||
class BackgroundMatting(object):
|
||||
"""
|
||||
Produce a mask by masking the given color. This is a simple strategy
|
||||
but effective for many games.
|
||||
"""
|
||||
def __init__(self, color):
|
||||
"""
|
||||
Args:
|
||||
color: a (r, g, b) tuple or single value for grayscale
|
||||
"""
|
||||
self._color = color
|
||||
|
||||
def get_mask(self, img):
|
||||
return img == self._color
|
||||
import cv2
|
||||
import numpy as np
|
||||
import skvideo.io
|
||||
|
||||
|
||||
class ImageSource(object):
|
||||
"""
|
||||
Source of natural images to be added to a simulated environment.
|
||||
"""
|
||||
|
||||
def get_image(self):
|
||||
"""
|
||||
Returns:
|
||||
@ -43,141 +26,56 @@ class ImageSource(object):
|
||||
pass
|
||||
|
||||
|
||||
class FixedColorSource(ImageSource):
|
||||
def __init__(self, shape, color):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
color: a 3-tuple
|
||||
"""
|
||||
self.arr = np.zeros((shape[0], shape[1], 3))
|
||||
self.arr[:, :] = color
|
||||
|
||||
def get_image(self):
|
||||
return self.arr
|
||||
|
||||
|
||||
class RandomColorSource(ImageSource):
|
||||
def __init__(self, shape):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
"""
|
||||
self.shape = shape
|
||||
self.arr = None
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._color = np.random.randint(0, 256, size=(3,))
|
||||
self.arr = np.zeros((self.shape[0], self.shape[1], 3))
|
||||
self.arr[:, :] = self._color
|
||||
|
||||
def get_image(self):
|
||||
return self.arr
|
||||
|
||||
|
||||
class NoiseSource(ImageSource):
|
||||
def __init__(self, shape, strength=255):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
strength (int): the strength of noise, in range [0, 255]
|
||||
"""
|
||||
self.shape = shape
|
||||
self.strength = strength
|
||||
|
||||
def get_image(self):
|
||||
return np.random.randn(self.shape[0], self.shape[1], 3) * self.strength
|
||||
|
||||
|
||||
class RandomImageSource(ImageSource):
|
||||
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
filelist: a list of image files
|
||||
"""
|
||||
self.grayscale = grayscale
|
||||
self.total_frames = total_frames
|
||||
self.shape = shape
|
||||
self.filelist = filelist
|
||||
self.build_arr()
|
||||
self.current_idx = 0
|
||||
self.reset()
|
||||
|
||||
def build_arr(self):
|
||||
self.total_frames = self.total_frames if self.total_frames else len(self.filelist)
|
||||
self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||
for i in range(self.total_frames):
|
||||
# if i % len(self.filelist) == 0: random.shuffle(self.filelist)
|
||||
fname = self.filelist[i % len(self.filelist)]
|
||||
if self.grayscale: im = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)[..., None]
|
||||
else: im = cv2.imread(fname, cv2.IMREAD_COLOR)
|
||||
self.arr[i] = cv2.resize(im, (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
|
||||
|
||||
def reset(self):
|
||||
self._loc = np.random.randint(0, self.total_frames)
|
||||
|
||||
def get_image(self):
|
||||
return self.arr[self._loc]
|
||||
|
||||
|
||||
class RandomVideoSource(ImageSource):
|
||||
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
||||
def __init__(self, shape, filelist, random_bg=False, max_videos=50, grayscale=False):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
filelist: a list of video files
|
||||
"""
|
||||
self.grayscale = grayscale
|
||||
self.total_frames = total_frames
|
||||
self.shape = shape
|
||||
self.filelist = filelist
|
||||
self.build_arr()
|
||||
random.shuffle(self.filelist)
|
||||
self.filelist = self.filelist[:max_videos]
|
||||
self.max_videos = max_videos
|
||||
self.random_bg = random_bg
|
||||
self.current_idx = 0
|
||||
self._current_vid = None
|
||||
self.reset()
|
||||
|
||||
def build_arr(self):
|
||||
if not self.total_frames:
|
||||
self.total_frames = 0
|
||||
self.arr = None
|
||||
random.shuffle(self.filelist)
|
||||
for fname in tqdm.tqdm(self.filelist, desc="Loading videos for natural", position=0):
|
||||
if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
||||
else: frames = skvideo.io.vread(fname)
|
||||
local_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||
for i in tqdm.tqdm(range(frames.shape[0]), desc="video frames", position=1):
|
||||
local_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
|
||||
if self.arr is None:
|
||||
self.arr = local_arr
|
||||
else:
|
||||
self.arr = np.concatenate([self.arr, local_arr], 0)
|
||||
self.total_frames += local_arr.shape[0]
|
||||
else:
|
||||
self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||
total_frame_i = 0
|
||||
file_i = 0
|
||||
with tqdm.tqdm(total=self.total_frames, desc="Loading videos for natural") as pbar:
|
||||
while total_frame_i < self.total_frames:
|
||||
if file_i % len(self.filelist) == 0: random.shuffle(self.filelist)
|
||||
file_i += 1
|
||||
fname = self.filelist[file_i % len(self.filelist)]
|
||||
if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
||||
else: frames = skvideo.io.vread(fname)
|
||||
for frame_i in range(frames.shape[0]):
|
||||
if total_frame_i >= self.total_frames: break
|
||||
def load_video(self, vid_id):
|
||||
fname = self.filelist[vid_id]
|
||||
if self.grayscale:
|
||||
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))[..., None] ## THIS IS NOT A BUG! cv2 uses (width, height)
|
||||
frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
||||
else:
|
||||
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))
|
||||
pbar.update(1)
|
||||
total_frame_i += 1
|
||||
frames = skvideo.io.vread(fname, num_frames=1000)
|
||||
|
||||
img_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||
for i in range(frames.shape[0]):
|
||||
if self.grayscale:
|
||||
img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))[..., None] # THIS IS NOT A BUG! cv2 uses (width, height)
|
||||
else:
|
||||
img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))
|
||||
return img_arr
|
||||
|
||||
def reset(self):
|
||||
self._loc = np.random.randint(0, self.total_frames)
|
||||
del self._current_vid
|
||||
self._video_id = np.random.randint(0, len(self.filelist))
|
||||
self._current_vid = self.load_video(self._video_id)
|
||||
while True:
|
||||
try:
|
||||
self._video_id = np.random.randint(0, len(self.filelist))
|
||||
self._current_vid = self.load_video(self._video_id)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
self._loc = np.random.randint(0, len(self._current_vid))
|
||||
|
||||
def get_image(self):
|
||||
img = self.arr[self._loc % self.total_frames]
|
||||
if self.random_bg:
|
||||
self._loc = np.random.randint(0, len(self._current_vid))
|
||||
else:
|
||||
self._loc += 1
|
||||
img = self._current_vid[self._loc % len(self._current_vid)]
|
||||
return img
|
@ -8,6 +8,11 @@ import skimage.io
|
||||
|
||||
from dmc2gym import natural_imgsource
|
||||
|
||||
high_noise = False
|
||||
|
||||
def set_global_var(set_high_noise):
|
||||
global high_noise
|
||||
high_noise = set_high_noise
|
||||
|
||||
def _spec_to_box(spec):
|
||||
def extract_min_max(s):
|
||||
@ -108,13 +113,16 @@ class DMCWrapper(core.Env):
|
||||
self._bg_source = natural_imgsource.NoiseSource(shape2d)
|
||||
else:
|
||||
files = glob.glob(os.path.expanduser(resource_files))
|
||||
self.files = files
|
||||
self.total_frames = total_frames
|
||||
self.shape2d = shape2d
|
||||
assert len(files), "Pattern {} does not match any files".format(
|
||||
resource_files
|
||||
)
|
||||
if img_source == "images":
|
||||
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
||||
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=False, max_videos=50, random_bg=False)
|
||||
elif img_source == "video":
|
||||
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
||||
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=False,max_videos=50, random_bg=False)
|
||||
else:
|
||||
raise Exception("img_source %s not defined." % img_source)
|
||||
|
||||
@ -185,6 +193,8 @@ class DMCWrapper(core.Env):
|
||||
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
self._bg_source.reset()
|
||||
#self._bg_source = natural_imgsource.RandomVideoSource(self.shape2d, self.files, grayscale=True, total_frames=self.total_frames, high_noise=high_noise)
|
||||
obs = self._get_obs(time_step)
|
||||
return obs
|
||||
|
||||
|
BIN
local_dm_control_suite/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/acrobot.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/acrobot.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/ball_in_cup.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/ball_in_cup.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/base.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/base.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/cartpole.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/cartpole.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/cheetah.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/cheetah.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/finger.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/finger.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/fish.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/fish.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/hopper.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/hopper.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/humanoid.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/humanoid.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/humanoid_CMU.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/humanoid_CMU.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/lqr.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/lqr.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/manipulator.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/manipulator.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/pendulum.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/pendulum.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/point_mass.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/point_mass.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/quadruped.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/quadruped.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/reacher.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/reacher.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/stacker.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/stacker.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/swimmer.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/swimmer.cpython-37.pyc
Normal file
Binary file not shown.
BIN
local_dm_control_suite/__pycache__/walker.cpython-37.pyc
Normal file
BIN
local_dm_control_suite/__pycache__/walker.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user