Adding correct file
This commit is contained in:
parent
36f629139c
commit
29c80c6d92
138
train_mm_moco.py
Normal file → Executable file
138
train_mm_moco.py
Normal file → Executable file
@ -4,36 +4,43 @@ from torchvision import models # For using the ResNet-50 model
|
||||
import torch.nn.functional as F
|
||||
|
||||
import timm
|
||||
import wandb
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
|
||||
class MultiModalMoCo(nn.Module):
|
||||
def __init__(self, writer, K=4096, m=0.99, T=1.0):
|
||||
def __init__(self, m=0.99, T=1.0, nn_model=None):
|
||||
super(MultiModalMoCo, self).__init__()
|
||||
self.writer = writer
|
||||
self.K = K
|
||||
self.m = m
|
||||
self.T = T
|
||||
self.nn_model = nn_model
|
||||
|
||||
self.intra_dim = 128
|
||||
self.inter_dim = 128
|
||||
|
||||
# Initialize the queue
|
||||
self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda()
|
||||
self.queue_ptr = 0
|
||||
self.inter_dim = 128
|
||||
|
||||
def create_mlp_head(output_dim):
|
||||
return nn.Sequential(
|
||||
nn.Linear(2048, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, output_dim)
|
||||
)
|
||||
if self.nn_model == 'resnet18':
|
||||
return nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, output_dim)
|
||||
)
|
||||
elif self.nn_model == 'resnet50':
|
||||
return nn.Sequential(
|
||||
nn.Linear(2048, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, output_dim)
|
||||
)
|
||||
|
||||
def create_resnet_encoder():
|
||||
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
|
||||
if self.nn_model == 'resnet18':
|
||||
resnet = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1')
|
||||
elif self.nn_model == 'resnet50':
|
||||
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
|
||||
#resnet = models.regnet_x_800mf(weights='RegNet_X_800MF_Weights')
|
||||
features = list(resnet.children())[:-2]
|
||||
features.append(nn.AdaptiveAvgPool2d((1, 1)))
|
||||
@ -42,25 +49,31 @@ class MultiModalMoCo(nn.Module):
|
||||
|
||||
# Vision encoders
|
||||
self.vision_base_q = create_resnet_encoder()
|
||||
self.vision_head_intra_q = create_mlp_head(self.intra_dim)
|
||||
self.vision_head_inter_q = create_mlp_head(self.inter_dim)
|
||||
|
||||
self.vision_base_k = create_resnet_encoder()
|
||||
self.vision_head_intra_k = create_mlp_head(self.intra_dim)
|
||||
self.vision_head_inter_k = create_mlp_head(self.inter_dim)
|
||||
|
||||
# Tactile encoders
|
||||
self.tactile_base_q = create_resnet_encoder()
|
||||
self.tactile_head_intra_q = create_mlp_head(self.intra_dim)
|
||||
self.tactile_head_inter_q = create_mlp_head(self.inter_dim)
|
||||
|
||||
self.tactile_base_k = create_resnet_encoder()
|
||||
self.tactile_head_intra_k = create_mlp_head(self.intra_dim)
|
||||
self.tactile_head_inter_k = create_mlp_head(self.inter_dim)
|
||||
|
||||
# Projection heads
|
||||
self.phi_vision_q = create_mlp_head(self.intra_dim)
|
||||
self.phi_tactile_q = create_mlp_head(self.intra_dim)
|
||||
|
||||
self.phi_vision_k = create_mlp_head(self.intra_dim)
|
||||
self.phi_tactile_k = create_mlp_head(self.intra_dim)
|
||||
|
||||
self.Phi_vision_q = create_mlp_head(self.intra_dim)
|
||||
self.Phi_tactile_q = create_mlp_head(self.intra_dim)
|
||||
|
||||
self.Phi_vision_k = create_mlp_head(self.intra_dim)
|
||||
self.Phi_tactile_k = create_mlp_head(self.intra_dim)
|
||||
|
||||
# Initialize key encoders with query encoder weights
|
||||
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
||||
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
||||
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
|
||||
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
|
||||
self._momentum_update_key_encoder(self.Phi_vision_q, self.Phi_vision_k)
|
||||
self._momentum_update_key_encoder(self.Phi_tactile_q, self.Phi_tactile_k)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(self,tensor):
|
||||
@ -84,41 +97,45 @@ class MultiModalMoCo(nn.Module):
|
||||
|
||||
def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader):
|
||||
vision_base_q = self.vision_base_q(x_vision_q)
|
||||
vision_queries_intra = self.vision_head_intra_q(vision_base_q)
|
||||
vision_queries_inter = self.vision_head_inter_q(vision_base_q)
|
||||
q_vv = self.phi_vision_q(vision_base_q)
|
||||
q_vt = self.phi_tactile_q(vision_base_q)
|
||||
|
||||
with torch.no_grad():
|
||||
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
||||
vision_base_k = self.vision_base_k(x_vision_k)
|
||||
vision_keys_intra = self.vision_head_intra_k(vision_base_k)
|
||||
vision_keys_inter = self.vision_head_inter_k(vision_base_k)
|
||||
vision_base_k = self.vision_base_k(x_vision_k)
|
||||
k_vv = self.phi_vision_k(vision_base_k)
|
||||
k_tv = self.phi_tactile_k(vision_base_k)
|
||||
|
||||
tactile_base_q = self.tactile_base_q(x_tactile_q)
|
||||
tactile_queries_intra = self.tactile_head_intra_q(tactile_base_q)
|
||||
tactile_queries_inter = self.tactile_head_inter_q(tactile_base_q)
|
||||
q_tv = self.phi_vision_q(tactile_base_q)
|
||||
q_tt = self.phi_tactile_q(tactile_base_q)
|
||||
|
||||
with torch.no_grad():
|
||||
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
||||
tactile_base_k = self.tactile_base_k(x_tactile_k)
|
||||
tactile_keys_intra = self.tactile_head_intra_k(tactile_base_k)
|
||||
tactile_keys_inter = self.tactile_head_inter_k(tactile_base_k)
|
||||
tactile_base_k = self.tactile_base_k(x_tactile_k)
|
||||
k_vt = self.phi_vision_k(tactile_base_k)
|
||||
k_tt = self.phi_tactile_k(tactile_base_k)
|
||||
|
||||
# Update key encoders
|
||||
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
||||
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
||||
self._momentum_update_key_encoder(self.phi_vision_q, self.phi_vision_k)
|
||||
self._momentum_update_key_encoder(self.phi_tactile_q, self.phi_tactile_k)
|
||||
|
||||
# Compute the contrastive loss for each pair of queries and keys
|
||||
vision_loss_intra = self.moco_contrastive_loss(vision_queries_intra, vision_keys_intra)
|
||||
tactile_loss_intra = self.moco_contrastive_loss(tactile_queries_intra, tactile_keys_intra)
|
||||
vision_tactile_inter = self.moco_contrastive_loss(vision_queries_inter, tactile_keys_inter)
|
||||
tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter)
|
||||
|
||||
vision_vision_intra = self.moco_contrastive_loss(q_vv, k_vv)
|
||||
tactile_tactile_intra = self.moco_contrastive_loss(q_tt, k_tt)
|
||||
tactile_vision_inter = self.moco_contrastive_loss(q_vt, k_vt)
|
||||
vision_tactile_inter = self.moco_contrastive_loss(q_tv, k_tv)
|
||||
|
||||
# Combine losses (you can use different strategies to combine these losses)
|
||||
weight_inter = 1
|
||||
combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter
|
||||
combined_loss = vision_vision_intra + tactile_tactile_intra + (tactile_vision_inter + vision_tactile_inter) * weight_inter
|
||||
|
||||
|
||||
if len_train_dataloader != 0:
|
||||
self.writer.add_scalar('module loss/vision intra loss', vision_loss_intra.item(), epoch * len_train_dataloader + i)
|
||||
self.writer.add_scalar('module loss/tactile intra loss', tactile_loss_intra.item(), epoch * len_train_dataloader + i)
|
||||
self.writer.add_scalar('module loss/vision tactile inter loss', vision_tactile_inter.item() * weight_inter, epoch * len_train_dataloader + i)
|
||||
self.writer.add_scalar('module loss/tactile vision inter loss', tactile_vision_inter.item() * weight_inter, epoch * len_train_dataloader + i)
|
||||
|
||||
wandb.log({
|
||||
'module loss/vision intra loss': vision_vision_intra.item(),
|
||||
'module loss/tactile intra loss': tactile_tactile_intra.item(),
|
||||
'module loss/vision tactile inter loss': vision_tactile_inter.item() * weight_inter,
|
||||
'module loss/tactile vision inter loss': tactile_vision_inter.item() * weight_inter
|
||||
}, step=epoch * len_train_dataloader + i)
|
||||
return combined_loss
|
||||
|
||||
|
||||
@ -127,7 +144,7 @@ def denormalize(tensor, mean, std):
|
||||
t.mul_(s).add_(m)
|
||||
return tensor
|
||||
|
||||
def evaluate_and_plot(model, test_dataloader, epoch, writer, device):
|
||||
def evaluate_and_plot(model, test_dataloader, epoch, device):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@ -150,14 +167,16 @@ def evaluate_and_plot(model, test_dataloader, epoch, writer, device):
|
||||
x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy()
|
||||
x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1)
|
||||
|
||||
writer.add_images('Vision_Images', x_vision_test_denorm, epoch)
|
||||
writer.add_images('Tactile_Images', x_tactile_test_denorm, epoch)
|
||||
|
||||
writer.add_scalar('testing loss', test_loss.item(), epoch * len(test_dataloader))
|
||||
x_vision_test_denorm = x_vision_test_denorm.transpose(0, 2, 3, 1)
|
||||
x_tactile_test_denorm = x_tactile_test_denorm.transpose(0, 2, 3, 1)
|
||||
wandb.log({
|
||||
"Vision_Images": [wandb.Image(img_tensor) for img_tensor in x_vision_test_denorm],
|
||||
"Tactile_Images": [wandb.Image(img_tensor) for img_tensor in x_tactile_test_denorm]
|
||||
}, commit=False)
|
||||
wandb.log({"testing loss": test_loss.item()}, step=epoch * len(test_dataloader))
|
||||
print(f"Test Loss: {test_loss.item():.4f}")
|
||||
|
||||
|
||||
def compute_tsne(model, test_dataloader, writer, epoch):
|
||||
def compute_tsne(model, test_dataloader, epoch):
|
||||
with torch.no_grad():
|
||||
test_data_list = list(test_dataloader)
|
||||
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
||||
@ -190,8 +209,7 @@ def compute_tsne(model, test_dataloader, writer, epoch):
|
||||
image = Image.open('temp_figure.png')
|
||||
image = np.array(image) # Convert image to a NumPy array
|
||||
image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW
|
||||
writer.add_image('t-SNE', image, global_step=epoch)
|
||||
|
||||
wandb.log({"t-SNE": wandb.Image(image)}, commit=False)
|
||||
|
||||
def find_knn(query_point, data_points, k=5):
|
||||
# Calculate the Euclidean distances
|
||||
|
Loading…
Reference in New Issue
Block a user