Adding high noise by randomising the frames
This commit is contained in:
parent
f2aa9baebb
commit
25c2853ba6
@ -16,7 +16,9 @@ def make(
|
|||||||
camera_id=0,
|
camera_id=0,
|
||||||
frame_skip=1,
|
frame_skip=1,
|
||||||
episode_length=1000,
|
episode_length=1000,
|
||||||
environment_kwargs=None
|
environment_kwargs=None,
|
||||||
|
video_recording=False,
|
||||||
|
video_recording_dir=None,
|
||||||
):
|
):
|
||||||
env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
|
env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
|
||||||
|
|
||||||
@ -46,6 +48,8 @@ def make(
|
|||||||
'width': width,
|
'width': width,
|
||||||
'camera_id': camera_id,
|
'camera_id': camera_id,
|
||||||
'frame_skip': frame_skip,
|
'frame_skip': frame_skip,
|
||||||
|
'video_recording': video_recording,
|
||||||
|
'video_recording_dir': video_recording_dir,
|
||||||
},
|
},
|
||||||
max_episode_steps=max_episode_steps
|
max_episode_steps=max_episode_steps
|
||||||
)
|
)
|
||||||
|
@ -123,7 +123,7 @@ class RandomImageSource(ImageSource):
|
|||||||
|
|
||||||
|
|
||||||
class RandomVideoSource(ImageSource):
|
class RandomVideoSource(ImageSource):
|
||||||
def __init__(self, shape, filelist, total_frames=None, grayscale=False):
|
def __init__(self, shape, filelist, total_frames=None, grayscale=False, high_noise=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
shape: [h, w]
|
shape: [h, w]
|
||||||
@ -133,6 +133,7 @@ class RandomVideoSource(ImageSource):
|
|||||||
self.total_frames = total_frames
|
self.total_frames = total_frames
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
self.filelist = filelist
|
self.filelist = filelist
|
||||||
|
self.high_noise = high_noise
|
||||||
self.build_arr()
|
self.build_arr()
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
self.reset()
|
self.reset()
|
||||||
@ -173,6 +174,9 @@ class RandomVideoSource(ImageSource):
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
total_frame_i += 1
|
total_frame_i += 1
|
||||||
|
|
||||||
|
# Randomize the order of the frames
|
||||||
|
if self.high_noise:
|
||||||
|
random.shuffle(self.arr)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._loc = np.random.randint(0, self.total_frames)
|
self._loc = np.random.randint(0, self.total_frames)
|
||||||
|
@ -6,9 +6,16 @@ from dm_env import specs
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import skimage.io
|
import skimage.io
|
||||||
|
|
||||||
|
from video import VideoRecorder
|
||||||
from dmc2gym import natural_imgsource
|
from dmc2gym import natural_imgsource
|
||||||
|
|
||||||
|
|
||||||
|
high_noise = False
|
||||||
|
|
||||||
|
def set_global_var(set_high_noise):
|
||||||
|
global high_noise
|
||||||
|
high_noise = set_high_noise
|
||||||
|
|
||||||
def _spec_to_box(spec):
|
def _spec_to_box(spec):
|
||||||
def extract_min_max(s):
|
def extract_min_max(s):
|
||||||
assert s.dtype == np.float64 or s.dtype == np.float32
|
assert s.dtype == np.float64 or s.dtype == np.float32
|
||||||
@ -54,7 +61,9 @@ class DMCWrapper(core.Env):
|
|||||||
width=84,
|
width=84,
|
||||||
camera_id=0,
|
camera_id=0,
|
||||||
frame_skip=1,
|
frame_skip=1,
|
||||||
environment_kwargs=None
|
environment_kwargs=None,
|
||||||
|
video_recording=False,
|
||||||
|
video_recording_dir=None,
|
||||||
):
|
):
|
||||||
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
|
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
|
||||||
self._from_pixels = from_pixels
|
self._from_pixels = from_pixels
|
||||||
@ -99,6 +108,10 @@ class DMCWrapper(core.Env):
|
|||||||
dtype=np.float32
|
dtype=np.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# video recording
|
||||||
|
if video_recording:
|
||||||
|
self.video = VideoRecorder(video_recording_dir+"/video", resource_files=resource_files, high_noise=high_noise)
|
||||||
|
|
||||||
# background
|
# background
|
||||||
if img_source is not None:
|
if img_source is not None:
|
||||||
shape2d = (height, width)
|
shape2d = (height, width)
|
||||||
@ -114,7 +127,7 @@ class DMCWrapper(core.Env):
|
|||||||
if img_source == "images":
|
if img_source == "images":
|
||||||
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
||||||
elif img_source == "video":
|
elif img_source == "video":
|
||||||
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames)
|
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=True, total_frames=total_frames, high_noise=high_noise)
|
||||||
else:
|
else:
|
||||||
raise Exception("img_source %s not defined." % img_source)
|
raise Exception("img_source %s not defined." % img_source)
|
||||||
|
|
||||||
|
15
DPI/train.py
15
DPI/train.py
@ -13,6 +13,7 @@ from utils import ReplayBuffer, make_env, save_image
|
|||||||
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
|
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
|
||||||
from logger import Logger
|
from logger import Logger
|
||||||
from video import VideoRecorder
|
from video import VideoRecorder
|
||||||
|
from dmc2gym.wrappers import set_global_var
|
||||||
|
|
||||||
#from agent.baseline_agent import BaselineAgent
|
#from agent.baseline_agent import BaselineAgent
|
||||||
#from agent.bisim_agent import BisimAgent
|
#from agent.bisim_agent import BisimAgent
|
||||||
@ -32,7 +33,8 @@ def parse_args():
|
|||||||
parser.add_argument('--resource_files', type=str)
|
parser.add_argument('--resource_files', type=str)
|
||||||
parser.add_argument('--eval_resource_files', type=str)
|
parser.add_argument('--eval_resource_files', type=str)
|
||||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||||
parser.add_argument('--total_frames', default=1000, type=int)
|
parser.add_argument('--total_frames', default=10000, type=int)
|
||||||
|
parser.add_argument('--high_noise', action='store_true')
|
||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
|
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
|
||||||
parser.add_argument('--episode_length', default=50, type=int)
|
parser.add_argument('--episode_length', default=50, type=int)
|
||||||
@ -103,6 +105,9 @@ class DPI:
|
|||||||
|
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
|
# set environment noise
|
||||||
|
set_global_var(self.args.high_noise)
|
||||||
|
|
||||||
# environment setup
|
# environment setup
|
||||||
self.env = make_env(self.args)
|
self.env = make_env(self.args)
|
||||||
self.env.seed(self.args.seed)
|
self.env.seed(self.args.seed)
|
||||||
@ -168,9 +173,9 @@ class DPI:
|
|||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
|
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||||
for episode_count in range(episodes):
|
for episode_count in range(episodes):
|
||||||
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
self.env.video.init(enabled=True)
|
||||||
video.init(enabled=True)
|
|
||||||
for i in range(self.args.episode_length):
|
for i in range(self.args.episode_length):
|
||||||
action = self.env.action_space.sample()
|
action = self.env.action_space.sample()
|
||||||
next_obs, _, done, _ = self.env.step(action)
|
next_obs, _, done, _ = self.env.step(action)
|
||||||
@ -178,14 +183,14 @@ class DPI:
|
|||||||
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
||||||
|
|
||||||
if args.save_video:
|
if args.save_video:
|
||||||
video.record(self.env)
|
self.env.video.record(self.env)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
done=False
|
done=False
|
||||||
else:
|
else:
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
video.save('%d.mp4' % episode_count)
|
self.env.video.save('%d.mp4' % episode_count)
|
||||||
print("Collected {} random episodes".format(episode_count+1))
|
print("Collected {} random episodes".format(episode_count+1))
|
||||||
#if args.save_video:
|
#if args.save_video:
|
||||||
# video.record(self.env)
|
# video.record(self.env)
|
||||||
|
29
DPI/utils.py
29
DPI/utils.py
@ -5,16 +5,18 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import dmc2gym
|
import dmc2gym
|
||||||
|
|
||||||
import random
|
import cv2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
class eval_mode(object):
|
class eval_mode(object):
|
||||||
@ -186,7 +188,9 @@ def make_env(args):
|
|||||||
from_pixels=(args.encoder_type == 'pixel'),
|
from_pixels=(args.encoder_type == 'pixel'),
|
||||||
height=args.image_size,
|
height=args.image_size,
|
||||||
width=args.image_size,
|
width=args.image_size,
|
||||||
frame_skip=args.action_repeat
|
frame_skip=args.action_repeat,
|
||||||
|
video_recording=args.save_video,
|
||||||
|
video_recording_dir=args.work_dir,
|
||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
@ -195,3 +199,20 @@ def save_image(array, filename):
|
|||||||
array = (array * 255).astype(np.uint8)
|
array = (array * 255).astype(np.uint8)
|
||||||
image = Image.fromarray(array)
|
image = Image.fromarray(array)
|
||||||
image.save(filename)
|
image.save(filename)
|
||||||
|
|
||||||
|
def video_from_array(arr, high_noise, filename):
|
||||||
|
"""
|
||||||
|
Save a video from a numpy array of shape (T, H, W, C)
|
||||||
|
Example:
|
||||||
|
video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4')
|
||||||
|
"""
|
||||||
|
if arr.shape[-1] == 1:
|
||||||
|
height, width, channels = arr.shape[1:]
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
|
out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height))
|
||||||
|
for i in range(arr.shape[0]):
|
||||||
|
frame = arr[i]
|
||||||
|
frame = np.uint8(frame)
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
|
||||||
|
out.write(frame)
|
||||||
|
out.release()
|
@ -13,7 +13,7 @@ from dmc2gym.natural_imgsource import RandomVideoSource
|
|||||||
|
|
||||||
|
|
||||||
class VideoRecorder(object):
|
class VideoRecorder(object):
|
||||||
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30):
|
def __init__(self, dir_name, resource_files=None, height=256, width=256, camera_id=0, fps=30, high_noise=False):
|
||||||
self.dir_name = dir_name
|
self.dir_name = dir_name
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
@ -23,7 +23,7 @@ class VideoRecorder(object):
|
|||||||
self.resource_files = resource_files
|
self.resource_files = resource_files
|
||||||
if resource_files:
|
if resource_files:
|
||||||
files = glob.glob(os.path.expanduser(resource_files))
|
files = glob.glob(os.path.expanduser(resource_files))
|
||||||
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000)
|
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000, high_noise=high_noise)
|
||||||
else:
|
else:
|
||||||
self._bg_source = None
|
self._bg_source = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user