diff --git a/Dreamer/dmc2gym/natural_imgsource.py b/Dreamer/dmc2gym/natural_imgsource.py
index 42ef62f..6205815 100644
--- a/Dreamer/dmc2gym/natural_imgsource.py
+++ b/Dreamer/dmc2gym/natural_imgsource.py
@@ -1,36 +1,19 @@
+# This code provides the class that is used to generate backgrounds for the natural background setting
+# the class is used inside an environment wrapper and will be called each time the env generates an observation
+# the code is largely based on https://github.com/facebookresearch/deep_bisim4control
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import numpy as np
-import cv2
-import skvideo.io
import random
-import tqdm
-class BackgroundMatting(object):
- """
- Produce a mask by masking the given color. This is a simple strategy
- but effective for many games.
- """
- def __init__(self, color):
- """
- Args:
- color: a (r, g, b) tuple or single value for grayscale
- """
- self._color = color
-
- def get_mask(self, img):
- return img == self._color
+import cv2
+import numpy as np
+import skvideo.io
class ImageSource(object):
"""
Source of natural images to be added to a simulated environment.
"""
+
def get_image(self):
"""
Returns:
@@ -43,141 +26,57 @@ class ImageSource(object):
pass
-class FixedColorSource(ImageSource):
- def __init__(self, shape, color):
- """
- Args:
- shape: [h, w]
- color: a 3-tuple
- """
- self.arr = np.zeros((shape[0], shape[1], 3))
- self.arr[:, :] = color
-
- def get_image(self):
- return self.arr
-
-
-class RandomColorSource(ImageSource):
- def __init__(self, shape):
- """
- Args:
- shape: [h, w]
- """
- self.shape = shape
- self.arr = None
- self.reset()
-
- def reset(self):
- self._color = np.random.randint(0, 256, size=(3,))
- self.arr = np.zeros((self.shape[0], self.shape[1], 3))
- self.arr[:, :] = self._color
-
- def get_image(self):
- return self.arr
-
-
-class NoiseSource(ImageSource):
- def __init__(self, shape, strength=255):
- """
- Args:
- shape: [h, w]
- strength (int): the strength of noise, in range [0, 255]
- """
- self.shape = shape
- self.strength = strength
-
- def get_image(self):
- return np.random.randn(self.shape[0], self.shape[1], 3) * self.strength
-
-
-class RandomImageSource(ImageSource):
- def __init__(self, shape, filelist, total_frames=None, grayscale=False):
- """
- Args:
- shape: [h, w]
- filelist: a list of image files
- """
- self.grayscale = grayscale
- self.total_frames = total_frames
- self.shape = shape
- self.filelist = filelist
- self.build_arr()
- self.current_idx = 0
- self.reset()
-
- def build_arr(self):
- self.total_frames = self.total_frames if self.total_frames else len(self.filelist)
- self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
- for i in range(self.total_frames):
- # if i % len(self.filelist) == 0: random.shuffle(self.filelist)
- fname = self.filelist[i % len(self.filelist)]
- if self.grayscale: im = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)[..., None]
- else: im = cv2.imread(fname, cv2.IMREAD_COLOR)
- self.arr[i] = cv2.resize(im, (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
-
- def reset(self):
- self._loc = np.random.randint(0, self.total_frames)
-
- def get_image(self):
- return self.arr[self._loc]
-
-
class RandomVideoSource(ImageSource):
- def __init__(self, shape, filelist, total_frames=None, grayscale=False):
+ def __init__(self, shape, filelist, random_bg=False, max_videos=100, grayscale=False):
"""
Args:
shape: [h, w]
filelist: a list of video files
"""
self.grayscale = grayscale
- self.total_frames = total_frames
self.shape = shape
self.filelist = filelist
- self.build_arr()
+ random.shuffle(self.filelist)
+ self.filelist = self.filelist[:max_videos]
+ self.max_videos = max_videos
+ self.random_bg = random_bg
self.current_idx = 0
+ self._current_vid = None
self.reset()
- def build_arr(self):
- if not self.total_frames:
- self.total_frames = 0
- self.arr = None
- random.shuffle(self.filelist)
- for fname in tqdm.tqdm(self.filelist, desc="Loading videos for natural", position=0):
- if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
- else: frames = skvideo.io.vread(fname)
- local_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
- for i in tqdm.tqdm(range(frames.shape[0]), desc="video frames", position=1):
- local_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0])) ## THIS IS NOT A BUG! cv2 uses (width, height)
- if self.arr is None:
- self.arr = local_arr
- else:
- self.arr = np.concatenate([self.arr, local_arr], 0)
- self.total_frames += local_arr.shape[0]
- else:
- self.arr = np.zeros((self.total_frames, self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
- total_frame_i = 0
- file_i = 0
- with tqdm.tqdm(total=self.total_frames, desc="Loading videos for natural") as pbar:
- while total_frame_i < self.total_frames:
- if file_i % len(self.filelist) == 0: random.shuffle(self.filelist)
- file_i += 1
- fname = self.filelist[file_i % len(self.filelist)]
- if self.grayscale: frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
- else: frames = skvideo.io.vread(fname)
- for frame_i in range(frames.shape[0]):
- if total_frame_i >= self.total_frames: break
- if self.grayscale:
- self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))[..., None] ## THIS IS NOT A BUG! cv2 uses (width, height)
- else:
- self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))
- pbar.update(1)
- total_frame_i += 1
+ def load_video(self, vid_id):
+ fname = self.filelist[vid_id]
+ if self.grayscale:
+ frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
+ else:
+ frames = skvideo.io.vread(fname, num_frames=1000)
+
+ img_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
+ for i in range(frames.shape[0]):
+ if self.grayscale:
+ img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))[..., None] # THIS IS NOT A BUG! cv2 uses (width, height)
+ else:
+ img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))
+ return img_arr
def reset(self):
- self._loc = np.random.randint(0, self.total_frames)
+ del self._current_vid
+ self._video_id = np.random.randint(0, len(self.filelist))
+ self._current_vid = self.load_video(self._video_id)
+ while True:
+ try:
+ self._video_id = np.random.randint(0, len(self.filelist))
+ self._current_vid = self.load_video(self._video_id)
+ break
+ except Exception:
+ continue
+ self._loc = np.random.randint(0, len(self._current_vid))
def get_image(self):
- img = self.arr[self._loc % self.total_frames]
- self._loc += 1
- return img
+ if self.random_bg:
+ self._loc = np.random.randint(0, len(self._current_vid))
+ else:
+ self._loc += 1
+ img = self._current_vid[self._loc % len(self._current_vid)]
+ return img
\ No newline at end of file
diff --git a/Dreamer/dmc2gym/wrappers.py b/Dreamer/dmc2gym/wrappers.py
index 7968416..69bfcbd 100644
--- a/Dreamer/dmc2gym/wrappers.py
+++ b/Dreamer/dmc2gym/wrappers.py
@@ -8,6 +8,11 @@ import skimage.io
from dmc2gym import natural_imgsource
+high_noise = False
+
+def set_global_var(set_high_noise):
+ global high_noise
+ high_noise = set_high_noise
def _spec_to_box(spec):
def extract_min_max(s):
@@ -63,7 +68,6 @@ class DMCWrapper(core.Env):
self._camera_id = camera_id
self._frame_skip = frame_skip
self._img_source = img_source
- self._resource_files = resource_files
# create task
self._env = suite.load(
@@ -109,13 +113,16 @@ class DMCWrapper(core.Env):
self._bg_source = natural_imgsource.NoiseSource(shape2d)
else:
files = glob.glob(os.path.expanduser(resource_files))
+ self.files = files
+ self.total_frames = total_frames
+ self.shape2d = shape2d
assert len(files), "Pattern {} does not match any files".format(
resource_files
)
if img_source == "images":
- self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
+ self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=False, max_videos=100, random_bg=False)
elif img_source == "video":
- self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
+ self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=False,max_videos=100, random_bg=False)
else:
raise Exception("img_source %s not defined." % img_source)
@@ -136,9 +143,7 @@ class DMCWrapper(core.Env):
mask = np.logical_and((obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])) # hardcoded for dmc
bg = self._bg_source.get_image()
obs[mask] = bg[mask]
- # obs = obs.transpose(2, 0, 1).copy()
- # CHW to HWC for tensorflow
- obs = obs.copy()
+ obs = obs.transpose(2, 0, 1).copy()
else:
obs = _flatten_obs(time_step.observation)
return obs
@@ -188,6 +193,8 @@ class DMCWrapper(core.Env):
def reset(self):
time_step = self._env.reset()
+ self._bg_source.reset()
+ #self._bg_source = natural_imgsource.RandomVideoSource(self.shape2d, self.files, grayscale=True, total_frames=self.total_frames, high_noise=high_noise)
obs = self._get_obs(time_step)
return obs
@@ -198,4 +205,4 @@ class DMCWrapper(core.Env):
camera_id = camera_id or self._camera_id
return self._env.physics.render(
height=height, width=width, camera_id=camera_id
- )
+ )
\ No newline at end of file
diff --git a/Dreamer/dreamers.py b/Dreamer/dreamers.py
index 62b1b3a..76612fd 100644
--- a/Dreamer/dreamers.py
+++ b/Dreamer/dreamers.py
@@ -96,11 +96,12 @@ class Dreamer(tools.Module):
@tf.function()
def train(self, data, log_images=False):
- self._strategy.experimental_run_v2(
+ self._strategy.run(
self._train, args=(data, log_images))
def _train(self, data, log_images):
with tf.GradientTape() as model_tape:
+ data["image"] = tf.transpose(data["image"], perm=[0, 1, 3, 4, 2])
embed = self._encode(data)
post, prior = self._dynamics.observe(embed, data['action'])
feat = self._dynamics.get_feat(post)
diff --git a/Dreamer/local_dm_control_suite/cheetah.xml b/Dreamer/local_dm_control_suite/cheetah.xml
index 1952b5e..dbce06c 100755
--- a/Dreamer/local_dm_control_suite/cheetah.xml
+++ b/Dreamer/local_dm_control_suite/cheetah.xml
@@ -21,7 +21,7 @@
-
+
diff --git a/Dreamer/run.py b/Dreamer/run.py
index 842bacb..c8ddc12 100644
--- a/Dreamer/run.py
+++ b/Dreamer/run.py
@@ -62,10 +62,10 @@ def main(method, config):
str(config.logdir), max_queue=1000, flush_millis=20000)
writer.set_as_default()
train_envs = [wrappers.Async(lambda: make_env(
- config, writer, 'train', datadir, config.video_dir, store=True), config.parallel)
+ config, writer, 'train', datadir, config.video_dir_train, store=True), config.parallel)
for _ in range(config.envs)]
test_envs = [wrappers.Async(lambda: make_env(
- config, writer, 'test', datadir, config.video_dir, store=False), config.parallel)
+ config, writer, 'test', datadir, config.video_dir_test, store=False), config.parallel)
for _ in range(config.envs)]
actspace = train_envs[0].action_space
diff --git a/Dreamer/tools.py b/Dreamer/tools.py
index 3bd8491..133f867 100644
--- a/Dreamer/tools.py
+++ b/Dreamer/tools.py
@@ -86,7 +86,11 @@ def video_summary(name, video, step=None, fps=20):
def encode_gif(frames, fps):
from subprocess import Popen, PIPE
+ print(frames[0].shape)
+ if frames[0].shape[-1] != 3:
+ frames = np.transpose(frames, [0, 2, 3, 1])
h, w, c = frames[0].shape
+ print(h,w,c)
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
cmd = ' '.join([
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
@@ -123,6 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
# Step agents.
# if use augmentation, need to modify dreamer.policy or here.
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
+ obs['image'] = tf.transpose(obs['image'], [0, 3, 2, 1])
action, agent_state = agent(obs, done, agent_state)
action = np.array(action)
assert len(action) == len(envs)
diff --git a/Dreamer/train_configs/dreamer.yaml b/Dreamer/train_configs/dreamer.yaml
index 1c8f68e..9d2b179 100644
--- a/Dreamer/train_configs/dreamer.yaml
+++ b/Dreamer/train_configs/dreamer.yaml
@@ -1,7 +1,8 @@
dmc:
- logdir: ./
- video_dir: ./
+ logdir: /home/vedant/tia/Dreamer/logdir
+ video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/
+ video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/
debug: False
seed: 0
steps: 1000000.0
diff --git a/Dreamer/train_configs/tia.yaml b/Dreamer/train_configs/tia.yaml
index ec4abc8..1bf3f85 100644
--- a/Dreamer/train_configs/tia.yaml
+++ b/Dreamer/train_configs/tia.yaml
@@ -1,7 +1,8 @@
dmc:
- logdir: ./
- video_dir: ./
+ logdir: /home/vedant/tia/Dreamer/logdir
+ video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/
+ video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/
debug: False
seed: 0
steps: 1000000.0