Adding momentum encoder

This commit is contained in:
Vedant Dave 2023-04-02 18:52:46 +02:00
parent d9d350e191
commit ada3cadf0c

View File

@ -155,17 +155,17 @@ class DPI:
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
) )
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.obs_decoder = ObservationDecoder( self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84) output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
) )
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128
)
self.transition_model = TransitionModel( self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256