import os import torch import numpy as np import torch.nn as nn import gym import dmc2gym import random from PIL import Image from collections import deque class eval_mode(object): def __init__(self, *models): self.models = models def __enter__(self): self.prev_states = [] for model in self.models: self.prev_states.append(model.training) model.train(False) def __exit__(self, *args): for model, state in zip(self.models, self.prev_states): model.train(state) return False def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( tau * param.data + (1 - tau) * target_param.data ) def set_seed_everywhere(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def module_hash(module): result = 0 for tensor in module.state_dict().values(): result += tensor.sum().item() return result def make_dir(dir_path): try: os.mkdir(dir_path) except OSError: pass return dir_path def preprocess_obs(obs, bits=5): """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" bins = 2**bits assert obs.dtype == torch.float32 if bits < 8: obs = torch.floor(obs / 2**(8 - bits)) obs = obs / bins obs = obs + torch.rand_like(obs) / bins obs = obs - 0.5 return obs class FrameStack(gym.Wrapper): def __init__(self, env, k): gym.Wrapper.__init__(self, env) self._k = k self._frames = deque([], maxlen=k) shp = env.observation_space.shape self.observation_space = gym.spaces.Box( low=0, high=1, shape=((shp[0] * k,) + shp[1:]), dtype=env.observation_space.dtype ) self._max_episode_steps = env._max_episode_steps def reset(self): obs = self.env.reset() for _ in range(self._k): self._frames.append(obs) return self._get_obs() def step(self, action): obs, reward, done, info = self.env.step(action) self._frames.append(obs) return self._get_obs(), reward, done, info def _get_obs(self): assert len(self._frames) == self._k return np.concatenate(list(self._frames), axis=0) class ReplayBuffer: def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args): self.size = size self.obs_shape = obs_shape self.action_size = action_size self.seq_len = seq_len self.batch_size = batch_size self.idx = 0 self.full = False self.args = args self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.actions = np.empty((size, action_size), dtype=np.float32) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) self.episode_count = np.zeros((size,), dtype=np.uint8) self.terminals = np.empty((size,), dtype=np.float32) self.steps, self.episodes = 0, 0 def add(self, obs, ac, next_obs, episode_count, done): self.observations[self.idx] = obs self.actions[self.idx] = ac self.next_observations[self.idx] = next_obs self.episode_count[self.idx] = episode_count self.terminals[self.idx] = done self.idx = (self.idx + 1) % self.size self.full = self.full or self.idx == 0 self.steps += 1 self.episodes = self.episodes + (1 if done else 0) def _sample_idx(self, L): valid_idx = False while not valid_idx: idx = np.random.randint(0, self.size if self.full else self.idx - L) idxs = np.arange(idx, idx + L) % self.size valid_idx = not self.idx in idxs[1:] return idxs def _retrieve_batch(self, idxs, n, L): vec_idxs = idxs.transpose().reshape(-1) # Unroll indices observations = self.observations[vec_idxs] next_observations = self.next_observations[vec_idxs] return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), observations.reshape(L, n, *next_observations.shape[1:]), \ self.rewards[vec_idxs].reshape(L, n), self.terminals[vec_idxs].reshape(L, n) def sample(self): n = self.batch_size l = self.seq_len obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) return obs,acs,rews,terms def group_steps(self, buffer, variable): variable = getattr(buffer, variable) non_zero_indices = np.nonzero(buffer.episode_count)[0] variable = variable[non_zero_indices] variable = variable.reshape(self.args.episode_length, self.args.batch_size, self.args.frame_stack*self.args.channels, self.args.image_size,self.args.image_size) return variable def transform_grouped_steps(self, variable): variable = variable.transpose((1, 0, 2, 3, 4)) variable = variable.reshape(self.args.batch_size*self.args.episode_length,self.args.frame_stack*self.args.channels, self.args.image_size,self.args.image_size) return variable def make_env(args): env = dmc2gym.make( domain_name=args.domain_name, task_name=args.task_name, resource_files=args.resource_files, img_source=args.img_source, total_frames=args.total_frames, seed=args.seed, visualize_reward=False, from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, frame_skip=args.action_repeat ) return env def save_image(array, filename): array = array.transpose(1, 2, 0) array = (array * 255).astype(np.uint8) image = Image.fromarray(array) image.save(filename)