From 5caea7695a7dd80bff129ecdac83ad9a70a97931 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Sun, 9 Apr 2023 18:22:12 +0200 Subject: [PATCH] Changing variable reshaping strategy --- DPI/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DPI/utils.py b/DPI/utils.py index 91e7b2e..8c2cd14 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -161,11 +161,11 @@ class ReplayBuffer: non_zero_indices = np.nonzero(buffer.episode_count)[0] variable = variable[non_zero_indices] if obs: - variable = variable.reshape(self.args.episode_length, self.args.batch_size, - self.args.frame_stack*self.args.channels, - self.args.image_size,self.args.image_size) + variable = variable.reshape(self.args.batch_size, self.args.episode_length, + self.args.frame_stack*self.args.channels, + self.args.image_size,self.args.image_size).transpose(1, 0, 2, 3, 4) else: - variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1) + variable = variable.reshape(self.args.batch_size, self.args.episode_length, -1).transpose(1, 0, 2) return variable def transform_grouped_steps(self, variable):