diff --git a/DPI/train.py b/DPI/train.py index 87b928d..219583c 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -155,17 +155,17 @@ class DPI: obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) state_size=self.args.state_size # 128 ) + + self.obs_encoder_momentum = ObservationEncoder( + obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) + state_size=self.args.state_size # 128 + ) self.obs_decoder = ObservationDecoder( state_size=self.args.state_size, # 128 output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84) ) - self.obs_encoder_momentum = ObservationEncoder( - obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) - state_size=self.args.state_size # 128 - ) - self.transition_model = TransitionModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256