Grouping for actions too
This commit is contained in:
parent
38cc645253
commit
a1fe81f018
@ -156,14 +156,16 @@ class ReplayBuffer:
|
|||||||
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
|
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
|
||||||
return obs,acs,rews,terms
|
return obs,acs,rews,terms
|
||||||
|
|
||||||
def group_steps(self, buffer, variable):
|
def group_steps(self, buffer, variable, obs=True):
|
||||||
variable = getattr(buffer, variable)
|
variable = getattr(buffer, variable)
|
||||||
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
||||||
variable = variable[non_zero_indices]
|
variable = variable[non_zero_indices]
|
||||||
|
if obs:
|
||||||
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
|
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
|
||||||
self.args.frame_stack*self.args.channels,
|
self.args.frame_stack*self.args.channels,
|
||||||
self.args.image_size,self.args.image_size)
|
self.args.image_size,self.args.image_size)
|
||||||
|
else:
|
||||||
|
variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1)
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
def transform_grouped_steps(self, variable):
|
def transform_grouped_steps(self, variable):
|
||||||
@ -227,7 +229,7 @@ class CorruptVideos:
|
|||||||
Check if a video file is corrupt.
|
Check if a video file is corrupt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filepath (str): Path to the video file.
|
dir_path (str): Path to the video file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the video is corrupt, False otherwise.
|
bool: True if the video is corrupt, False otherwise.
|
||||||
|
Loading…
Reference in New Issue
Block a user