Adding high noise by randomising the frames

This commit is contained in:
Vedant Dave 2023-03-25 17:07:07 +01:00
parent f2aa9baebb
commit 25c2853ba6
6 changed files with 64 additions and 17 deletions

View File

@ -16,7 +16,9 @@ def make(
camera_id=0,
frame_skip=1,
episode_length=1000,
environment_kwargs=None
environment_kwargs=None,
video_recording=False,
video_recording_dir=None,
):
env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
@ -46,6 +48,8 @@ def make(
'width': width,
'camera_id': camera_id,
'frame_skip': frame_skip,
'video_recording': video_recording,
'video_recording_dir': video_recording_dir,
},
max_episode_steps=max_episode_steps
)

View File

@ -123,7 +123,7 @@ class RandomImageSource(ImageSource):
class RandomVideoSource(ImageSource):
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
def __init__(self, shape, filelist, total_frames=None, grayscale=False, high_noise=False):
"""
Args:
shape: [h, w]
@ -133,6 +133,7 @@ class RandomVideoSource(ImageSource):
self.total_frames = total_frames
self.shape = shape
self.filelist = filelist
self.high_noise = high_noise
self.build_arr()
self.current_idx = 0
self.reset()
@ -173,6 +174,9 @@ class RandomVideoSource(ImageSource):
pbar.update(1)
total_frame_i += 1
# Randomize the order of the frames
if self.high_noise:
random.shuffle(self.arr)
def reset(self):
self._loc = np.random.randint(0, self.total_frames)

View File

@ -6,9 +6,16 @@ from dm_env import specs
import numpy as np
import skimage.io
from video import VideoRecorder
from dmc2gym import natural_imgsource
high_noise = False
def set_global_var(set_high_noise):
global high_noise
high_noise = set_high_noise
def _spec_to_box(spec):
def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32
@ -54,7 +61,9 @@ class DMCWrapper(core.Env):
width=84,
camera_id=0,
frame_skip=1,
environment_kwargs=None
environment_kwargs=None,
video_recording=False,
video_recording_dir=None,
):
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
self._from_pixels = from_pixels
@ -99,6 +108,10 @@ class DMCWrapper(core.Env):
dtype=np.float32
)
# video recording
if video_recording:
self.video = VideoRecorder(video_recording_dir+"/video", resource_files=resource_files, high_noise=high_noise)
# background
if img_source is not None:
shape2d = (height, width)
@ -114,7 +127,7 @@ class DMCWrapper(core.Env):
if img_source == "images":
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
elif img_source == "video":
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames, high_noise=high_noise)
else:
raise Exception("img_source %s not defined." % img_source)

View File

@ -13,6 +13,7 @@ from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger
from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
@ -32,7 +33,8 @@ def parse_args():
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int)
parser.add_argument('--total_frames', default=10000, type=int)
parser.add_argument('--high_noise', action='store_true')
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
parser.add_argument('--episode_length', default=50, type=int)
@ -103,6 +105,9 @@ class DPI:
self.args = args
# set environment noise
set_global_var(self.args.high_noise)
# environment setup
self.env = make_env(self.args)
self.env.seed(self.args.seed)
@ -168,9 +173,9 @@ class DPI:
obs = self.env.reset()
done = False
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in range(episodes):
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
video.init(enabled=True)
self.env.video.init(enabled=True)
for i in range(self.args.episode_length):
action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action)
@ -178,14 +183,14 @@ class DPI:
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
if args.save_video:
video.record(self.env)
self.env.video.record(self.env)
if done:
obs = self.env.reset()
done=False
else:
obs = next_obs
video.save('%d.mp4' % episode_count)
self.env.video.save('%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1))
#if args.save_video:
# video.record(self.env)

View File

@ -5,16 +5,18 @@
# LICENSE file in the root directory of this source tree.
import os
import torch
import random
import numpy as np
from collections import deque
import torch
import torch.nn as nn
import gym
import dmc2gym
import random
import cv2
from PIL import Image
from collections import deque
class eval_mode(object):
@ -186,7 +188,9 @@ def make_env(args):
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
frame_skip=args.action_repeat,
video_recording=args.save_video,
video_recording_dir=args.work_dir,
)
return env
@ -195,3 +199,20 @@ def save_image(array, filename):
array = (array * 255).astype(np.uint8)
image = Image.fromarray(array)
image.save(filename)
def video_from_array(arr, high_noise, filename):
"""
Save a video from a numpy array of shape (T, H, W, C)
Example:
video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4')
"""
if arr.shape[-1] == 1:
height, width, channels = arr.shape[1:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height))
for i in range(arr.shape[0]):
frame = arr[i]
frame = np.uint8(frame)
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
out.write(frame)
out.release()

View File

@ -13,7 +13,7 @@ from dmc2gym.natural_imgsource import RandomVideoSource
class VideoRecorder(object):
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30):
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30, high_noise=False):
self.dir_name = dir_name
self.height = height
self.width = width
@ -23,7 +23,7 @@ class VideoRecorder(object):
self.resource_files = resource_files
if resource_files:
files = glob.glob(os.path.expanduser(resource_files))
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000)
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000, high_noise=high_noise)
else:
self._bg_source = None