Adding momentum encoder
This commit is contained in:
parent
d9d350e191
commit
ada3cadf0c
10
DPI/train.py
10
DPI/train.py
@ -155,17 +155,17 @@ class DPI:
|
|||||||
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
|
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
|
||||||
state_size=self.args.state_size # 128
|
state_size=self.args.state_size # 128
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.obs_encoder_momentum = ObservationEncoder(
|
||||||
|
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
|
||||||
|
state_size=self.args.state_size # 128
|
||||||
|
)
|
||||||
|
|
||||||
self.obs_decoder = ObservationDecoder(
|
self.obs_decoder = ObservationDecoder(
|
||||||
state_size=self.args.state_size, # 128
|
state_size=self.args.state_size, # 128
|
||||||
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
|
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.obs_encoder_momentum = ObservationEncoder(
|
|
||||||
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
|
|
||||||
state_size=self.args.state_size # 128
|
|
||||||
)
|
|
||||||
|
|
||||||
self.transition_model = TransitionModel(
|
self.transition_model = TransitionModel(
|
||||||
state_size=self.args.state_size, # 128
|
state_size=self.args.state_size, # 128
|
||||||
hidden_size=self.args.hidden_size, # 256
|
hidden_size=self.args.hidden_size, # 256
|
||||||
|
Loading…
Reference in New Issue
Block a user