Adding high noise by randomising the frames
This commit is contained in:
parent
f2aa9baebb
commit
25c2853ba6
@ -16,7 +16,9 @@ def make(
|
||||
camera_id=0,
|
||||
frame_skip=1,
|
||||
episode_length=1000,
|
||||
environment_kwargs=None
|
||||
environment_kwargs=None,
|
||||
video_recording=False,
|
||||
video_recording_dir=None,
|
||||
):
|
||||
env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
|
||||
|
||||
@ -46,6 +48,8 @@ def make(
|
||||
'width': width,
|
||||
'camera_id': camera_id,
|
||||
'frame_skip': frame_skip,
|
||||
'video_recording': video_recording,
|
||||
'video_recording_dir': video_recording_dir,
|
||||
},
|
||||
max_episode_steps=max_episode_steps
|
||||
)
|
||||
|
@ -123,7 +123,7 @@ class RandomImageSource(ImageSource):
|
||||
|
||||
|
||||
class RandomVideoSource(ImageSource):
|
||||
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
||||
def __init__(self, shape, filelist, total_frames=None, grayscale=False, high_noise=False):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
@ -133,6 +133,7 @@ class RandomVideoSource(ImageSource):
|
||||
self.total_frames = total_frames
|
||||
self.shape = shape
|
||||
self.filelist = filelist
|
||||
self.high_noise = high_noise
|
||||
self.build_arr()
|
||||
self.current_idx = 0
|
||||
self.reset()
|
||||
@ -172,7 +173,10 @@ class RandomVideoSource(ImageSource):
|
||||
self.arr[total_frame_i] = cv2.resize(frames[frame_i], (self.shape[1], self.shape[0]))
|
||||
pbar.update(1)
|
||||
total_frame_i += 1
|
||||
|
||||
|
||||
# Randomize the order of the frames
|
||||
if self.high_noise:
|
||||
random.shuffle(self.arr)
|
||||
|
||||
def reset(self):
|
||||
self._loc = np.random.randint(0, self.total_frames)
|
||||
|
@ -6,9 +6,16 @@ from dm_env import specs
|
||||
import numpy as np
|
||||
import skimage.io
|
||||
|
||||
from video import VideoRecorder
|
||||
from dmc2gym import natural_imgsource
|
||||
|
||||
|
||||
high_noise = False
|
||||
|
||||
def set_global_var(set_high_noise):
|
||||
global high_noise
|
||||
high_noise = set_high_noise
|
||||
|
||||
def _spec_to_box(spec):
|
||||
def extract_min_max(s):
|
||||
assert s.dtype == np.float64 or s.dtype == np.float32
|
||||
@ -54,7 +61,9 @@ class DMCWrapper(core.Env):
|
||||
width=84,
|
||||
camera_id=0,
|
||||
frame_skip=1,
|
||||
environment_kwargs=None
|
||||
environment_kwargs=None,
|
||||
video_recording=False,
|
||||
video_recording_dir=None,
|
||||
):
|
||||
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
|
||||
self._from_pixels = from_pixels
|
||||
@ -99,6 +108,10 @@ class DMCWrapper(core.Env):
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
# video recording
|
||||
if video_recording:
|
||||
self.video = VideoRecorder(video_recording_dir+"/video", resource_files=resource_files, high_noise=high_noise)
|
||||
|
||||
# background
|
||||
if img_source is not None:
|
||||
shape2d = (height, width)
|
||||
@ -114,7 +127,7 @@ class DMCWrapper(core.Env):
|
||||
if img_source == "images":
|
||||
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
||||
elif img_source == "video":
|
||||
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
||||
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames, high_noise=high_noise)
|
||||
else:
|
||||
raise Exception("img_source %s not defined." % img_source)
|
||||
|
||||
|
15
DPI/train.py
15
DPI/train.py
@ -13,6 +13,7 @@ from utils import ReplayBuffer, make_env, save_image
|
||||
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
|
||||
from logger import Logger
|
||||
from video import VideoRecorder
|
||||
from dmc2gym.wrappers import set_global_var
|
||||
|
||||
#from agent.baseline_agent import BaselineAgent
|
||||
#from agent.bisim_agent import BisimAgent
|
||||
@ -32,7 +33,8 @@ def parse_args():
|
||||
parser.add_argument('--resource_files', type=str)
|
||||
parser.add_argument('--eval_resource_files', type=str)
|
||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||
parser.add_argument('--total_frames', default=1000, type=int)
|
||||
parser.add_argument('--total_frames', default=10000, type=int)
|
||||
parser.add_argument('--high_noise', action='store_true')
|
||||
# replay buffer
|
||||
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
|
||||
parser.add_argument('--episode_length', default=50, type=int)
|
||||
@ -103,6 +105,9 @@ class DPI:
|
||||
|
||||
self.args = args
|
||||
|
||||
# set environment noise
|
||||
set_global_var(self.args.high_noise)
|
||||
|
||||
# environment setup
|
||||
self.env = make_env(self.args)
|
||||
self.env.seed(self.args.seed)
|
||||
@ -168,9 +173,9 @@ class DPI:
|
||||
obs = self.env.reset()
|
||||
done = False
|
||||
|
||||
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||
for episode_count in range(episodes):
|
||||
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||
video.init(enabled=True)
|
||||
self.env.video.init(enabled=True)
|
||||
for i in range(self.args.episode_length):
|
||||
action = self.env.action_space.sample()
|
||||
next_obs, _, done, _ = self.env.step(action)
|
||||
@ -178,14 +183,14 @@ class DPI:
|
||||
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
||||
|
||||
if args.save_video:
|
||||
video.record(self.env)
|
||||
self.env.video.record(self.env)
|
||||
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
done=False
|
||||
else:
|
||||
obs = next_obs
|
||||
video.save('%d.mp4' % episode_count)
|
||||
self.env.video.save('%d.mp4' % episode_count)
|
||||
print("Collected {} random episodes".format(episode_count+1))
|
||||
#if args.save_video:
|
||||
# video.record(self.env)
|
||||
|
31
DPI/utils.py
31
DPI/utils.py
@ -5,16 +5,18 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import gym
|
||||
import dmc2gym
|
||||
|
||||
import random
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from collections import deque
|
||||
|
||||
|
||||
class eval_mode(object):
|
||||
@ -186,7 +188,9 @@ def make_env(args):
|
||||
from_pixels=(args.encoder_type == 'pixel'),
|
||||
height=args.image_size,
|
||||
width=args.image_size,
|
||||
frame_skip=args.action_repeat
|
||||
frame_skip=args.action_repeat,
|
||||
video_recording=args.save_video,
|
||||
video_recording_dir=args.work_dir,
|
||||
)
|
||||
return env
|
||||
|
||||
@ -194,4 +198,21 @@ def save_image(array, filename):
|
||||
array = array.transpose(1, 2, 0)
|
||||
array = (array * 255).astype(np.uint8)
|
||||
image = Image.fromarray(array)
|
||||
image.save(filename)
|
||||
image.save(filename)
|
||||
|
||||
def video_from_array(arr, high_noise, filename):
|
||||
"""
|
||||
Save a video from a numpy array of shape (T, H, W, C)
|
||||
Example:
|
||||
video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4')
|
||||
"""
|
||||
if arr.shape[-1] == 1:
|
||||
height, width, channels = arr.shape[1:]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height))
|
||||
for i in range(arr.shape[0]):
|
||||
frame = arr[i]
|
||||
frame = np.uint8(frame)
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
|
||||
out.write(frame)
|
||||
out.release()
|
@ -13,7 +13,7 @@ from dmc2gym.natural_imgsource import RandomVideoSource
|
||||
|
||||
|
||||
class VideoRecorder(object):
|
||||
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30):
|
||||
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30, high_noise=False):
|
||||
self.dir_name = dir_name
|
||||
self.height = height
|
||||
self.width = width
|
||||
@ -23,7 +23,7 @@ class VideoRecorder(object):
|
||||
self.resource_files = resource_files
|
||||
if resource_files:
|
||||
files = glob.glob(os.path.expanduser(resource_files))
|
||||
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000)
|
||||
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000, high_noise=high_noise)
|
||||
else:
|
||||
self._bg_source = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user