From c8fdd11d8c74ad5e1bd705d25c134541c51f298d Mon Sep 17 00:00:00 2001 From: VedantDave Date: Wed, 12 Apr 2023 17:30:20 +0200 Subject: [PATCH] Outputting only 3 channels from the decoder --- DPI/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DPI/train.py b/DPI/train.py index 3e58e9d..f49adc4 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -173,7 +173,7 @@ class DPI: self.obs_decoder = ObservationDecoder( state_size=self.args.state_size, # 128 - output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (9,84,84) + output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) ) self.transition_model = TransitionModel( @@ -362,7 +362,7 @@ class DPI: # decoder loss horizon = np.minimum(50-i, imagine_horizon) obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon]) - decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon])) + decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:])) # reward loss reward_dist = self.reward_model(self.current_states_dict["sample"])