From 9fd405c42e6721029c7fa315a075fe05de156ab8 Mon Sep 17 00:00:00 2001 From: Vedant Dave Date: Wed, 13 Sep 2023 08:39:36 +0000 Subject: [PATCH] Update 'train_mm_moco.py' --- train_mm_moco.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/train_mm_moco.py b/train_mm_moco.py index 696c43c..f11a9b4 100755 --- a/train_mm_moco.py +++ b/train_mm_moco.py @@ -100,26 +100,29 @@ class MultiModalMoCo(nn.Module): q_vv = self.phi_vision_q(vision_base_q) q_vt = self.phi_tactile_q(vision_base_q) - vision_base_k = self.vision_base_k(x_vision_k) - k_vv = self.phi_vision_k(vision_base_k) - k_tv = self.phi_tactile_k(vision_base_k) + with torch.no_grad(): # no gradient + # Update key encoders + self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) + self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) + self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k) + # Compute key features + vision_base_k = self.vision_base_k(x_vision_k) + k_vv = self.phi_vision_k(vision_base_k) + k_tv = self.phi_tactile_k(vision_base_k) tactile_base_q = self.tactile_base_q(x_tactile_q) q_tv = self.Phi_vision_q(tactile_base_q) q_tt = self.Phi_tactile_q(tactile_base_q) - tactile_base_k = self.tactile_base_k(x_tactile_k) - k_vt = self.Phi_vision_k(tactile_base_k) - k_tt = self.Phi_tactile_k(tactile_base_k) - - # Update key encoders - self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) - self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) - self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k) - self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) - self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k) - self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k) - + with torch.no_grad(): # no gradient + # Update key encoders + self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) + self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k) + self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k) + # Compute key features + tactile_base_k = self.tactile_base_k(x_tactile_k) + k_vt = self.Phi_vision_k(tactile_base_k) + k_tt = self.Phi_tactile_k(tactile_base_k) # Compute the contrastive loss for each pair of queries and keys vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv)