Update 'train_mm_moco.py'
This commit is contained in:
parent
ee5ce8cb82
commit
9fd405c42e
@ -100,6 +100,12 @@ class MultiModalMoCo(nn.Module):
|
|||||||
q_vv = self.phi_vision_q(vision_base_q)
|
q_vv = self.phi_vision_q(vision_base_q)
|
||||||
q_vt = self.phi_tactile_q(vision_base_q)
|
q_vt = self.phi_tactile_q(vision_base_q)
|
||||||
|
|
||||||
|
with torch.no_grad(): # no gradient
|
||||||
|
# Update key encoders
|
||||||
|
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
||||||
|
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
||||||
|
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
|
||||||
|
# Compute key features
|
||||||
vision_base_k = self.vision_base_k(x_vision_k)
|
vision_base_k = self.vision_base_k(x_vision_k)
|
||||||
k_vv = self.phi_vision_k(vision_base_k)
|
k_vv = self.phi_vision_k(vision_base_k)
|
||||||
k_tv = self.phi_tactile_k(vision_base_k)
|
k_tv = self.phi_tactile_k(vision_base_k)
|
||||||
@ -108,18 +114,15 @@ class MultiModalMoCo(nn.Module):
|
|||||||
q_tv = self.Phi_vision_q(tactile_base_q)
|
q_tv = self.Phi_vision_q(tactile_base_q)
|
||||||
q_tt = self.Phi_tactile_q(tactile_base_q)
|
q_tt = self.Phi_tactile_q(tactile_base_q)
|
||||||
|
|
||||||
tactile_base_k = self.tactile_base_k(x_tactile_k)
|
with torch.no_grad(): # no gradient
|
||||||
k_vt = self.Phi_vision_k(tactile_base_k)
|
|
||||||
k_tt = self.Phi_tactile_k(tactile_base_k)
|
|
||||||
|
|
||||||
# Update key encoders
|
# Update key encoders
|
||||||
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
|
||||||
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
|
||||||
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
|
|
||||||
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
|
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
|
||||||
self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k)
|
self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k)
|
||||||
self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k)
|
self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k)
|
||||||
|
# Compute key features
|
||||||
|
tactile_base_k = self.tactile_base_k(x_tactile_k)
|
||||||
|
k_vt = self.Phi_vision_k(tactile_base_k)
|
||||||
|
k_tt = self.Phi_tactile_k(tactile_base_k)
|
||||||
|
|
||||||
# Compute the contrastive loss for each pair of queries and keys
|
# Compute the contrastive loss for each pair of queries and keys
|
||||||
vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv)
|
vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv)
|
||||||
|
Loading…
Reference in New Issue
Block a user