import torch import torch.nn as nn from torchvision import models # For using the ResNet-50 model import torch.nn.functional as F import timm import random import numpy as np from PIL import Image import matplotlib.pyplot as plt from sklearn.manifold import TSNE class MultiModalMoCo(nn.Module): def __init__(self, writer, K=4096, m=0.99, T=1.0): super(MultiModalMoCo, self).__init__() self.writer = writer self.K = K self.m = m self.T = T self.intra_dim = 64 self.inter_dim = 64 # Initialize the queue self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda() self.queue_ptr = 0 def create_mlp_head(output_dim): return nn.Sequential( nn.Linear(2048, 2048), nn.ReLU(), nn.Linear(2048, output_dim) ) def create_resnet_encoder(): resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1') features = list(resnet.children())[:-2] features.append(nn.AdaptiveAvgPool2d((1, 1))) features.append(nn.Flatten()) return nn.Sequential(*features) # Vision encoders self.vision_base_q = create_resnet_encoder() self.vision_head_intra_q = create_mlp_head(self.intra_dim) self.vision_head_inter_q = create_mlp_head(self.inter_dim) self.vision_base_k = create_resnet_encoder() self.vision_head_intra_k = create_mlp_head(self.intra_dim) self.vision_head_inter_k = create_mlp_head(self.inter_dim) # Tactile encoders self.tactile_base_q = create_resnet_encoder() self.tactile_head_intra_q = create_mlp_head(self.intra_dim) self.tactile_head_inter_q = create_mlp_head(self.inter_dim) self.tactile_base_k = create_resnet_encoder() self.tactile_head_intra_k = create_mlp_head(self.intra_dim) self.tactile_head_inter_k = create_mlp_head(self.inter_dim) # Initialize key encoders with query encoder weights self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) @torch.no_grad() def concat_all_gather(self,tensor): tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def moco_contrastive_loss(self, q, k): q = nn.functional.normalize(q, dim=1) k = nn.functional.normalize(k, dim=1) logits = torch.mm(q, k.T.detach()) / self.T labels = torch.arange(logits.shape[0], dtype=torch.long).cuda() return nn.CrossEntropyLoss()(logits, labels) @torch.no_grad() def _momentum_update_key_encoder(self, base_q, base_k): for param_q, param_k in zip(base_q.parameters(), base_k.parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader): vision_base_q = self.vision_base_q(x_vision_q) vision_queries_intra = self.vision_head_intra_q(vision_base_q) vision_queries_inter = self.vision_head_inter_q(vision_base_q) with torch.no_grad(): self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) vision_base_k = self.vision_base_k(x_vision_k) vision_keys_intra = self.vision_head_intra_k(vision_base_k) vision_keys_inter = self.vision_head_inter_k(vision_base_k) tactile_base_q = self.tactile_base_q(x_tactile_q) tactile_queries_intra = self.tactile_head_intra_q(tactile_base_q) tactile_queries_inter = self.tactile_head_inter_q(tactile_base_q) with torch.no_grad(): self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) tactile_base_k = self.tactile_base_k(x_tactile_k) tactile_keys_intra = self.tactile_head_intra_k(tactile_base_k) tactile_keys_inter = self.tactile_head_inter_k(tactile_base_k) # Compute the contrastive loss for each pair of queries and keys vision_loss_intra = self.moco_contrastive_loss(vision_queries_intra, vision_keys_intra) tactile_loss_intra = self.moco_contrastive_loss(tactile_queries_intra, tactile_keys_intra) vision_tactile_inter = self.moco_contrastive_loss(vision_queries_inter, tactile_keys_inter) tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter) # Combine losses (you can use different strategies to combine these losses) weight_inter = 0.1 combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter if len_train_dataloader != 0: self.writer.add_scalar('module loss/vision intra loss', vision_loss_intra.item(), epoch * len_train_dataloader + i) self.writer.add_scalar('module loss/tactile intra loss', tactile_loss_intra.item(), epoch * len_train_dataloader + i) self.writer.add_scalar('module loss/vision tactile inter loss', vision_tactile_inter.item() * weight_inter, epoch * len_train_dataloader + i) self.writer.add_scalar('module loss/tactile vision inter loss', tactile_vision_inter.item() * weight_inter, epoch * len_train_dataloader + i) return combined_loss def denormalize(tensor, mean, std): for t, m, s in zip(tensor, mean, std): t.mul_(s).add_(m) return tensor def evaluate_and_plot(model, test_dataloader, epoch, writer, device): model.eval() with torch.no_grad(): test_data_list = list(test_dataloader) x_vision_test, x_tactile_test = random.choice(test_data_list) random_indices = random.sample(range(x_vision_test.shape[0]), 4) x_vision_test = x_vision_test[random_indices].to(device) x_tactile_test = x_tactile_test[random_indices].to(device) with torch.no_grad(): test_loss = model(x_vision_test, x_vision_test, x_tactile_test, x_tactile_test, epoch, 0, 0) # Denormalize vision images x_vision_test_denorm = denormalize(x_vision_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) x_vision_test_denorm = x_vision_test_denorm.cpu().numpy() x_vision_test_denorm = np.clip(x_vision_test_denorm, 0, 1) # Denormalize tactile images x_tactile_test_denorm = denormalize(x_tactile_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy() x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1) writer.add_images('Vision_Images', x_vision_test_denorm, epoch) writer.add_images('Tactile_Images', x_tactile_test_denorm, epoch) writer.add_scalar('testing loss', test_loss.item(), epoch * len(test_dataloader)) print(f"Test Loss: {test_loss.item():.4f}") def compute_tsne(model, test_dataloader, writer, epoch): with torch.no_grad(): test_data_list = list(test_dataloader) x_vision_test, x_tactile_test = random.choice(test_data_list) random_indices = random.sample(range(x_vision_test.shape[0]), 10) x_vision_test = x_vision_test[random_indices].to('cuda') x_tactile_test = x_tactile_test[random_indices].to('cuda') vision_base_q = model.vision_base_q(x_vision_test) tactile_base_q = model.tactile_base_q(x_tactile_test) vision_base_q = vision_base_q.cpu().numpy() tactile_base_q = tactile_base_q.cpu().numpy() tsne = TSNE(n_components=2, random_state=0, perplexity=5) # Create pairs of corresponding representations and labels num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0]) data = np.concatenate((vision_base_q[:num_samples], tactile_base_q[:num_samples]), axis=0) labels = np.arange(1, num_samples+1).repeat(2) tsne_data = tsne.fit_transform(data) fig = plt.figure(figsize=(10, 10)) for i, (x, y) in enumerate(tsne_data): plt.scatter(x, y, color='blue') plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom') plt.savefig('temp_figure.png') plt.close(fig) image = Image.open('temp_figure.png') image = np.array(image) # Convert image to a NumPy array image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW writer.add_image('t-SNE', image, global_step=epoch)