diff --git a/DPI/utils.py b/DPI/utils.py index 9892ed8..7867eeb 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -156,14 +156,16 @@ class ReplayBuffer: obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) return obs,acs,rews,terms - def group_steps(self, buffer, variable): + def group_steps(self, buffer, variable, obs=True): variable = getattr(buffer, variable) non_zero_indices = np.nonzero(buffer.episode_count)[0] variable = variable[non_zero_indices] - - variable = variable.reshape(self.args.episode_length, self.args.batch_size, - self.args.frame_stack*self.args.channels, - self.args.image_size,self.args.image_size) + if obs: + variable = variable.reshape(self.args.episode_length, self.args.batch_size, + self.args.frame_stack*self.args.channels, + self.args.image_size,self.args.image_size) + else: + variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1) return variable def transform_grouped_steps(self, variable): @@ -227,7 +229,7 @@ class CorruptVideos: Check if a video file is corrupt. Args: - filepath (str): Path to the video file. + dir_path (str): Path to the video file. Returns: bool: True if the video is corrupt, False otherwise.