From ada3cadf0c1e1242114befb688db8e0d5dc0d4b9 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Sun, 2 Apr 2023 18:52:46 +0200 Subject: [PATCH] Adding momentum encoder --- DPI/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/DPI/train.py b/DPI/train.py index 87b928d..219583c 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -155,17 +155,17 @@ class DPI: obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) state_size=self.args.state_size # 128 ) + + self.obs_encoder_momentum = ObservationEncoder( + obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) + state_size=self.args.state_size # 128 + ) self.obs_decoder = ObservationDecoder( state_size=self.args.state_size, # 128 output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84) ) - self.obs_encoder_momentum = ObservationEncoder( - obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) - state_size=self.args.state_size # 128 - ) - self.transition_model = TransitionModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256