214 lines
9.1 KiB
Python
214 lines
9.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torchvision import models # For using the ResNet-50 model
|
|
import torch.nn.functional as F
|
|
|
|
import timm
|
|
import wandb
|
|
import random
|
|
import numpy as np
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.manifold import TSNE
|
|
|
|
|
|
class MultiModalMoCo(nn.Module):
|
|
def __init__(self, m=0.99, T=1.0, nn_model=None):
|
|
super(MultiModalMoCo, self).__init__()
|
|
self.m = m
|
|
self.T = T
|
|
self.nn_model = nn_model
|
|
|
|
self.intra_dim = 128
|
|
self.inter_dim = 128
|
|
|
|
def create_mlp_head(output_dim):
|
|
if self.nn_model == 'resnet18':
|
|
return nn.Sequential(
|
|
nn.Linear(512, 2048),
|
|
nn.ReLU(),
|
|
nn.Linear(2048, output_dim)
|
|
)
|
|
elif self.nn_model == 'resnet50':
|
|
return nn.Sequential(
|
|
nn.Linear(2048, 2048),
|
|
nn.ReLU(),
|
|
nn.Linear(2048, output_dim)
|
|
)
|
|
|
|
def create_resnet_encoder():
|
|
if self.nn_model == 'resnet18':
|
|
resnet = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1')
|
|
elif self.nn_model == 'resnet50':
|
|
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
|
|
#resnet = models.regnet_x_800mf(weights='RegNet_X_800MF_Weights')
|
|
features = list(resnet.children())[:-2]
|
|
features.append(nn.AdaptiveAvgPool2d((1, 1)))
|
|
features.append(nn.Flatten())
|
|
return nn.Sequential(*features)
|
|
|
|
# Vision encoders
|
|
self.vision_base_q = create_resnet_encoder()
|
|
self.vision_base_k = create_resnet_encoder()
|
|
self.tactile_base_q = create_resnet_encoder()
|
|
self.tactile_base_k = create_resnet_encoder()
|
|
|
|
# Projection heads
|
|
self.phi_vision_q = create_mlp_head(self.intra_dim)
|
|
self.phi_vision_k = create_mlp_head(self.intra_dim)
|
|
self.phi_tactile_q = create_mlp_head(self.intra_dim)
|
|
self.phi_tactile_k = create_mlp_head(self.intra_dim)
|
|
|
|
# Initialize key encoders with query encoder weights
|
|
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
|
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
|
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
|
|
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
|
|
|
|
@torch.no_grad()
|
|
def concat_all_gather(self,tensor):
|
|
tensors_gather = [torch.ones_like(tensor)
|
|
for _ in range(torch.distributed.get_world_size())]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
output = torch.cat(tensors_gather, dim=0)
|
|
return output
|
|
|
|
def moco_contrastive_loss(self, q, k):
|
|
q = nn.functional.normalize(q, dim=1)
|
|
k = nn.functional.normalize(k, dim=1)
|
|
logits = torch.mm(q, k.T.detach()) / self.T
|
|
labels = torch.arange(logits.shape[0], dtype=torch.long).cuda()
|
|
return nn.CrossEntropyLoss()(logits, labels)
|
|
|
|
@torch.no_grad()
|
|
def _momentum_update_key_encoder(self, base_q, base_k):
|
|
for param_q, param_k in zip(base_q.parameters(), base_k.parameters()):
|
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
|
|
|
|
def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader):
|
|
vision_base_q = self.vision_base_q(x_vision_q)
|
|
q_vv = self.phi_vision_q(vision_base_q)
|
|
q_vt = self.phi_tactile_q(vision_base_q)
|
|
|
|
vision_base_k = self.vision_base_k(x_vision_k)
|
|
k_vv = self.phi_vision_k(vision_base_k)
|
|
k_tv = self.phi_tactile_k(vision_base_k)
|
|
|
|
tactile_base_q = self.tactile_base_q(x_tactile_q)
|
|
q_tv = self.phi_vision_q(tactile_base_q)
|
|
q_tt = self.phi_tactile_q(tactile_base_q)
|
|
|
|
tactile_base_k = self.tactile_base_k(x_tactile_k)
|
|
k_vt = self.phi_vision_k(tactile_base_k)
|
|
k_tt = self.phi_tactile_k(tactile_base_k)
|
|
|
|
# Update key encoders
|
|
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
|
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
|
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
|
|
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
|
|
|
|
# Compute the contrastive loss for each pair of queries and keys
|
|
vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv)
|
|
tactile_tactile_intra = self.moco_contrastive_loss(q_tt, k_tt)
|
|
tactile_vision_inter = self.moco_contrastive_loss(q_vt, k_vt)
|
|
vision_tactile_inter = self.moco_contrastive_loss(q_tv, k_tv)
|
|
|
|
# Combine losses (you can use different strategies to combine these losses)
|
|
weight_inter = 1
|
|
combined_loss = vision_vision_intra + tactile_tactile_intra + (tactile_vision_inter + vision_tactile_inter) * weight_inter
|
|
|
|
|
|
if len_train_dataloader != 0:
|
|
wandb.log({
|
|
'module loss/vision intra loss': vision_vision_intra.item(),
|
|
'module loss/tactile intra loss': tactile_tactile_intra.item(),
|
|
'module loss/vision tactile inter loss': vision_tactile_inter.item() * weight_inter,
|
|
'module loss/tactile vision inter loss': tactile_vision_inter.item() * weight_inter
|
|
}, step=epoch * len_train_dataloader + i)
|
|
return combined_loss
|
|
|
|
|
|
def denormalize(tensor, mean, std):
|
|
for t, m, s in zip(tensor, mean, std):
|
|
t.mul_(s).add_(m)
|
|
return tensor
|
|
|
|
def evaluate_and_plot(model, test_dataloader, epoch, device):
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
test_data_list = list(test_dataloader)
|
|
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
|
random_indices = random.sample(range(x_vision_test.shape[0]), 4)
|
|
x_vision_test = x_vision_test[random_indices].to(device)
|
|
x_tactile_test = x_tactile_test[random_indices].to(device)
|
|
|
|
with torch.no_grad():
|
|
test_loss = model(x_vision_test, x_vision_test, x_tactile_test, x_tactile_test, epoch, 0, 0)
|
|
|
|
# Denormalize vision images
|
|
x_vision_test_denorm = denormalize(x_vision_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
x_vision_test_denorm = x_vision_test_denorm.cpu().numpy()
|
|
x_vision_test_denorm = np.clip(x_vision_test_denorm, 0, 1)
|
|
|
|
# Denormalize tactile images
|
|
x_tactile_test_denorm = denormalize(x_tactile_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy()
|
|
x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1)
|
|
|
|
x_vision_test_denorm = x_vision_test_denorm.transpose(0, 2, 3, 1)
|
|
x_tactile_test_denorm = x_tactile_test_denorm.transpose(0, 2, 3, 1)
|
|
wandb.log({
|
|
"Vision_Images": [wandb.Image(img_tensor) for img_tensor in x_vision_test_denorm],
|
|
"Tactile_Images": [wandb.Image(img_tensor) for img_tensor in x_tactile_test_denorm]
|
|
}, commit=False)
|
|
wandb.log({"testing loss": test_loss.item()}, step=epoch * len(test_dataloader))
|
|
print(f"Test Loss: {test_loss.item():.4f}")
|
|
|
|
def compute_tsne(model, test_dataloader, epoch):
|
|
with torch.no_grad():
|
|
test_data_list = list(test_dataloader)
|
|
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
|
random_indices = random.sample(range(x_vision_test.shape[0]), 20)
|
|
x_vision_test = x_vision_test[random_indices].to('cuda')
|
|
x_tactile_test = x_tactile_test[random_indices].to('cuda')
|
|
vision_base_q = model.vision_base_q(x_vision_test)
|
|
tactile_base_q = model.tactile_base_q(x_tactile_test)
|
|
|
|
vision_base_q = vision_base_q.cpu().numpy()
|
|
tactile_base_q = tactile_base_q.cpu().numpy()
|
|
|
|
tsne = TSNE(n_components=2, random_state=0, perplexity=2)
|
|
|
|
# Create pairs of corresponding representations and labels
|
|
num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0])
|
|
data = np.concatenate((vision_base_q[:num_samples], tactile_base_q[:num_samples]), axis=0)
|
|
labels = np.arange(1, num_samples+1).repeat(2)
|
|
|
|
tsne_data = tsne.fit_transform(data)
|
|
|
|
fig = plt.figure(figsize=(10, 10))
|
|
|
|
for i, (x, y) in enumerate(tsne_data):
|
|
plt.scatter(x, y, color='blue')
|
|
plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom')
|
|
plt.savefig('temp_figure.png')
|
|
plt.close(fig)
|
|
|
|
image = Image.open('temp_figure.png')
|
|
image = np.array(image) # Convert image to a NumPy array
|
|
image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW
|
|
wandb.log({"t-SNE": wandb.Image(image)}, commit=False)
|
|
|
|
def find_knn(query_point, data_points, k=5):
|
|
# Calculate the Euclidean distances
|
|
distances = torch.norm(data_points - query_point, dim=1)
|
|
|
|
# Find the indices of the k smallest distances
|
|
knn_indices = torch.topk(distances, k, largest=False, sorted=True)[1]
|
|
|
|
# Get the k smallest distances
|
|
knn_distances = distances[knn_indices]
|
|
|
|
return knn_indices, knn_distances |