Grouping for actions too

This commit is contained in:
Vedant Dave 2023-03-27 19:22:47 +02:00
parent 38cc645253
commit a1fe81f018

View File

@ -156,14 +156,16 @@ class ReplayBuffer:
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,rews,terms
def group_steps(self, buffer, variable):
def group_steps(self, buffer, variable, obs=True):
variable = getattr(buffer, variable)
non_zero_indices = np.nonzero(buffer.episode_count)[0]
variable = variable[non_zero_indices]
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size)
if obs:
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size)
else:
variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1)
return variable
def transform_grouped_steps(self, variable):
@ -227,7 +229,7 @@ class CorruptVideos:
Check if a video file is corrupt.
Args:
filepath (str): Path to the video file.
dir_path (str): Path to the video file.
Returns:
bool: True if the video is corrupt, False otherwise.