diff --git a/DPI/utils.py b/DPI/utils.py index 9c1ef22..c4f2208 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -200,6 +200,7 @@ class ReplayBuffer: variable = getattr(buffer, variable) non_zero_indices = np.nonzero(buffer.episode_count)[0] variable = variable[non_zero_indices] + if obs: variable = variable.reshape(-1, self.args.episode_length, self.args.frame_stack*self.args.channels,