From 856ddfc6b83bc8fe9442512cff89eb08b9d76760 Mon Sep 17 00:00:00 2001 From: Vedant Dave Date: Wed, 13 Sep 2023 08:45:47 +0000 Subject: [PATCH] Update 'train_mm_moco.py' --- train_mm_moco.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_mm_moco.py b/train_mm_moco.py index f11a9b4..6e8963b 100755 --- a/train_mm_moco.py +++ b/train_mm_moco.py @@ -103,8 +103,8 @@ class MultiModalMoCo(nn.Module): with torch.no_grad(): # no gradient # Update key encoders self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) - self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k) + self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) # Compute key features vision_base_k = self.vision_base_k(x_vision_k) k_vv = self.phi_vision_k(vision_base_k) @@ -116,7 +116,7 @@ class MultiModalMoCo(nn.Module): with torch.no_grad(): # no gradient # Update key encoders - self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) + self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k) self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k) # Compute key features