# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import imageio import os import numpy as np import glob from dmc2gym.natural_imgsource import RandomVideoSource class VideoRecorder(object): def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30): self.dir_name = dir_name self.height = height self.width = width self.camera_id = camera_id self.fps = fps self.frames = [] self.resource_files = resource_files if resource_files: files = glob.glob(os.path.expanduser(resource_files)) self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000) else: self._bg_source = None def init(self, enabled=True): self.frames = [] self.enabled = self.dir_name is not None and enabled def record(self, env): if self.enabled: frame = env.render( mode='rgb_array', height=self.height, width=self.width, camera_id=self.camera_id ) if self._bg_source: mask = np.logical_and((frame[:, :, 2] > frame[:, :, 1]), (frame[:, :, 2] > frame[:, :, 0])) # hardcoded for dmc bg = self._bg_source.get_image() frame[mask] = bg[mask] self.frames.append(frame) def record_from_obs(self, obs): if self.enabled: frame = obs if self._bg_source: mask = np.logical_and((frame[:, :, 2] > frame[:, :, 1]), (frame[:, :, 2] > frame[:, :, 0])) bg = self._bg_source.get_image() frame[mask] = bg[mask] self.frames.append(frame.transpose(1, 2, 0)) def save(self, file_name): if self.enabled: path = os.path.join(self.dir_name, file_name) imageio.mimsave(path, self.frames, fps=self.fps)