Adding Rewards
This commit is contained in:
parent
ac714e3495
commit
cc48b0b0f8
@ -120,15 +120,17 @@ class ReplayBuffer:
|
|||||||
self.args = args
|
self.args = args
|
||||||
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
||||||
self.actions = np.empty((size, action_size), dtype=np.float32)
|
self.actions = np.empty((size, action_size), dtype=np.float32)
|
||||||
|
self.rewards = np.empty((size,1), dtype=np.float32)
|
||||||
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
||||||
self.episode_count = np.zeros((size,), dtype=np.uint8)
|
self.episode_count = np.zeros((size,), dtype=np.uint8)
|
||||||
self.terminals = np.empty((size,), dtype=np.float32)
|
self.terminals = np.empty((size,), dtype=np.float32)
|
||||||
self.steps, self.episodes = 0, 0
|
self.steps, self.episodes = 0, 0
|
||||||
|
|
||||||
def add(self, obs, ac, next_obs, episode_count, done):
|
def add(self, obs, ac, next_obs, rew, episode_count, done):
|
||||||
self.observations[self.idx] = obs
|
self.observations[self.idx] = obs
|
||||||
self.actions[self.idx] = ac
|
self.actions[self.idx] = ac
|
||||||
self.next_observations[self.idx] = next_obs
|
self.next_observations[self.idx] = next_obs
|
||||||
|
self.rewards[self.idx] = rew
|
||||||
self.episode_count[self.idx] = episode_count
|
self.episode_count[self.idx] = episode_count
|
||||||
self.terminals[self.idx] = done
|
self.terminals[self.idx] = done
|
||||||
self.idx = (self.idx + 1) % self.size
|
self.idx = (self.idx + 1) % self.size
|
||||||
|
Loading…
Reference in New Issue
Block a user