diff --git a/train_mm_moco.py b/train_mm_moco.py old mode 100644 new mode 100755 index 47c02f1..597bc08 --- a/train_mm_moco.py +++ b/train_mm_moco.py @@ -4,36 +4,43 @@ from torchvision import models # For using the ResNet-50 model import torch.nn.functional as F import timm +import wandb import random import numpy as np from PIL import Image import matplotlib.pyplot as plt from sklearn.manifold import TSNE + class MultiModalMoCo(nn.Module): - def __init__(self, writer, K=4096, m=0.99, T=1.0): + def __init__(self, m=0.99, T=1.0, nn_model=None): super(MultiModalMoCo, self).__init__() - self.writer = writer - self.K = K self.m = m self.T = T + self.nn_model = nn_model self.intra_dim = 128 - self.inter_dim = 128 - - # Initialize the queue - self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda() - self.queue_ptr = 0 + self.inter_dim = 128 def create_mlp_head(output_dim): - return nn.Sequential( - nn.Linear(2048, 2048), - nn.ReLU(), - nn.Linear(2048, output_dim) - ) + if self.nn_model == 'resnet18': + return nn.Sequential( + nn.Linear(512, 2048), + nn.ReLU(), + nn.Linear(2048, output_dim) + ) + elif self.nn_model == 'resnet50': + return nn.Sequential( + nn.Linear(2048, 2048), + nn.ReLU(), + nn.Linear(2048, output_dim) + ) def create_resnet_encoder(): - resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1') + if self.nn_model == 'resnet18': + resnet = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1') + elif self.nn_model == 'resnet50': + resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1') #resnet = models.regnet_x_800mf(weights='RegNet_X_800MF_Weights') features = list(resnet.children())[:-2] features.append(nn.AdaptiveAvgPool2d((1, 1))) @@ -42,25 +49,31 @@ class MultiModalMoCo(nn.Module): # Vision encoders self.vision_base_q = create_resnet_encoder() - self.vision_head_intra_q = create_mlp_head(self.intra_dim) - self.vision_head_inter_q = create_mlp_head(self.inter_dim) - self.vision_base_k = create_resnet_encoder() - self.vision_head_intra_k = create_mlp_head(self.intra_dim) - self.vision_head_inter_k = create_mlp_head(self.inter_dim) - - # Tactile encoders self.tactile_base_q = create_resnet_encoder() - self.tactile_head_intra_q = create_mlp_head(self.intra_dim) - self.tactile_head_inter_q = create_mlp_head(self.inter_dim) - self.tactile_base_k = create_resnet_encoder() - self.tactile_head_intra_k = create_mlp_head(self.intra_dim) - self.tactile_head_inter_k = create_mlp_head(self.inter_dim) + + # Projection heads + self.phi_vision_q = create_mlp_head(self.intra_dim) + self.phi_tactile_q = create_mlp_head(self.intra_dim) + + self.phi_vision_k = create_mlp_head(self.intra_dim) + self.phi_tactile_k = create_mlp_head(self.intra_dim) + + self.Phi_vision_q = create_mlp_head(self.intra_dim) + self.Phi_tactile_q = create_mlp_head(self.intra_dim) + + self.Phi_vision_k = create_mlp_head(self.intra_dim) + self.Phi_tactile_k = create_mlp_head(self.intra_dim) # Initialize key encoders with query encoder weights self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) + self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k) + self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) + self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k) + self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k) + @torch.no_grad() def concat_all_gather(self,tensor): @@ -84,41 +97,45 @@ class MultiModalMoCo(nn.Module): def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader): vision_base_q = self.vision_base_q(x_vision_q) - vision_queries_intra = self.vision_head_intra_q(vision_base_q) - vision_queries_inter = self.vision_head_inter_q(vision_base_q) + q_vv = self.phi_vision_q(vision_base_q) + q_vt = self.phi_tactile_q(vision_base_q) - with torch.no_grad(): - self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) - vision_base_k = self.vision_base_k(x_vision_k) - vision_keys_intra = self.vision_head_intra_k(vision_base_k) - vision_keys_inter = self.vision_head_inter_k(vision_base_k) + vision_base_k = self.vision_base_k(x_vision_k) + k_vv = self.phi_vision_k(vision_base_k) + k_tv = self.phi_tactile_k(vision_base_k) tactile_base_q = self.tactile_base_q(x_tactile_q) - tactile_queries_intra = self.tactile_head_intra_q(tactile_base_q) - tactile_queries_inter = self.tactile_head_inter_q(tactile_base_q) + q_tv = self.phi_vision_q(tactile_base_q) + q_tt = self.phi_tactile_q(tactile_base_q) - with torch.no_grad(): - self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) - tactile_base_k = self.tactile_base_k(x_tactile_k) - tactile_keys_intra = self.tactile_head_intra_k(tactile_base_k) - tactile_keys_inter = self.tactile_head_inter_k(tactile_base_k) + tactile_base_k = self.tactile_base_k(x_tactile_k) + k_vt = self.phi_vision_k(tactile_base_k) + k_tt = self.phi_tactile_k(tactile_base_k) + + # Update key encoders + self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) + self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) + self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k) + self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k) # Compute the contrastive loss for each pair of queries and keys - vision_loss_intra = self.moco_contrastive_loss(vision_queries_intra, vision_keys_intra) - tactile_loss_intra = self.moco_contrastive_loss(tactile_queries_intra, tactile_keys_intra) - vision_tactile_inter = self.moco_contrastive_loss(vision_queries_inter, tactile_keys_inter) - tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter) - + vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv) + tactile_tactile_intra = self.moco_contrastive_loss(q_tt, k_tt) + tactile_vision_inter = self.moco_contrastive_loss(q_vt, k_vt) + vision_tactile_inter = self.moco_contrastive_loss(q_tv, k_tv) + # Combine losses (you can use different strategies to combine these losses) weight_inter = 1 - combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter + combined_loss = vision_vision_intra + tactile_tactile_intra + (tactile_vision_inter + vision_tactile_inter) * weight_inter + if len_train_dataloader != 0: - self.writer.add_scalar('module loss/vision intra loss', vision_loss_intra.item(), epoch * len_train_dataloader + i) - self.writer.add_scalar('module loss/tactile intra loss', tactile_loss_intra.item(), epoch * len_train_dataloader + i) - self.writer.add_scalar('module loss/vision tactile inter loss', vision_tactile_inter.item() * weight_inter, epoch * len_train_dataloader + i) - self.writer.add_scalar('module loss/tactile vision inter loss', tactile_vision_inter.item() * weight_inter, epoch * len_train_dataloader + i) - + wandb.log({ + 'module loss/vision intra loss': vision_vision_intra.item(), + 'module loss/tactile intra loss': tactile_tactile_intra.item(), + 'module loss/vision tactile inter loss': vision_tactile_inter.item() * weight_inter, + 'module loss/tactile vision inter loss': tactile_vision_inter.item() * weight_inter + }, step=epoch * len_train_dataloader + i) return combined_loss @@ -127,7 +144,7 @@ def denormalize(tensor, mean, std): t.mul_(s).add_(m) return tensor -def evaluate_and_plot(model, test_dataloader, epoch, writer, device): +def evaluate_and_plot(model, test_dataloader, epoch, device): model.eval() with torch.no_grad(): @@ -150,14 +167,16 @@ def evaluate_and_plot(model, test_dataloader, epoch, writer, device): x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy() x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1) - writer.add_images('Vision_Images', x_vision_test_denorm, epoch) - writer.add_images('Tactile_Images', x_tactile_test_denorm, epoch) - - writer.add_scalar('testing loss', test_loss.item(), epoch * len(test_dataloader)) + x_vision_test_denorm = x_vision_test_denorm.transpose(0, 2, 3, 1) + x_tactile_test_denorm = x_tactile_test_denorm.transpose(0, 2, 3, 1) + wandb.log({ + "Vision_Images": [wandb.Image(img_tensor) for img_tensor in x_vision_test_denorm], + "Tactile_Images": [wandb.Image(img_tensor) for img_tensor in x_tactile_test_denorm] + }, commit=False) + wandb.log({"testing loss": test_loss.item()}, step=epoch * len(test_dataloader)) print(f"Test Loss: {test_loss.item():.4f}") - -def compute_tsne(model, test_dataloader, writer, epoch): +def compute_tsne(model, test_dataloader, epoch): with torch.no_grad(): test_data_list = list(test_dataloader) x_vision_test, x_tactile_test = random.choice(test_data_list) @@ -190,8 +209,7 @@ def compute_tsne(model, test_dataloader, writer, epoch): image = Image.open('temp_figure.png') image = np.array(image) # Convert image to a NumPy array image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW - writer.add_image('t-SNE', image, global_step=epoch) - + wandb.log({"t-SNE": wandb.Image(image)}, commit=False) def find_knn(query_point, data_points, k=5): # Calculate the Euclidean distances