Outputting only 3 channels from the decoder
This commit is contained in:
parent
a83149f61e
commit
c8fdd11d8c
@ -173,7 +173,7 @@ class DPI:
|
|||||||
|
|
||||||
self.obs_decoder = ObservationDecoder(
|
self.obs_decoder = ObservationDecoder(
|
||||||
state_size=self.args.state_size, # 128
|
state_size=self.args.state_size, # 128
|
||||||
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (9,84,84)
|
output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.transition_model = TransitionModel(
|
self.transition_model = TransitionModel(
|
||||||
@ -362,7 +362,7 @@ class DPI:
|
|||||||
# decoder loss
|
# decoder loss
|
||||||
horizon = np.minimum(50-i, imagine_horizon)
|
horizon = np.minimum(50-i, imagine_horizon)
|
||||||
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
|
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
|
||||||
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon]))
|
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
|
||||||
|
|
||||||
# reward loss
|
# reward loss
|
||||||
reward_dist = self.reward_model(self.current_states_dict["sample"])
|
reward_dist = self.reward_model(self.current_states_dict["sample"])
|
||||||
|
Loading…
Reference in New Issue
Block a user