# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import random import numpy as np from collections import deque import torch import torch.nn as nn import gym import dmc2gym import cv2 from PIL import Image class eval_mode(object): def __init__(self, *models): self.models = models def __enter__(self): self.prev_states = [] for model in self.models: self.prev_states.append(model.training) model.train(False) def __exit__(self, *args): for model, state in zip(self.models, self.prev_states): model.train(state) return False def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( tau * param.data + (1 - tau) * target_param.data ) def set_seed_everywhere(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def module_hash(module): result = 0 for tensor in module.state_dict().values(): result += tensor.sum().item() return result def make_dir(dir_path): try: os.mkdir(dir_path) except OSError: pass return dir_path def preprocess_obs(obs, bits=5): """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" bins = 2**bits assert obs.dtype == torch.float32 if bits < 8: obs = torch.floor(obs / 2**(8 - bits)) obs = obs / bins obs = obs + torch.rand_like(obs) / bins obs = obs - 0.5 return obs class FrameStack(gym.Wrapper): def __init__(self, env, k): gym.Wrapper.__init__(self, env) self._k = k self._frames = deque([], maxlen=k) shp = env.observation_space.shape self.observation_space = gym.spaces.Box( low=0, high=1, shape=((shp[0] * k,) + shp[1:]), dtype=env.observation_space.dtype ) self._max_episode_steps = env._max_episode_steps def reset(self): obs = self.env.reset() for _ in range(self._k): self._frames.append(obs) return self._get_obs() def step(self, action): obs, reward, done, info = self.env.step(action) self._frames.append(obs) return self._get_obs(), reward, done, info def _get_obs(self): assert len(self._frames) == self._k return np.concatenate(list(self._frames), axis=0) class ReplayBuffer: def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): self.size = size self.obs_shape = obs_shape self.action_size = action_size self.seq_len = seq_len self.batch_size = batch_size self.idx = 0 self.full = False self.args = args self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.actions = np.empty((size, action_size), dtype=np.float32) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) self.episode_count = np.zeros((size,), dtype=np.uint8) self.terminals = np.empty((size,), dtype=np.float32) self.steps, self.episodes = 0, 0 def add(self, obs, ac, next_obs, episode_count, done): self.observations[self.idx] = obs self.actions[self.idx] = ac self.next_observations[self.idx] = next_obs self.episode_count[self.idx] = episode_count self.terminals[self.idx] = done self.idx = (self.idx + 1) % self.size self.full = self.full or self.idx == 0 self.steps += 1 self.episodes = self.episodes + (1 if done else 0) def _sample_idx(self, L): valid_idx = False while not valid_idx: idx = np.random.randint(0, self.size if self.full else self.idx - L) idxs = np.arange(idx, idx + L) % self.size valid_idx = not self.idx in idxs[1:] return idxs def _retrieve_batch(self, idxs, n, L): vec_idxs = idxs.transpose().reshape(-1) # Unroll indices observations = self.observations[vec_idxs] next_observations = self.next_observations[vec_idxs] return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), observations.reshape(L, n, *next_observations.shape[1:]), \ self.rewards[vec_idxs].reshape(L, n), self.terminals[vec_idxs].reshape(L, n) def sample(self): n = self.batch_size l = self.seq_len obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) return obs,acs,rews,terms def group_steps(self, buffer, variable, obs=True): variable = getattr(buffer, variable) non_zero_indices = np.nonzero(buffer.episode_count)[0] variable = variable[non_zero_indices] if obs: variable = variable.reshape(self.args.episode_length, self.args.batch_size, self.args.frame_stack*self.args.channels, self.args.image_size,self.args.image_size) else: variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1) return variable def transform_grouped_steps(self, variable): variable = variable.transpose((1, 0, 2, 3, 4)) variable = variable.reshape(self.args.batch_size*self.args.episode_length,self.args.frame_stack*self.args.channels, self.args.image_size,self.args.image_size) return variable def make_env(args): # For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml, # else change to (0.5, 0.5, 0.5, 1.0) for default ground plane color # https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom env = dmc2gym.make( domain_name=args.domain_name, task_name=args.task_name, resource_files=args.resource_files, img_source=args.img_source, total_frames=args.total_frames, seed=args.seed, visualize_reward=False, from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, frame_skip=args.action_repeat, video_recording=args.save_video, video_recording_dir=args.work_dir, ) return env def save_image(array, filename): array = array.transpose(1, 2, 0) array = (array * 255).astype(np.uint8) image = Image.fromarray(array) image.save(filename) def video_from_array(arr, high_noise, filename): """ Save a video from a numpy array of shape (T, H, W, C) Example: video_from_array(np.random.rand(100, 64, 64, 1), 'test.mp4') """ if arr.shape[-1] == 1: height, width, channels = arr.shape[1:] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter('output.mp4', fourcc, 30.0, (width, height)) for i in range(arr.shape[0]): frame = arr[i] frame = np.uint8(frame) frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) out.write(frame) out.release() class CorruptVideos: def __init__(self, dir_path): self.dir_path = dir_path def _is_video_corrupt(self,filepath): """ Check if a video file is corrupt. Args: dir_path (str): Path to the video file. Returns: bool: True if the video is corrupt, False otherwise. """ # Open the video file cap = cv2.VideoCapture(filepath) if not cap.isOpened(): return True ret, frame = cap.read() if not ret: return True cap.release() return False def _delete_corrupt_video(self, filepath): os.remove(filepath) def is_video_corrupt(self, delete=False): for filename in os.listdir(self.dir_path): filepath = os.path.join(self.dir_path, filename) if filepath.endswith(".mp4"): if self._is_video_corrupt(filepath): print(f"{filepath} is corrupt.") if delete: self._delete_corrupt_video(filepath) print(f"Deleted {filepath}")