Changing variable reshaping strategy

This commit is contained in:
Vedant Dave 2023-04-09 18:22:12 +02:00
parent ada3cadf0c
commit 5caea7695a

View File

@ -161,11 +161,11 @@ class ReplayBuffer:
non_zero_indices = np.nonzero(buffer.episode_count)[0] non_zero_indices = np.nonzero(buffer.episode_count)[0]
variable = variable[non_zero_indices] variable = variable[non_zero_indices]
if obs: if obs:
variable = variable.reshape(self.args.episode_length, self.args.batch_size, variable = variable.reshape(self.args.batch_size, self.args.episode_length,
self.args.frame_stack*self.args.channels, self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size) self.args.image_size,self.args.image_size).transpose(1, 0, 2, 3, 4)
else: else:
variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1) variable = variable.reshape(self.args.batch_size, self.args.episode_length, -1).transpose(1, 0, 2)
return variable return variable
def transform_grouped_steps(self, variable): def transform_grouped_steps(self, variable):