diff --git a/DPI/dmc2gym/__init__.py b/DPI/dmc2gym/__init__.py index 7c1d277..ed7b897 100644 --- a/DPI/dmc2gym/__init__.py +++ b/DPI/dmc2gym/__init__.py @@ -16,7 +16,9 @@ def make( camera_id=0, frame_skip=1, episode_length=1000, - environment_kwargs=None + environment_kwargs=None, + video_recording=False, + video_recording_dir=None, ): env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed) @@ -46,6 +48,8 @@ def make( 'width': width, 'camera_id': camera_id, 'frame_skip': frame_skip, + 'video_recording': video_recording, + 'video_recording_dir': video_recording_dir, }, max_episode_steps=max_episode_steps ) diff --git a/DPI/dmc2gym/natural_imgsource.py b/DPI/dmc2gym/natural_imgsource.py index 42ef62f..a911262 100644 --- a/DPI/dmc2gym/natural_imgsource.py +++ b/DPI/dmc2gym/natural_imgsource.py @@ -123,7 +123,7 @@ class RandomImageSource(ImageSource): class RandomVideoSource(ImageSource): - def __init__(self, shape, filelist, total_frames=None, grayscale=False): + def __init__(self, shape, filelist, total_frames=None, grayscale=False, high_noise=False): """ Args: shape: [h, w] @@ -133,6 +133,7 @@ class RandomVideoSource(ImageSource): self.total_frames = total_frames self.shape = shape self.filelist = filelist + self.high_noise = high_noise self.build_arr() self.current_idx = 0 self.reset() @@ -172,7 +173,10 @@ class RandomVideoSource(ImageSource): self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0])) pbar.update(1) total_frame_i += 1 - + + # Randomize the order of the frames + if self.high_noise: + random.shuffle(self.arr) def reset(self): self._loc = np.random.randint(0, self.total_frames) diff --git a/DPI/dmc2gym/wrappers.py b/DPI/dmc2gym/wrappers.py index 077f2eb..7ba08c7 100644 --- a/DPI/dmc2gym/wrappers.py +++ b/DPI/dmc2gym/wrappers.py @@ -6,9 +6,16 @@ from dm_env import specs import numpy as np import skimage.io +from video import VideoRecorder from dmc2gym import natural_imgsource +high_noise = False + +def set_global_var(set_high_noise): + global high_noise + high_noise = set_high_noise + def _spec_to_box(spec): def extract_min_max(s): assert s.dtype == np.float64 or s.dtype == np.float32 @@ -54,7 +61,9 @@ class DMCWrapper(core.Env): width=84, camera_id=0, frame_skip=1, - environment_kwargs=None + environment_kwargs=None, + video_recording=False, + video_recording_dir=None, ): assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' self._from_pixels = from_pixels @@ -99,6 +108,10 @@ class DMCWrapper(core.Env): dtype=np.float32 ) + # video recording + if video_recording: + self.video = VideoRecorder(video_recording_dir+"/video", resource_files=resource_files, high_noise=high_noise) + # background if img_source is not None: shape2d = (height, width) @@ -114,7 +127,7 @@ class DMCWrapper(core.Env): if img_source == "images": self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames) elif img_source == "video": - self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames) + self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames, high_noise=high_noise) else: raise Exception("img_source %s not defined." % img_source) diff --git a/DPI/train.py b/DPI/train.py index 1b9c16f..622c0ad 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -13,6 +13,7 @@ from utils import ReplayBuffer, make_env, save_image from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample from logger import Logger from video import VideoRecorder +from dmc2gym.wrappers import set_global_var #from agent.baseline_agent import BaselineAgent #from agent.bisim_agent import BisimAgent @@ -32,7 +33,8 @@ def parse_args(): parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) - parser.add_argument('--total_frames', default=1000, type=int) + parser.add_argument('--total_frames', default=10000, type=int) + parser.add_argument('--high_noise', action='store_true') # replay buffer parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000 parser.add_argument('--episode_length', default=50, type=int) @@ -103,6 +105,9 @@ class DPI: self.args = args + # set environment noise + set_global_var(self.args.high_noise) + # environment setup self.env = make_env(self.args) self.env.seed(self.args.seed) @@ -168,9 +173,9 @@ class DPI: obs = self.env.reset() done = False + #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) for episode_count in range(episodes): - video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) - video.init(enabled=True) + self.env.video.init(enabled=True) for i in range(self.args.episode_length): action = self.env.action_space.sample() next_obs, _, done, _ = self.env.step(action) @@ -178,14 +183,14 @@ class DPI: self.data_buffer.add(obs, action, next_obs, episode_count+1, done) if args.save_video: - video.record(self.env) + self.env.video.record(self.env) if done: obs = self.env.reset() done=False else: obs = next_obs - video.save('%d.mp4' % episode_count) + self.env.video.save('%d.mp4' % episode_count) print("Collected {} random episodes".format(episode_count+1)) #if args.save_video: # video.record(self.env) diff --git a/DPI/utils.py b/DPI/utils.py index de28442..8cf759b 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -5,16 +5,18 @@ # LICENSE file in the root directory of this source tree. import os -import torch +import random import numpy as np +from collections import deque + +import torch import torch.nn as nn import gym import dmc2gym -import random +import cv2 from PIL import Image -from collections import deque class eval_mode(object): @@ -186,7 +188,9 @@ def make_env(args): from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, - frame_skip=args.action_repeat + frame_skip=args.action_repeat, + video_recording=args.save_video, + video_recording_dir=args.work_dir, ) return env @@ -194,4 +198,21 @@ def save_image(array, filename): array = array.transpose(1, 2, 0) array = (array * 255).astype(np.uint8) image = Image.fromarray(array) - image.save(filename) \ No newline at end of file + image.save(filename) + +def video_from_array(arr, high_noise, filename): + """ + Save a video from a numpy array of shape (T, H, W, C) + Example: + video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4') + """ + if arr.shape[-1] == 1: + height, width, channels = arr.shape[1:] + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height)) + for i in range(arr.shape[0]): + frame = arr[i] + frame = np.uint8(frame) + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) + out.write(frame) + out.release() \ No newline at end of file diff --git a/DPI/video.py b/DPI/video.py index c372207..745add6 100644 --- a/DPI/video.py +++ b/DPI/video.py @@ -13,7 +13,7 @@ from dmc2gym.natural_imgsource import RandomVideoSource class VideoRecorder(object): - def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30): + def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30, high_noise=False): self.dir_name = dir_name self.height = height self.width = width @@ -23,7 +23,7 @@ class VideoRecorder(object): self.resource_files = resource_files if resource_files: files = glob.glob(os.path.expanduser(resource_files)) - self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000) + self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000, high_noise=high_noise) else: self._bg_source = None