diff --git a/train_mm_moco.py b/train_mm_moco.py index f11a9b4..6e8963b 100755 --- a/train_mm_moco.py +++ b/train_mm_moco.py @@ -103,8 +103,8 @@ class MultiModalMoCo(nn.Module): with torch.no_grad(): # no gradient # Update key encoders self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) - self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k) + self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) # Compute key features vision_base_k = self.vision_base_k(x_vision_k) k_vv = self.phi_vision_k(vision_base_k) @@ -116,7 +116,7 @@ class MultiModalMoCo(nn.Module): with torch.no_grad(): # no gradient # Update key encoders - self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) + self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k) self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k) # Compute key features