DBC/dmc2gym/natural_imgsource.py

81 lines
2.7 KiB
Python
Raw Normal View History

2023-06-01 14:23:18 +00:00
# This code provides the class that is used to generate backgrounds for the natural background setting
# the class is used inside an environment wrapper and will be called each time the env generates an observation
# the code is largely based on https://github.com/facebookresearch/deep_bisim4control
2020-10-12 22:39:25 +00:00
2023-06-01 14:23:18 +00:00
import random
2020-10-12 22:39:25 +00:00
import cv2
2023-06-01 14:23:18 +00:00
import numpy as np
2020-10-12 22:39:25 +00:00
import skvideo.io
class ImageSource(object):
"""
Source of natural images to be added to a simulated environment.
"""
2023-06-01 14:23:18 +00:00
2020-10-12 22:39:25 +00:00
def get_image(self):
"""
Returns:
an RGB image of [h, w, 3] with a fixed shape.
"""
pass
def reset(self):
""" Called when an episode ends. """
pass
class RandomVideoSource(ImageSource):
2023-06-01 14:23:18 +00:00
def __init__(self, shape, filelist, random_bg=False, max_videos=50, grayscale=False):
2020-10-12 22:39:25 +00:00
"""
Args:
shape: [h, w]
filelist: a list of video files
"""
self.grayscale = grayscale
self.shape = shape
self.filelist = filelist
2023-06-01 14:23:18 +00:00
random.shuffle(self.filelist)
self.filelist = self.filelist[:max_videos]
self.max_videos = max_videos
self.random_bg = random_bg
2020-10-12 22:39:25 +00:00
self.current_idx = 0
2023-06-01 14:23:18 +00:00
self._current_vid = None
2020-10-12 22:39:25 +00:00
self.reset()
2023-06-01 14:23:18 +00:00
def load_video(self, vid_id):
fname = self.filelist[vid_id]
if self.grayscale:
frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
2020-10-12 22:39:25 +00:00
else:
2023-06-01 14:23:18 +00:00
frames = skvideo.io.vread(fname, num_frames=1000)
2020-10-12 22:39:25 +00:00
2023-06-01 14:23:18 +00:00
img_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
for i in range(frames.shape[0]):
if self.grayscale:
img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))[..., None] # THIS IS NOT A BUG! cv2 uses (width, height)
else:
img_arr[i] = cv2.resize(frames[i], (self.shape[1], self.shape[0]))
return img_arr
2020-10-12 22:39:25 +00:00
def reset(self):
2023-06-01 14:23:18 +00:00
del self._current_vid
self._video_id = np.random.randint(0, len(self.filelist))
self._current_vid = self.load_video(self._video_id)
while True:
try:
self._video_id = np.random.randint(0, len(self.filelist))
self._current_vid = self.load_video(self._video_id)
break
except Exception:
continue
self._loc = np.random.randint(0, len(self._current_vid))
2020-10-12 22:39:25 +00:00
def get_image(self):
2023-06-01 14:23:18 +00:00
if self.random_bg:
self._loc = np.random.randint(0, len(self._current_vid))
else:
self._loc += 1
img = self._current_vid[self._loc % len(self._current_vid)]
return img