Adding high noise by randomising the frames

This commit is contained in:
Vedant Dave 2023-03-25 17:07:07 +01:00
parent f2aa9baebb
commit 25c2853ba6
6 changed files with 64 additions and 17 deletions

View File

@ -16,7 +16,9 @@ def make(
camera_id=0, camera_id=0,
frame_skip=1, frame_skip=1,
episode_length=1000, episode_length=1000,
environment_kwargs=None environment_kwargs=None,
video_recording=False,
video_recording_dir=None,
): ):
env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed) env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
@ -46,6 +48,8 @@ def make(
'width': width, 'width': width,
'camera_id': camera_id, 'camera_id': camera_id,
'frame_skip': frame_skip, 'frame_skip': frame_skip,
'video_recording': video_recording,
'video_recording_dir': video_recording_dir,
}, },
max_episode_steps=max_episode_steps max_episode_steps=max_episode_steps
) )

View File

@ -123,7 +123,7 @@ class RandomImageSource(ImageSource):
class RandomVideoSource(ImageSource): class RandomVideoSource(ImageSource):
def __init__(self, shape, filelist, total_frames=None, grayscale=False): def __init__(self, shape, filelist, total_frames=None, grayscale=False, high_noise=False):
""" """
Args: Args:
shape: [h, w] shape: [h, w]
@ -133,6 +133,7 @@ class RandomVideoSource(ImageSource):
self.total_frames = total_frames self.total_frames = total_frames
self.shape = shape self.shape = shape
self.filelist = filelist self.filelist = filelist
self.high_noise = high_noise
self.build_arr() self.build_arr()
self.current_idx = 0 self.current_idx = 0
self.reset() self.reset()
@ -172,7 +173,10 @@ class RandomVideoSource(ImageSource):
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0])) self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))
pbar.update(1) pbar.update(1)
total_frame_i += 1 total_frame_i += 1
# Randomize the order of the frames
if self.high_noise:
random.shuffle(self.arr)
def reset(self): def reset(self):
self._loc = np.random.randint(0, self.total_frames) self._loc = np.random.randint(0, self.total_frames)

View File

@ -6,9 +6,16 @@ from dm_env import specs
import numpy as np import numpy as np
import skimage.io import skimage.io
from video import VideoRecorder
from dmc2gym import natural_imgsource from dmc2gym import natural_imgsource
high_noise = False
def set_global_var(set_high_noise):
global high_noise
high_noise = set_high_noise
def _spec_to_box(spec): def _spec_to_box(spec):
def extract_min_max(s): def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32 assert s.dtype == np.float64 or s.dtype == np.float32
@ -54,7 +61,9 @@ class DMCWrapper(core.Env):
width=84, width=84,
camera_id=0, camera_id=0,
frame_skip=1, frame_skip=1,
environment_kwargs=None environment_kwargs=None,
video_recording=False,
video_recording_dir=None,
): ):
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
self._from_pixels = from_pixels self._from_pixels = from_pixels
@ -99,6 +108,10 @@ class DMCWrapper(core.Env):
dtype=np.float32 dtype=np.float32
) )
# video recording
if video_recording:
self.video = VideoRecorder(video_recording_dir+"/video", resource_files=resource_files, high_noise=high_noise)
# background # background
if img_source is not None: if img_source is not None:
shape2d = (height, width) shape2d = (height, width)
@ -114,7 +127,7 @@ class DMCWrapper(core.Env):
if img_source == "images": if img_source == "images":
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames) self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
elif img_source == "video": elif img_source == "video":
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames) self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames, high_noise=high_noise)
else: else:
raise Exception("img_source %s not defined." % img_source) raise Exception("img_source %s not defined." % img_source)

View File

@ -13,6 +13,7 @@ from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
#from agent.baseline_agent import BaselineAgent #from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent #from agent.bisim_agent import BisimAgent
@ -32,7 +33,8 @@ def parse_args():
parser.add_argument('--resource_files', type=str) parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int) parser.add_argument('--total_frames', default=10000, type=int)
parser.add_argument('--high_noise', action='store_true')
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000 parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
parser.add_argument('--episode_length', default=50, type=int) parser.add_argument('--episode_length', default=50, type=int)
@ -103,6 +105,9 @@ class DPI:
self.args = args self.args = args
# set environment noise
set_global_var(self.args.high_noise)
# environment setup # environment setup
self.env = make_env(self.args) self.env = make_env(self.args)
self.env.seed(self.args.seed) self.env.seed(self.args.seed)
@ -168,9 +173,9 @@ class DPI:
obs = self.env.reset() obs = self.env.reset()
done = False done = False
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in range(episodes): for episode_count in range(episodes):
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) self.env.video.init(enabled=True)
video.init(enabled=True)
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
action = self.env.action_space.sample() action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action) next_obs, _, done, _ = self.env.step(action)
@ -178,14 +183,14 @@ class DPI:
self.data_buffer.add(obs, action, next_obs, episode_count+1, done) self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
if args.save_video: if args.save_video:
video.record(self.env) self.env.video.record(self.env)
if done: if done:
obs = self.env.reset() obs = self.env.reset()
done=False done=False
else: else:
obs = next_obs obs = next_obs
video.save('%d.mp4' % episode_count) self.env.video.save('%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1)) print("Collected {} random episodes".format(episode_count+1))
#if args.save_video: #if args.save_video:
# video.record(self.env) # video.record(self.env)

View File

@ -5,16 +5,18 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import torch import random
import numpy as np import numpy as np
from collections import deque
import torch
import torch.nn as nn import torch.nn as nn
import gym import gym
import dmc2gym import dmc2gym
import random import cv2
from PIL import Image from PIL import Image
from collections import deque
class eval_mode(object): class eval_mode(object):
@ -186,7 +188,9 @@ def make_env(args):
from_pixels=(args.encoder_type == 'pixel'), from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size, height=args.image_size,
width=args.image_size, width=args.image_size,
frame_skip=args.action_repeat frame_skip=args.action_repeat,
video_recording=args.save_video,
video_recording_dir=args.work_dir,
) )
return env return env
@ -194,4 +198,21 @@ def save_image(array, filename):
array = array.transpose(1, 2, 0) array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8) array = (array * 255).astype(np.uint8)
image = Image.fromarray(array) image = Image.fromarray(array)
image.save(filename) image.save(filename)
def video_from_array(arr, high_noise, filename):
"""
Save a video from a numpy array of shape (T, H, W, C)
Example:
video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4')
"""
if arr.shape[-1] == 1:
height, width, channels = arr.shape[1:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height))
for i in range(arr.shape[0]):
frame = arr[i]
frame = np.uint8(frame)
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
out.write(frame)
out.release()

View File

@ -13,7 +13,7 @@ from dmc2gym.natural_imgsource import RandomVideoSource
class VideoRecorder(object): class VideoRecorder(object):
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30): def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30, high_noise=False):
self.dir_name = dir_name self.dir_name = dir_name
self.height = height self.height = height
self.width = width self.width = width
@ -23,7 +23,7 @@ class VideoRecorder(object):
self.resource_files = resource_files self.resource_files = resource_files
if resource_files: if resource_files:
files = glob.glob(os.path.expanduser(resource_files)) files = glob.glob(os.path.expanduser(resource_files))
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000) self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000, high_noise=high_noise)
else: else:
self._bg_source = None self._bg_source = None