diff --git a/DPI/utils.py b/DPI/utils.py index a237e6f..83e9b9d 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -120,15 +120,17 @@ class ReplayBuffer: self.args = args self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.actions = np.empty((size, action_size), dtype=np.float32) + self.rewards = np.empty((size,1), dtype=np.float32) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) self.episode_count = np.zeros((size,), dtype=np.uint8) self.terminals = np.empty((size,), dtype=np.float32) self.steps, self.episodes = 0, 0 - def add(self, obs, ac, next_obs, episode_count, done): + def add(self, obs, ac, next_obs, rew, episode_count, done): self.observations[self.idx] = obs self.actions[self.idx] = ac self.next_observations[self.idx] = next_obs + self.rewards[self.idx] = rew self.episode_count[self.idx] = episode_count self.terminals[self.idx] = done self.idx = (self.idx + 1) % self.size