Changing variable reshaping strategy
This commit is contained in:
parent
ada3cadf0c
commit
5caea7695a
@ -161,11 +161,11 @@ class ReplayBuffer:
|
|||||||
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
||||||
variable = variable[non_zero_indices]
|
variable = variable[non_zero_indices]
|
||||||
if obs:
|
if obs:
|
||||||
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
|
variable = variable.reshape(self.args.batch_size, self.args.episode_length,
|
||||||
self.args.frame_stack*self.args.channels,
|
self.args.frame_stack*self.args.channels,
|
||||||
self.args.image_size,self.args.image_size)
|
self.args.image_size,self.args.image_size).transpose(1, 0, 2, 3, 4)
|
||||||
else:
|
else:
|
||||||
variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1)
|
variable = variable.reshape(self.args.batch_size, self.args.episode_length, -1).transpose(1, 0, 2)
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
def transform_grouped_steps(self, variable):
|
def transform_grouped_steps(self, variable):
|
||||||
|
Loading…
Reference in New Issue
Block a user