From 9a2e9f420b2b0281f02080eb1db0d6a46052ec6d Mon Sep 17 00:00:00 2001 From: VedantDave Date: Sat, 15 Apr 2023 15:54:09 +0200 Subject: [PATCH] Checking branch push --- DPI/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/DPI/utils.py b/DPI/utils.py index 9c1ef22..c4f2208 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -200,6 +200,7 @@ class ReplayBuffer: variable = getattr(buffer, variable) non_zero_indices = np.nonzero(buffer.episode_count)[0] variable = variable[non_zero_indices] + if obs: variable = variable.reshape(-1, self.args.episode_length, self.args.frame_stack*self.args.channels,