import os import random import pickle import numpy as np from collections import deque import torch import torch.nn as nn from torch.utils.tensorboard import SummaryWriter import gym import dmc2gym import cv2 from PIL import Image from typing import Iterable class eval_mode(object): def __init__(self, *models): self.models = models def __enter__(self): self.prev_states = [] for model in self.models: self.prev_states.append(model.training) model.train(False) def __exit__(self, *args): for model, state in zip(self.models, self.prev_states): model.train(state) return False def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( tau * param.data + (1 - tau) * target_param.data ) def set_seed_everywhere(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def module_hash(module): result = 0 for tensor in module.state_dict().values(): result += tensor.sum().item() return result def make_dir(dir_path): try: os.mkdir(dir_path) except OSError: pass return dir_path def preprocess_obs(obs, bits=5): """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" bins = 2**bits assert obs.dtype == torch.float32 if bits < 8: obs = torch.floor(obs / 2**(8 - bits)) obs = obs / bins obs = obs + torch.rand_like(obs) / bins obs = obs - 0.5 return obs class FrameStack(gym.Wrapper): def __init__(self, env, k): gym.Wrapper.__init__(self, env) self._k = k self._frames = deque([], maxlen=k) shp = env.observation_space.shape self.observation_space = gym.spaces.Box( low=0, high=1, shape=((shp[0] * k,) + shp[1:]), dtype=env.observation_space.dtype ) self._max_episode_steps = env._max_episode_steps def reset(self): obs = self.env.reset() for _ in range(self._k): self._frames.append(obs) return self._get_obs() def step(self, action): obs, reward, done, info = self.env.step(action) self._frames.append(obs) return self._get_obs(), reward, done, info def _get_obs(self): assert len(self._frames) == self._k return np.concatenate(list(self._frames), axis=0) class ActionRepeat: def __init__(self, env, amount): self._env = env self._amount = amount def __getattr__(self, name): return getattr(self._env, name) def step(self, action): done = False total_reward = 0 current_step = 0 while current_step < self._amount and not done: obs, reward, done, info = self._env.step(action) total_reward += reward current_step += 1 return obs, total_reward, done, info class NormalizeActions: def __init__(self, env): self._env = env self._mask = np.logical_and( np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)) self._low = np.where(self._mask, env.action_space.low, -1) self._high = np.where(self._mask, env.action_space.high, 1) def __getattr__(self, name): return getattr(self._env, name) @property def action_space(self): low = np.where(self._mask, -np.ones_like(self._low), self._low) high = np.where(self._mask, np.ones_like(self._low), self._high) return gym.spaces.Box(low, high, dtype=np.float32) def step(self, action): original = (action + 1) / 2 * (self._high - self._low) + self._low original = np.where(self._mask, original, action) return self._env.step(original) class TimeLimit: def __init__(self, env, duration): self._env = env self._duration = duration self._step = None def __getattr__(self, name): return getattr(self._env, name) def step(self, action): assert self._step is not None, 'Must reset environment.' obs, reward, done, info = self._env.step(action) self._step += 1 if self._step >= self._duration: done = True if 'discount' not in info: info['discount'] = np.array(1.0).astype(np.float32) self._step = None return obs, reward, done, info def reset(self): self._step = 0 return self._env.reset() class ReplayBuffer: def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): self.size = size self.obs_shape = obs_shape self.action_size = action_size self.seq_len = seq_len self.batch_size = batch_size self.idx = 0 self.full = False self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) self.actions = np.empty((size, action_size), dtype=np.float32) self.rewards = np.empty((size,), dtype=np.float32) self.terminals = np.empty((size,), dtype=np.float32) self.steps, self.episodes = 0, 0 self.episode_count = np.zeros((size,), dtype=np.int32) def add(self, obs, ac, next_obs, rew, done, episode_count): self.observations[self.idx] = obs self.next_observations[self.idx] = next_obs self.actions[self.idx] = ac self.rewards[self.idx] = rew self.terminals[self.idx] = done self.full = self.full or self.idx == 0 self.steps += 1 self.episodes = self.episodes + (1 if done else 0) self.episode_count[self.idx] = episode_count self.idx = (self.idx + 1) % self.size def _sample_idx(self, L): valid_idx = False while not valid_idx: idx = np.random.randint(0, self.size if self.full else self.idx - L) idxs = np.arange(idx, idx + L) % self.size valid_idx = not self.idx in idxs[1:] return idxs def _retrieve_batch(self, idxs, n, L): vec_idxs = idxs.transpose().reshape(-1) # Unroll indices observations = self.observations[vec_idxs] next_obs = self.next_observations[vec_idxs] obs = observations.reshape(L, n, *observations.shape[1:]) next_obs = next_obs.reshape(L, n, *next_obs.shape[1:]) acs = self.actions[vec_idxs].reshape(L, n, -1) rew = self.rewards[vec_idxs].reshape(L, n) term = self.terminals[vec_idxs].reshape(L, n) return obs, acs, next_obs, rew, term def sample(self): n = self.batch_size l = self.seq_len obs,acs,next_obs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) return obs,acs,next_obs,rews,terms class ReplayBuffer1: def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): self.size = size self.obs_shape = obs_shape self.action_size = action_size self.seq_len = seq_len self.batch_size = batch_size self.idx = 0 self.full = False self.args = args self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.actions = np.empty((size, action_size), dtype=np.float32) self.rewards = np.empty((size,1), dtype=np.float32) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) self.episode_count = np.zeros((size,), dtype=np.uint8) self.terminals = np.empty((size,), dtype=np.float32) self.steps, self.episodes = 0, 0 def add(self, obs, ac, next_obs, rew, episode_count, done): self.observations[self.idx] = obs self.actions[self.idx] = ac self.next_observations[self.idx] = next_obs self.rewards[self.idx] = rew self.episode_count[self.idx] = episode_count self.terminals[self.idx] = done self.idx = (self.idx + 1) % self.size self.full = self.full or self.idx == 0 self.steps += 1 self.episodes = self.episodes + (1 if done else 0) def _sample_idx(self, L): valid_idx = False while not valid_idx: idx = np.random.randint(0, self.size if self.full else self.idx - L) idxs = np.arange(idx, idx + L) % self.size valid_idx = not self.idx in idxs[1:] return idxs def _retrieve_batch(self, idxs, n, L): vec_idxs = idxs.transpose().reshape(-1) # Unroll indices observations = self.observations[vec_idxs] next_observations = self.next_observations[vec_idxs] return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), observations.reshape(L, n, *next_observations.shape[1:]), \ self.rewards[vec_idxs].reshape(L, n), self.terminals[vec_idxs].reshape(L, n) def sample(self): n = self.batch_size l = self.seq_len obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) return obs,acs,rews,terms def group_steps(self, buffer, variable, obs=True): variable = getattr(buffer, variable) non_zero_indices = np.nonzero(buffer.episode_count)[0] print(buffer.episode_count) variable = variable[non_zero_indices] print(variable.shape) exit() if obs: variable = variable.reshape(-1, self.args.episode_length, self.args.frame_stack*self.args.channels, self.args.image_size,self.args.image_size).transpose(1, 0, 2, 3, 4) else: variable = variable.reshape(variable.shape[0]//self.args.episode_length, self.args.episode_length, -1).transpose(1, 0, 2) return variable def transform_grouped_steps(self, variable): variable = variable.transpose((1, 0, 2, 3, 4)) variable = variable.reshape(self.args.batch_size*self.args.episode_length,self.args.frame_stack*self.args.channels, self.args.image_size,self.args.image_size) return variable def sample_random_idx(self, buffer_length, last=False): init = 0 if last else buffer_length - self.args.batch_size random_indices = random.sample(range(init, buffer_length), self.args.batch_size) return random_indices def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0): if offset == 0: variable_tensor = torch.tensor(self.group_steps(buffer,variable_name, is_obs)).float()[:self.args.episode_length-1].to(device) else: variable_tensor = torch.tensor(self.group_steps(buffer,variable_name, is_obs)).float()[offset:].to(device) return variable_tensor[:,random_indices,:,:,:] if is_obs else variable_tensor[:,random_indices,:] def make_env(args): # For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml, # else change to (0.5, 0.5, 0.5, 1.0) for default ground plane color # https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom env = dmc2gym.make( domain_name=args.domain_name, task_name=args.task_name, resource_files=args.resource_files, img_source=args.img_source, total_frames=args.total_frames, seed=args.seed, visualize_reward=False, from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, frame_skip=args.action_repeat, video_recording=args.save_video, video_recording_dir=args.work_dir, version=args.version, ) return env def shuffle_along_axis(a, axis): idx = np.random.rand(*a.shape).argsort(axis=axis) return np.take_along_axis(a,idx,axis=axis) def preprocess_obs(obs): obs = (obs/255.0) - 0.5 return obs def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( tau * param.detach().data + (1 - tau) * target_param.data ) def save_image(array, filename): array = array.transpose(1, 2, 0) array = ((array+0.5) * 255).astype(np.uint8) image = Image.fromarray(array) image.save(filename) def video_from_array(arr, high_noise, filename): """ Save a video from a numpy array of shape (T, H, W, C) Example: video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4') """ if arr.shape[-1] == 1: height, width, channels = arr.shape[1:] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height)) for i in range(arr.shape[0]): frame = arr[i] frame = np.uint8(frame) frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) out.write(frame) out.release() class CorruptVideos: def __init__(self, dir_path): self.dir_path = dir_path def _is_video_corrupt(self,filepath): """ Check if a video file is corrupt. Args: dir_path (str): Path to the video file. Returns: bool: True if the video is corrupt, False otherwise. """ # Open the video file cap = cv2.VideoCapture(filepath) if not cap.isOpened(): return True ret, frame = cap.read() if not ret: return True cap.release() return False def _delete_corrupt_video(self, filepath): os.remove(filepath) def is_video_corrupt(self, delete=False): for filename in os.listdir(self.dir_path): filepath = os.path.join(self.dir_path, filename) if filepath.endswith(".mp4"): if self._is_video_corrupt(filepath): print(f"{filepath} is corrupt.") if delete: self._delete_corrupt_video(filepath) print(f"Deleted {filepath}") def get_parameters(modules: Iterable[nn.Module]): """ Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters """ model_parameters = [] for module in modules: model_parameters += list(module.parameters()) return model_parameters class FreezeParameters: def __init__(self, modules: Iterable[nn.Module]): """ Context manager to locally freeze gradients. In some cases with can speed up computation because gradients aren't calculated for these listed modules. example: ``` with FreezeParameters([module]): output_tensor = module(input_tensor) ``` :param modules: iterable of modules. used to call .parameters() to freeze gradients. """ self.modules = modules self.param_states = [p.requires_grad for p in get_parameters(self.modules)] def __enter__(self): for param in get_parameters(self.modules): param.requires_grad = False def __exit__(self, exc_type, exc_val, exc_tb): for i, param in enumerate(get_parameters(self.modules)): param.requires_grad = self.param_states[i] class Logger: def __init__(self, log_dir, n_logged_samples=10, summary_writer=None): self._log_dir = log_dir print('########################') print('logging outputs to ', log_dir) print('########################') self._n_logged_samples = n_logged_samples self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1) def log_scalar(self, scalar, name, step_): self._summ_writer.add_scalar('{}'.format(name), scalar, step_) def log_scalars(self, scalar_dict, step): for key, value in scalar_dict.items(): print('{} : {}'.format(key, value)) self.log_scalar(value, key, step) self.dump_scalars_to_pickle(scalar_dict, step) def log_videos(self, videos, step, max_videos_to_save=1, fps=20, video_title='video'): # max rollout length max_videos_to_save = np.min([max_videos_to_save, videos.shape[0]]) max_length = videos[0].shape[0] for i in range(max_videos_to_save): if videos[i].shape[0]>max_length: max_length = videos[i].shape[0] # pad rollouts to all be same length for i in range(max_videos_to_save): if videos[i].shape[0]