Adding Files
This commit is contained in:
parent
a8b9de1e7e
commit
bd4410e9d0
@ -1,36 +1,19 @@
|
|||||||
|
# This code provides the class that is used to generate backgrounds for the natural background setting
|
||||||
|
# the class is used inside an environment wrapper and will be called each time the env generates an observation
|
||||||
|
# the code is largely based on https://github.com/facebookresearch/deep_bisim4control
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import skvideo.io
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
|
||||||
|
|
||||||
class BackgroundMatting(object):
|
import cv2
|
||||||
"""
|
import numpy as np
|
||||||
Produce a mask by masking the given color. This is a simple strategy
|
import skvideo.io
|
||||||
but effective for many games.
|
|
||||||
"""
|
|
||||||
def __init__(self, color):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
color: a (r, g, b) tuple or single value for grayscale
|
|
||||||
"""
|
|
||||||
self._color = color
|
|
||||||
|
|
||||||
def get_mask(self, img):
|
|
||||||
return img == self._color
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSource(object):
|
class ImageSource(object):
|
||||||
"""
|
"""
|
||||||
Source of natural images to be added to a simulated environment.
|
Source of natural images to be added to a simulated environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_image(self):
|
def get_image(self):
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -43,141 +26,57 @@ class ImageSource(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FixedColorSource(ImageSource):
|
|
||||||
def __init__(self, shape, color):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
color: a 3-tuple
|
|
||||||
"""
|
|
||||||
self.arr = np.zeros((shape[0], shape[1], 3))
|
|
||||||
self.arr[:, :] = color
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
return self.arr
|
|
||||||
|
|
||||||
|
|
||||||
class RandomColorSource(ImageSource):
|
|
||||||
def __init__(self, shape):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
"""
|
|
||||||
self.shape = shape
|
|
||||||
self.arr = None
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._color = np.random.randint(0, 256, size=(3,))
|
|
||||||
self.arr = np.zeros((self.shape[0], self.shape[1], 3))
|
|
||||||
self.arr[:, :] = self._color
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
return self.arr
|
|
||||||
|
|
||||||
|
|
||||||
class NoiseSource(ImageSource):
|
|
||||||
def __init__(self, shape, strength=255):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
strength (int): the strength of noise, in range [0, 255]
|
|
||||||
"""
|
|
||||||
self.shape = shape
|
|
||||||
self.strength = strength
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
return np.random.randn(self.shape[0], self.shape[1], 3) * self.strength
|
|
||||||
|
|
||||||
|
|
||||||
class RandomImageSource(ImageSource):
|
|
||||||
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
shape: [h, w]
|
|
||||||
filelist: a list of image files
|
|
||||||
"""
|
|
||||||
self.grayscale = grayscale
|
|
||||||
self.total_frames = total_frames
|
|
||||||
self.shape = shape
|
|
||||||
self.filelist = filelist
|
|
||||||
self.build_arr()
|
|
||||||
self.current_idx = 0
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def build_arr(self):
|
|
||||||
self.total_frames = self.total_frames if self.total_frames else len(self.filelist)
|
|
||||||
self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
|
||||||
for i in range(self.total_frames):
|
|
||||||
# if i % len(self.filelist) == 0: random.shuffle(self.filelist)
|
|
||||||
fname = self.filelist[i % len(self.filelist)]
|
|
||||||
if self.grayscale: im = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)[..., None]
|
|
||||||
else: im = cv2.imread(fname, cv2.IMREAD_COLOR)
|
|
||||||
self.arr[i] = cv2.resize(im, (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._loc = np.random.randint(0, self.total_frames)
|
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
return self.arr[self._loc]
|
|
||||||
|
|
||||||
|
|
||||||
class RandomVideoSource(ImageSource):
|
class RandomVideoSource(ImageSource):
|
||||||
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
def __init__(self, shape, filelist, random_bg=False, max_videos=100, grayscale=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
shape: [h, w]
|
shape: [h, w]
|
||||||
filelist: a list of video files
|
filelist: a list of video files
|
||||||
"""
|
"""
|
||||||
self.grayscale = grayscale
|
self.grayscale = grayscale
|
||||||
self.total_frames = total_frames
|
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
self.filelist = filelist
|
self.filelist = filelist
|
||||||
self.build_arr()
|
random.shuffle(self.filelist)
|
||||||
|
self.filelist = self.filelist[:max_videos]
|
||||||
|
self.max_videos = max_videos
|
||||||
|
self.random_bg = random_bg
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
|
self._current_vid = None
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def build_arr(self):
|
def load_video(self, vid_id):
|
||||||
if not self.total_frames:
|
fname = self.filelist[vid_id]
|
||||||
self.total_frames = 0
|
|
||||||
self.arr = None
|
|
||||||
random.shuffle(self.filelist)
|
|
||||||
for fname in tqdm.tqdm(self.filelist, desc="Loading videos for natural", position=0):
|
|
||||||
if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
|
||||||
else: frames = skvideo.io.vread(fname)
|
|
||||||
local_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
|
||||||
for i in tqdm.tqdm(range(frames.shape[0]), desc="video frames", position=1):
|
|
||||||
local_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
|
|
||||||
if self.arr is None:
|
|
||||||
self.arr = local_arr
|
|
||||||
else:
|
|
||||||
self.arr = np.concatenate([self.arr, local_arr], 0)
|
|
||||||
self.total_frames += local_arr.shape[0]
|
|
||||||
else:
|
|
||||||
self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
|
||||||
total_frame_i = 0
|
|
||||||
file_i = 0
|
|
||||||
with tqdm.tqdm(total=self.total_frames, desc="Loading videos for natural") as pbar:
|
|
||||||
while total_frame_i < self.total_frames:
|
|
||||||
if file_i % len(self.filelist) == 0: random.shuffle(self.filelist)
|
|
||||||
file_i += 1
|
|
||||||
fname = self.filelist[file_i % len(self.filelist)]
|
|
||||||
if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
|
||||||
else: frames = skvideo.io.vread(fname)
|
|
||||||
for frame_i in range(frames.shape[0]):
|
|
||||||
if total_frame_i >= self.total_frames: break
|
|
||||||
if self.grayscale:
|
|
||||||
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))[..., None] ## THIS IS NOT A BUG! cv2 uses (width, height)
|
|
||||||
else:
|
|
||||||
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))
|
|
||||||
pbar.update(1)
|
|
||||||
total_frame_i += 1
|
|
||||||
|
|
||||||
|
if self.grayscale:
|
||||||
|
frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
|
||||||
|
else:
|
||||||
|
frames = skvideo.io.vread(fname, num_frames=1000)
|
||||||
|
|
||||||
|
img_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
|
||||||
|
for i in range(frames.shape[0]):
|
||||||
|
if self.grayscale:
|
||||||
|
img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))[..., None] # THIS IS NOT A BUG! cv2 uses (width, height)
|
||||||
|
else:
|
||||||
|
img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))
|
||||||
|
return img_arr
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._loc = np.random.randint(0, self.total_frames)
|
del self._current_vid
|
||||||
|
self._video_id = np.random.randint(0, len(self.filelist))
|
||||||
|
self._current_vid = self.load_video(self._video_id)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self._video_id = np.random.randint(0, len(self.filelist))
|
||||||
|
self._current_vid = self.load_video(self._video_id)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
self._loc = np.random.randint(0, len(self._current_vid))
|
||||||
|
|
||||||
def get_image(self):
|
def get_image(self):
|
||||||
img = self.arr[self._loc % self.total_frames]
|
if self.random_bg:
|
||||||
self._loc += 1
|
self._loc = np.random.randint(0, len(self._current_vid))
|
||||||
|
else:
|
||||||
|
self._loc += 1
|
||||||
|
img = self._current_vid[self._loc % len(self._current_vid)]
|
||||||
return img
|
return img
|
@ -8,6 +8,11 @@ import skimage.io
|
|||||||
|
|
||||||
from dmc2gym import natural_imgsource
|
from dmc2gym import natural_imgsource
|
||||||
|
|
||||||
|
high_noise = False
|
||||||
|
|
||||||
|
def set_global_var(set_high_noise):
|
||||||
|
global high_noise
|
||||||
|
high_noise = set_high_noise
|
||||||
|
|
||||||
def _spec_to_box(spec):
|
def _spec_to_box(spec):
|
||||||
def extract_min_max(s):
|
def extract_min_max(s):
|
||||||
@ -63,7 +68,6 @@ class DMCWrapper(core.Env):
|
|||||||
self._camera_id = camera_id
|
self._camera_id = camera_id
|
||||||
self._frame_skip = frame_skip
|
self._frame_skip = frame_skip
|
||||||
self._img_source = img_source
|
self._img_source = img_source
|
||||||
self._resource_files = resource_files
|
|
||||||
|
|
||||||
# create task
|
# create task
|
||||||
self._env = suite.load(
|
self._env = suite.load(
|
||||||
@ -109,13 +113,16 @@ class DMCWrapper(core.Env):
|
|||||||
self._bg_source = natural_imgsource.NoiseSource(shape2d)
|
self._bg_source = natural_imgsource.NoiseSource(shape2d)
|
||||||
else:
|
else:
|
||||||
files = glob.glob(os.path.expanduser(resource_files))
|
files = glob.glob(os.path.expanduser(resource_files))
|
||||||
|
self.files = files
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.shape2d = shape2d
|
||||||
assert len(files), "Pattern {} does not match any files".format(
|
assert len(files), "Pattern {} does not match any files".format(
|
||||||
resource_files
|
resource_files
|
||||||
)
|
)
|
||||||
if img_source == "images":
|
if img_source == "images":
|
||||||
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=False, max_videos=100, random_bg=False)
|
||||||
elif img_source == "video":
|
elif img_source == "video":
|
||||||
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=False,max_videos=100, random_bg=False)
|
||||||
else:
|
else:
|
||||||
raise Exception("img_source %s not defined." % img_source)
|
raise Exception("img_source %s not defined." % img_source)
|
||||||
|
|
||||||
@ -136,9 +143,7 @@ class DMCWrapper(core.Env):
|
|||||||
mask = np.logical_and((obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])) # hardcoded for dmc
|
mask = np.logical_and((obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])) # hardcoded for dmc
|
||||||
bg = self._bg_source.get_image()
|
bg = self._bg_source.get_image()
|
||||||
obs[mask] = bg[mask]
|
obs[mask] = bg[mask]
|
||||||
# obs = obs.transpose(2, 0, 1).copy()
|
obs = obs.transpose(2, 0, 1).copy()
|
||||||
# CHW to HWC for tensorflow
|
|
||||||
obs = obs.copy()
|
|
||||||
else:
|
else:
|
||||||
obs = _flatten_obs(time_step.observation)
|
obs = _flatten_obs(time_step.observation)
|
||||||
return obs
|
return obs
|
||||||
@ -188,6 +193,8 @@ class DMCWrapper(core.Env):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
time_step = self._env.reset()
|
time_step = self._env.reset()
|
||||||
|
self._bg_source.reset()
|
||||||
|
#self._bg_source = natural_imgsource.RandomVideoSource(self.shape2d, self.files, grayscale=True, total_frames=self.total_frames, high_noise=high_noise)
|
||||||
obs = self._get_obs(time_step)
|
obs = self._get_obs(time_step)
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
@ -96,11 +96,12 @@ class Dreamer(tools.Module):
|
|||||||
|
|
||||||
@tf.function()
|
@tf.function()
|
||||||
def train(self, data, log_images=False):
|
def train(self, data, log_images=False):
|
||||||
self._strategy.experimental_run_v2(
|
self._strategy.run(
|
||||||
self._train, args=(data, log_images))
|
self._train, args=(data, log_images))
|
||||||
|
|
||||||
def _train(self, data, log_images):
|
def _train(self, data, log_images):
|
||||||
with tf.GradientTape() as model_tape:
|
with tf.GradientTape() as model_tape:
|
||||||
|
data["image"] = tf.transpose(data["image"], perm=[0, 1, 3, 4, 2])
|
||||||
embed = self._encode(data)
|
embed = self._encode(data)
|
||||||
post, prior = self._dynamics.observe(embed, data['action'])
|
post, prior = self._dynamics.observe(embed, data['action'])
|
||||||
feat = self._dynamics.get_feat(post)
|
feat = self._dynamics.get_feat(post)
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
<option timestep="0.01"/>
|
<option timestep="0.01"/>
|
||||||
|
|
||||||
<worldbody>
|
<worldbody>
|
||||||
<geom name="ground" type="plane" conaffinity="1" pos="98 0 0" size="100 .8 .5" material="grid"/>
|
<geom name="ground" type="plane" conaffinity="1" pos="98 0 0" size="100 .8 .5" rgba="0.8 0.9 0.8 0" material="grid"/>
|
||||||
<body name="torso" pos="0 0 .7" childclass="cheetah">
|
<body name="torso" pos="0 0 .7" childclass="cheetah">
|
||||||
<light name="light" pos="0 0 2" mode="trackcom"/>
|
<light name="light" pos="0 0 2" mode="trackcom"/>
|
||||||
<camera name="side" pos="0 -3 0" quat="0.707 0.707 0 0" mode="trackcom"/>
|
<camera name="side" pos="0 -3 0" quat="0.707 0.707 0 0" mode="trackcom"/>
|
||||||
|
@ -62,10 +62,10 @@ def main(method, config):
|
|||||||
str(config.logdir), max_queue=1000, flush_millis=20000)
|
str(config.logdir), max_queue=1000, flush_millis=20000)
|
||||||
writer.set_as_default()
|
writer.set_as_default()
|
||||||
train_envs = [wrappers.Async(lambda: make_env(
|
train_envs = [wrappers.Async(lambda: make_env(
|
||||||
config, writer, 'train', datadir, config.video_dir, store=True), config.parallel)
|
config, writer, 'train', datadir, config.video_dir_train, store=True), config.parallel)
|
||||||
for _ in range(config.envs)]
|
for _ in range(config.envs)]
|
||||||
test_envs = [wrappers.Async(lambda: make_env(
|
test_envs = [wrappers.Async(lambda: make_env(
|
||||||
config, writer, 'test', datadir, config.video_dir, store=False), config.parallel)
|
config, writer, 'test', datadir, config.video_dir_test, store=False), config.parallel)
|
||||||
for _ in range(config.envs)]
|
for _ in range(config.envs)]
|
||||||
actspace = train_envs[0].action_space
|
actspace = train_envs[0].action_space
|
||||||
|
|
||||||
|
@ -86,7 +86,11 @@ def video_summary(name, video, step=None, fps=20):
|
|||||||
|
|
||||||
def encode_gif(frames, fps):
|
def encode_gif(frames, fps):
|
||||||
from subprocess import Popen, PIPE
|
from subprocess import Popen, PIPE
|
||||||
|
print(frames[0].shape)
|
||||||
|
if frames[0].shape[-1] != 3:
|
||||||
|
frames = np.transpose(frames, [0, 2, 3, 1])
|
||||||
h, w, c = frames[0].shape
|
h, w, c = frames[0].shape
|
||||||
|
print(h,w,c)
|
||||||
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
|
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
|
||||||
cmd = ' '.join([
|
cmd = ' '.join([
|
||||||
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
|
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
|
||||||
@ -123,6 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|||||||
# Step agents.
|
# Step agents.
|
||||||
# if use augmentation, need to modify dreamer.policy or here.
|
# if use augmentation, need to modify dreamer.policy or here.
|
||||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
||||||
|
obs['image'] = tf.transpose(obs['image'], [0, 3, 2, 1])
|
||||||
action, agent_state = agent(obs, done, agent_state)
|
action, agent_state = agent(obs, done, agent_state)
|
||||||
action = np.array(action)
|
action = np.array(action)
|
||||||
assert len(action) == len(envs)
|
assert len(action) == len(envs)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
dmc:
|
dmc:
|
||||||
|
|
||||||
logdir: ./
|
logdir: /home/vedant/tia/Dreamer/logdir
|
||||||
video_dir: ./
|
video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/
|
||||||
|
video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/
|
||||||
debug: False
|
debug: False
|
||||||
seed: 0
|
seed: 0
|
||||||
steps: 1000000.0
|
steps: 1000000.0
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
dmc:
|
dmc:
|
||||||
|
|
||||||
logdir: ./
|
logdir: /home/vedant/tia/Dreamer/logdir
|
||||||
video_dir: ./
|
video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/
|
||||||
|
video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/
|
||||||
debug: False
|
debug: False
|
||||||
seed: 0
|
seed: 0
|
||||||
steps: 1000000.0
|
steps: 1000000.0
|
||||||
|
Loading…
Reference in New Issue
Block a user