Grouping for actions too
This commit is contained in:
parent
38cc645253
commit
a1fe81f018
@ -156,14 +156,16 @@ class ReplayBuffer:
|
||||
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
|
||||
return obs,acs,rews,terms
|
||||
|
||||
def group_steps(self, buffer, variable):
|
||||
def group_steps(self, buffer, variable, obs=True):
|
||||
variable = getattr(buffer, variable)
|
||||
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
||||
variable = variable[non_zero_indices]
|
||||
|
||||
if obs:
|
||||
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
|
||||
self.args.frame_stack*self.args.channels,
|
||||
self.args.image_size,self.args.image_size)
|
||||
else:
|
||||
variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1)
|
||||
return variable
|
||||
|
||||
def transform_grouped_steps(self, variable):
|
||||
@ -227,7 +229,7 @@ class CorruptVideos:
|
||||
Check if a video file is corrupt.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the video file.
|
||||
dir_path (str): Path to the video file.
|
||||
|
||||
Returns:
|
||||
bool: True if the video is corrupt, False otherwise.
|
||||
|
Loading…
Reference in New Issue
Block a user