Adding Rewards
This commit is contained in:
parent
ac714e3495
commit
cc48b0b0f8
@ -120,15 +120,17 @@ class ReplayBuffer:
|
||||
self.args = args
|
||||
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
||||
self.actions = np.empty((size, action_size), dtype=np.float32)
|
||||
self.rewards = np.empty((size,1), dtype=np.float32)
|
||||
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
||||
self.episode_count = np.zeros((size,), dtype=np.uint8)
|
||||
self.terminals = np.empty((size,), dtype=np.float32)
|
||||
self.steps, self.episodes = 0, 0
|
||||
|
||||
def add(self, obs, ac, next_obs, episode_count, done):
|
||||
def add(self, obs, ac, next_obs, rew, episode_count, done):
|
||||
self.observations[self.idx] = obs
|
||||
self.actions[self.idx] = ac
|
||||
self.next_observations[self.idx] = next_obs
|
||||
self.rewards[self.idx] = rew
|
||||
self.episode_count[self.idx] = episode_count
|
||||
self.terminals[self.idx] = done
|
||||
self.idx = (self.idx + 1) % self.size
|
||||
|
Loading…
Reference in New Issue
Block a user