Curiosity/DPI/utils.py

355 lines
12 KiB
Python

import os
import random
import numpy as np
from collections import deque
import torch
import torch.nn as nn
import gym
import dmc2gym
import cv2
from PIL import Image
from typing import Iterable
class eval_mode(object):
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(False)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def module_hash(module):
result = 0
for tensor in module.state_dict().values():
result += tensor.sum().item()
return result
def make_dir(dir_path):
try:
os.mkdir(dir_path)
except OSError:
pass
return dir_path
def preprocess_obs(obs, bits=5):
"""Preprocessing image, see https://arxiv.org/abs/1807.03039."""
bins = 2**bits
assert obs.dtype == torch.float32
if bits < 8:
obs = torch.floor(obs / 2**(8 - bits))
obs = obs / bins
obs = obs + torch.rand_like(obs) / bins
obs = obs - 0.5
return obs
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self._k = k
self._frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0,
high=1,
shape=((shp[0] * k,) + shp[1:]),
dtype=env.observation_space.dtype
)
self._max_episode_steps = env._max_episode_steps
def reset(self):
obs = self.env.reset()
for _ in range(self._k):
self._frames.append(obs)
return self._get_obs()
def step(self, action):
obs, reward, done, info = self.env.step(action)
self._frames.append(obs)
return self._get_obs(), reward, done, info
def _get_obs(self):
assert len(self._frames) == self._k
return np.concatenate(list(self._frames), axis=0)
class ActionRepeat:
def __init__(self, env, amount):
self._env = env
self._amount = amount
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
done = False
total_reward = 0
current_step = 0
while current_step < self._amount and not done:
obs, reward, done, info = self._env.step(action)
total_reward += reward
current_step += 1
return obs, total_reward, done, info
class NormalizeActions:
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low),
np.isfinite(env.action_space.high))
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.args = args
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,1), dtype=np.float32)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.episode_count = np.zeros((size,), dtype=np.uint8)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
def add(self, obs, ac, next_obs, rew, episode_count, done):
self.observations[self.idx] = obs
self.actions[self.idx] = ac
self.next_observations[self.idx] = next_obs
self.rewards[self.idx] = rew
self.episode_count[self.idx] = episode_count
self.terminals[self.idx] = done
self.idx = (self.idx + 1) % self.size
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
def _sample_idx(self, L):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.size if self.full else self.idx - L)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = not self.idx in idxs[1:]
return idxs
def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs]
next_observations = self.next_observations[vec_idxs]
return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), observations.reshape(L, n, *next_observations.shape[1:]), \
self.rewards[vec_idxs].reshape(L, n), self.terminals[vec_idxs].reshape(L, n)
def sample(self):
n = self.batch_size
l = self.seq_len
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,rews,terms
def group_steps(self, buffer, variable, obs=True):
variable = getattr(buffer, variable)
non_zero_indices = np.nonzero(buffer.episode_count)[0]
variable = variable[non_zero_indices]
if obs:
variable = variable.reshape(-1, self.args.episode_length,
self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size).transpose(1, 0, 2, 3, 4)
else:
variable = variable.reshape(variable.shape[0]//self.args.episode_length, self.args.episode_length, -1).transpose(1, 0, 2)
return variable
def transform_grouped_steps(self, variable):
variable = variable.transpose((1, 0, 2, 3, 4))
variable = variable.reshape(self.args.batch_size*self.args.episode_length,self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size)
return variable
def sample_random_idx(self, buffer_length):
random_indices = random.sample(range(0, buffer_length), self.args.batch_size)
return random_indices
def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0):
if offset == 0:
variable_tensor = torch.tensor(self.group_steps(buffer,variable_name, is_obs)).float()[:self.args.episode_length-1].to(device)
else:
variable_tensor = torch.tensor(self.group_steps(buffer,variable_name, is_obs)).float()[offset:].to(device)
return variable_tensor[:,random_indices,:,:,:] if is_obs else variable_tensor[:,random_indices,:]
def make_env(args):
# For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml,
# else change to (0.5, 0.5, 0.5, 1.0) for default ground plane color
# https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
resource_files=args.resource_files,
img_source=args.img_source,
total_frames=args.total_frames,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat,
video_recording=args.save_video,
video_recording_dir=args.work_dir,
version=args.version,
)
return env
def preprocess_obs(obs):
obs = obs/255.0 - 0.5
return obs
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
def save_image(array, filename):
array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8)
image = Image.fromarray(array)
image.save(filename)
def video_from_array(arr, high_noise, filename):
"""
Save a video from a numpy array of shape (T, H, W, C)
Example:
video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4')
"""
if arr.shape[-1] == 1:
height, width, channels = arr.shape[1:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height))
for i in range(arr.shape[0]):
frame = arr[i]
frame = np.uint8(frame)
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
out.write(frame)
out.release()
class CorruptVideos:
def __init__(self, dir_path):
self.dir_path = dir_path
def _is_video_corrupt(self,filepath):
"""
Check if a video file is corrupt.
Args:
dir_path (str): Path to the video file.
Returns:
bool: True if the video is corrupt, False otherwise.
"""
# Open the video file
cap = cv2.VideoCapture(filepath)
if not cap.isOpened():
return True
ret, frame = cap.read()
if not ret:
return True
cap.release()
return False
def _delete_corrupt_video(self, filepath):
os.remove(filepath)
def is_video_corrupt(self, delete=False):
for filename in os.listdir(self.dir_path):
filepath = os.path.join(self.dir_path, filename)
if filepath.endswith(".mp4"):
if self._is_video_corrupt(filepath):
print(f"{filepath} is corrupt.")
if delete:
self._delete_corrupt_video(filepath)
print(f"Deleted {filepath}")
def get_parameters(modules: Iterable[nn.Module]):
"""
Given a list of torch modules, returns a list of their parameters.
:param modules: iterable of modules
:returns: a list of parameters
"""
model_parameters = []
for module in modules:
model_parameters += list(module.parameters())
return model_parameters
class FreezeParameters:
def __init__(self, modules: Iterable[nn.Module]):
"""
Context manager to locally freeze gradients.
In some cases with can speed up computation because gradients aren't calculated for these listed modules.
example:
```
with FreezeParameters([module]):
output_tensor = module(input_tensor)
```
:param modules: iterable of modules. used to call .parameters() to freeze gradients.
"""
self.modules = modules
self.param_states = [p.requires_grad for p in get_parameters(self.modules)]
def __enter__(self):
for param in get_parameters(self.modules):
param.requires_grad = False
def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i]