diff --git a/DPI/train.py b/DPI/train.py index 3e58e9d..f49adc4 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -173,7 +173,7 @@ class DPI: self.obs_decoder = ObservationDecoder( state_size=self.args.state_size, # 128 - output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (9,84,84) + output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) ) self.transition_model = TransitionModel( @@ -362,7 +362,7 @@ class DPI: # decoder loss horizon = np.minimum(50-i, imagine_horizon) obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon]) - decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon])) + decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:])) # reward loss reward_dist = self.reward_model(self.current_states_dict["sample"])