This commit is contained in:
Vedant Dave 2023-09-12 13:35:14 +00:00
parent 3897ccd505
commit 53fc2efe2f

View File

@ -4,28 +4,32 @@ from torchvision import models # For using the ResNet-50 model
import torch.nn.functional as F import torch.nn.functional as F
import timm import timm
import wandb
import random import random
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
class MultiModalMoCo(nn.Module): class MultiModalMoCo(nn.Module):
def __init__(self, writer, K=4096, m=0.99, T=1.0): def __init__(self, m=0.99, T=1.0, nn_model=None):
super(MultiModalMoCo, self).__init__() super(MultiModalMoCo, self).__init__()
self.writer = writer
self.K = K
self.m = m self.m = m
self.T = T self.T = T
self.nn_model = nn_model
self.intra_dim = 128 self.intra_dim = 128
self.inter_dim = 128 self.inter_dim = 128
# Initialize the queue
self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda()
self.queue_ptr = 0
def create_mlp_head(output_dim): def create_mlp_head(output_dim):
if self.nn_model == 'resnet18':
return nn.Sequential(
nn.Linear(512, 2048),
nn.ReLU(),
nn.Linear(2048, output_dim)
)
elif self.nn_model == 'resnet50':
return nn.Sequential( return nn.Sequential(
nn.Linear(2048, 2048), nn.Linear(2048, 2048),
nn.ReLU(), nn.ReLU(),
@ -33,6 +37,9 @@ class MultiModalMoCo(nn.Module):
) )
def create_resnet_encoder(): def create_resnet_encoder():
if self.nn_model == 'resnet18':
resnet = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1')
elif self.nn_model == 'resnet50':
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1') resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
#resnet = models.regnet_x_800mf(weights='RegNet_X_800MF_Weights') #resnet = models.regnet_x_800mf(weights='RegNet_X_800MF_Weights')
features = list(resnet.children())[:-2] features = list(resnet.children())[:-2]
@ -42,25 +49,21 @@ class MultiModalMoCo(nn.Module):
# Vision encoders # Vision encoders
self.vision_base_q = create_resnet_encoder() self.vision_base_q = create_resnet_encoder()
self.vision_head_intra_q = create_mlp_head(self.intra_dim)
self.vision_head_inter_q = create_mlp_head(self.inter_dim)
self.vision_base_k = create_resnet_encoder() self.vision_base_k = create_resnet_encoder()
self.vision_head_intra_k = create_mlp_head(self.intra_dim)
self.vision_head_inter_k = create_mlp_head(self.inter_dim)
# Tactile encoders
self.tactile_base_q = create_resnet_encoder() self.tactile_base_q = create_resnet_encoder()
self.tactile_head_intra_q = create_mlp_head(self.intra_dim)
self.tactile_head_inter_q = create_mlp_head(self.inter_dim)
self.tactile_base_k = create_resnet_encoder() self.tactile_base_k = create_resnet_encoder()
self.tactile_head_intra_k = create_mlp_head(self.intra_dim)
self.tactile_head_inter_k = create_mlp_head(self.inter_dim) # Projection heads
self.phi_vision_q = create_mlp_head(self.intra_dim)
self.phi_vision_k = create_mlp_head(self.intra_dim)
self.phi_tactile_q = create_mlp_head(self.intra_dim)
self.phi_tactile_k = create_mlp_head(self.intra_dim)
# Initialize key encoders with query encoder weights # Initialize key encoders with query encoder weights
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
@torch.no_grad() @torch.no_grad()
def concat_all_gather(self,tensor): def concat_all_gather(self,tensor):
@ -84,41 +87,45 @@ class MultiModalMoCo(nn.Module):
def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader): def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader):
vision_base_q = self.vision_base_q(x_vision_q) vision_base_q = self.vision_base_q(x_vision_q)
vision_queries_intra = self.vision_head_intra_q(vision_base_q) q_vv = self.phi_vision_q(vision_base_q)
vision_queries_inter = self.vision_head_inter_q(vision_base_q) q_vt = self.phi_tactile_q(vision_base_q)
with torch.no_grad():
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
vision_base_k = self.vision_base_k(x_vision_k) vision_base_k = self.vision_base_k(x_vision_k)
vision_keys_intra = self.vision_head_intra_k(vision_base_k) k_vv = self.phi_vision_k(vision_base_k)
vision_keys_inter = self.vision_head_inter_k(vision_base_k) k_tv = self.phi_tactile_k(vision_base_k)
tactile_base_q = self.tactile_base_q(x_tactile_q) tactile_base_q = self.tactile_base_q(x_tactile_q)
tactile_queries_intra = self.tactile_head_intra_q(tactile_base_q) q_tv = self.phi_vision_q(tactile_base_q)
tactile_queries_inter = self.tactile_head_inter_q(tactile_base_q) q_tt = self.phi_tactile_q(tactile_base_q)
with torch.no_grad():
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
tactile_base_k = self.tactile_base_k(x_tactile_k) tactile_base_k = self.tactile_base_k(x_tactile_k)
tactile_keys_intra = self.tactile_head_intra_k(tactile_base_k) k_vt = self.phi_vision_k(tactile_base_k)
tactile_keys_inter = self.tactile_head_inter_k(tactile_base_k) k_tt = self.phi_tactile_k(tactile_base_k)
# Update key encoders
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
# Compute the contrastive loss for each pair of queries and keys # Compute the contrastive loss for each pair of queries and keys
vision_loss_intra = self.moco_contrastive_loss(vision_queries_intra, vision_keys_intra) vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv)
tactile_loss_intra = self.moco_contrastive_loss(tactile_queries_intra, tactile_keys_intra) tactile_tactile_intra = self.moco_contrastive_loss(q_tt, k_tt)
vision_tactile_inter = self.moco_contrastive_loss(vision_queries_inter, tactile_keys_inter) tactile_vision_inter = self.moco_contrastive_loss(q_vt, k_vt)
tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter) vision_tactile_inter = self.moco_contrastive_loss(q_tv, k_tv)
# Combine losses (you can use different strategies to combine these losses) # Combine losses (you can use different strategies to combine these losses)
weight_inter = 1 weight_inter = 1
combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter combined_loss = vision_vision_intra + tactile_tactile_intra + (tactile_vision_inter + vision_tactile_inter) * weight_inter
if len_train_dataloader != 0: if len_train_dataloader != 0:
self.writer.add_scalar('module loss/vision intra loss', vision_loss_intra.item(), epoch * len_train_dataloader + i) wandb.log({
self.writer.add_scalar('module loss/tactile intra loss', tactile_loss_intra.item(), epoch * len_train_dataloader + i) 'module loss/vision intra loss': vision_vision_intra.item(),
self.writer.add_scalar('module loss/vision tactile inter loss', vision_tactile_inter.item() * weight_inter, epoch * len_train_dataloader + i) 'module loss/tactile intra loss': tactile_tactile_intra.item(),
self.writer.add_scalar('module loss/tactile vision inter loss', tactile_vision_inter.item() * weight_inter, epoch * len_train_dataloader + i) 'module loss/vision tactile inter loss': vision_tactile_inter.item() * weight_inter,
'module loss/tactile vision inter loss': tactile_vision_inter.item() * weight_inter
}, step=epoch * len_train_dataloader + i)
return combined_loss return combined_loss
@ -127,7 +134,7 @@ def denormalize(tensor, mean, std):
t.mul_(s).add_(m) t.mul_(s).add_(m)
return tensor return tensor
def evaluate_and_plot(model, test_dataloader, epoch, writer, device): def evaluate_and_plot(model, test_dataloader, epoch, device):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -150,14 +157,16 @@ def evaluate_and_plot(model, test_dataloader, epoch, writer, device):
x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy() x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy()
x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1) x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1)
writer.add_images('Vision_Images', x_vision_test_denorm, epoch) x_vision_test_denorm = x_vision_test_denorm.transpose(0, 2, 3, 1)
writer.add_images('Tactile_Images', x_tactile_test_denorm, epoch) x_tactile_test_denorm = x_tactile_test_denorm.transpose(0, 2, 3, 1)
wandb.log({
writer.add_scalar('testing loss', test_loss.item(), epoch * len(test_dataloader)) "Vision_Images": [wandb.Image(img_tensor) for img_tensor in x_vision_test_denorm],
"Tactile_Images": [wandb.Image(img_tensor) for img_tensor in x_tactile_test_denorm]
}, commit=False)
wandb.log({"testing loss": test_loss.item()}, step=epoch * len(test_dataloader))
print(f"Test Loss: {test_loss.item():.4f}") print(f"Test Loss: {test_loss.item():.4f}")
def compute_tsne(model, test_dataloader, epoch):
def compute_tsne(model, test_dataloader, writer, epoch):
with torch.no_grad(): with torch.no_grad():
test_data_list = list(test_dataloader) test_data_list = list(test_dataloader)
x_vision_test, x_tactile_test = random.choice(test_data_list) x_vision_test, x_tactile_test = random.choice(test_data_list)
@ -190,8 +199,7 @@ def compute_tsne(model, test_dataloader, writer, epoch):
image = Image.open('temp_figure.png') image = Image.open('temp_figure.png')
image = np.array(image) # Convert image to a NumPy array image = np.array(image) # Convert image to a NumPy array
image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW
writer.add_image('t-SNE', image, global_step=epoch) wandb.log({"t-SNE": wandb.Image(image)}, commit=False)
def find_knn(query_point, data_points, k=5): def find_knn(query_point, data_points, k=5):
# Calculate the Euclidean distances # Calculate the Euclidean distances