Update 'train_mm_moco.py'

This commit is contained in:
Vedant Dave 2023-09-13 08:45:47 +00:00
parent 9fd405c42e
commit 856ddfc6b8

View File

@ -103,8 +103,8 @@ class MultiModalMoCo(nn.Module):
with torch.no_grad(): # no gradient with torch.no_grad(): # no gradient
# Update key encoders # Update key encoders
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k) self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
# Compute key features # Compute key features
vision_base_k = self.vision_base_k(x_vision_k) vision_base_k = self.vision_base_k(x_vision_k)
k_vv = self.phi_vision_k(vision_base_k) k_vv = self.phi_vision_k(vision_base_k)
@ -116,7 +116,7 @@ class MultiModalMoCo(nn.Module):
with torch.no_grad(): # no gradient with torch.no_grad(): # no gradient
# Update key encoders # Update key encoders
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k) self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k)
self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k) self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k)
# Compute key features # Compute key features