Adding Rewards

This commit is contained in:
Vedant Dave 2023-04-12 09:33:19 +02:00
parent ac714e3495
commit cc48b0b0f8

View File

@ -120,15 +120,17 @@ class ReplayBuffer:
self.args = args self.args = args
self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32) self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,1), dtype=np.float32)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.episode_count = np.zeros((size,), dtype=np.uint8) self.episode_count = np.zeros((size,), dtype=np.uint8)
self.terminals = np.empty((size,), dtype=np.float32) self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0 self.steps, self.episodes = 0, 0
def add(self, obs, ac, next_obs, episode_count, done): def add(self, obs, ac, next_obs, rew, episode_count, done):
self.observations[self.idx] = obs self.observations[self.idx] = obs
self.actions[self.idx] = ac self.actions[self.idx] = ac
self.next_observations[self.idx] = next_obs self.next_observations[self.idx] = next_obs
self.rewards[self.idx] = rew
self.episode_count[self.idx] = episode_count self.episode_count[self.idx] = episode_count
self.terminals[self.idx] = done self.terminals[self.idx] = done
self.idx = (self.idx + 1) % self.size self.idx = (self.idx + 1) % self.size