Update 'train_mm_moco.py'

This commit is contained in:
Vedant Dave 2023-09-13 08:20:00 +00:00
parent fd69f00137
commit ee5ce8cb82

View File

@ -105,18 +105,21 @@ class MultiModalMoCo(nn.Module):
k_tv = self.phi_tactile_k(vision_base_k)
tactile_base_q = self.tactile_base_q(x_tactile_q)
q_tv = self.phi_vision_q(tactile_base_q)
q_tt = self.phi_tactile_q(tactile_base_q)
q_tv = self.Phi_vision_q(tactile_base_q)
q_tt = self.Phi_tactile_q(tactile_base_q)
tactile_base_k = self.tactile_base_k(x_tactile_k)
k_vt = self.phi_vision_k(tactile_base_k)
k_tt = self.phi_tactile_k(tactile_base_k)
k_vt = self.Phi_vision_k(tactile_base_k)
k_tt = self.Phi_tactile_k(tactile_base_k)
# Update key encoders
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k)
self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k)
# Compute the contrastive loss for each pair of queries and keys
vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv)