import torch import numpy as np import torch.nn as nn import gym import os from collections import deque import random class eval_mode(object): def __init__(self, *models): self.models = models def __enter__(self): self.prev_states = [] for model in self.models: self.prev_states.append(model.training) model.train(False) def __exit__(self, *args): for model, state in zip(self.models, self.prev_states): model.train(state) return False def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( tau * param.data + (1 - tau) * target_param.data ) def set_seed_everywhere(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def module_hash(module): result = 0 for tensor in module.state_dict().values(): result += tensor.sum().item() return result def make_dir(dir_path): try: os.mkdir(dir_path) except OSError: pass return dir_path def preprocess_obs(obs, bits=5): """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" bins = 2**bits assert obs.dtype == torch.float32 if bits < 8: obs = torch.floor(obs / 2**(8 - bits)) obs = obs / bins obs = obs + torch.rand_like(obs) / bins obs = obs - 0.5 return obs class ReplayBuffer(object): """Buffer to store environment transitions.""" def __init__(self, obs_shape, action_shape, capacity, batch_size, device): self.capacity = capacity self.batch_size = batch_size self.device = device # the proprioceptive obs is stored as float32, pixels obs as uint8 obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 self.last_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.curr_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.last_actions = np.empty((capacity, *action_shape), dtype=np.float32) self.actions = np.empty((capacity, *action_shape), dtype=np.float32) self.last_rewards = np.empty((capacity, 1), dtype=np.float32) self.rewards = np.empty((capacity, 1), dtype=np.float32) self.last_not_dones = np.empty((capacity, 1), dtype=np.float32) self.not_dones = np.empty((capacity, 1), dtype=np.float32) self.idx = 0 self.last_save = 0 self.full = False def add(self, last_obs, last_action, last_reward, curr_obs, last_done, action, reward, next_obs, done): np.copyto(self.last_obses[self.idx], last_obs) np.copyto(self.last_actions[self.idx], last_action) np.copyto(self.last_rewards[self.idx], last_reward) np.copyto(self.curr_obses[self.idx], curr_obs) np.copyto(self.last_not_dones[self.idx], not last_done) np.copyto(self.actions[self.idx], action) np.copyto(self.rewards[self.idx], reward) np.copyto(self.next_obses[self.idx], next_obs) np.copyto(self.not_dones[self.idx], not done) self.idx = (self.idx + 1) % self.capacity self.full = self.full or self.idx == 0 def sample(self): idxs = np.random.randint( 0, self.capacity if self.full else self.idx, size=self.batch_size ) last_obses = torch.as_tensor(self.last_obses[idxs], device=self.device).float() last_actions = torch.as_tensor(self.last_actions[idxs], device=self.device) last_rewards = torch.as_tensor(self.last_rewards[idxs], device=self.device) curr_obses = torch.as_tensor(self.curr_obses[idxs], device=self.device).float() last_not_dones = torch.as_tensor(self.last_not_dones[idxs], device=self.device) actions = torch.as_tensor(self.actions[idxs], device=self.device) rewards = torch.as_tensor(self.rewards[idxs], device=self.device) next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device).float() not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) return last_obses, last_actions, last_rewards, curr_obses, last_not_dones, actions, rewards, next_obses, not_dones def save(self, save_dir): if self.idx == self.last_save: return path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx)) payload = [ self.last_obses[self.last_save:self.idx], self.last_actions[self.last_save:self.idx], self.last_rewards[self.last_save:self.idx], self.curr_obses[self.last_save:self.idx], self.last_not_dones[self.last_save:self.idx], self.actions[self.last_save:self.idx], self.rewards[self.last_save:self.idx], self.next_obses[self.last_save:self.idx], self.not_dones[self.last_save:self.idx] ] self.last_save = self.idx torch.save(payload, path) def load(self, save_dir): chunks = os.listdir(save_dir) chucks = sorted(chunks, key=lambda x: int(x.split('_')[0])) for chunk in chucks: start, end = [int(x) for x in chunk.split('.')[0].split('_')] path = os.path.join(save_dir, chunk) payload = torch.load(path) assert self.idx == start self.last_obses[start:end] = payload[0] self.last_actions[start:end] = payload[1] self.last_rewards[start:end] = payload[2] self.curr_obses[start:end] = payload[3] self.last_not_dones[start:end] = payload[4] self.actions[start:end] = payload[2] self.rewards[start:end] = payload[3] self.next_obses[start:end] = payload[4] self.not_dones[start:end] = payload[4] self.idx = end class FrameStack(gym.Wrapper): def __init__(self, env, k): gym.Wrapper.__init__(self, env) self._k = k self._frames = deque([], maxlen=k) shp = env.observation_space.shape self.observation_space = gym.spaces.Box( low=0, high=1, shape=((shp[0] * k,) + shp[1:]), dtype=env.observation_space.dtype ) self._max_episode_steps = env._max_episode_steps def reset(self): obs = self.env.reset() for _ in range(self._k): self._frames.append(obs) return self._get_obs() def step(self, action): obs, reward, done, info = self.env.step(action) self._frames.append(obs) return self._get_obs(), reward, done, info def _get_obs(self): assert len(self._frames) == self._k return np.concatenate(list(self._frames), axis=0)