TVSSL/train_mm.py

191 lines
8.9 KiB
Python
Raw Normal View History

2023-08-30 11:39:44 +00:00
import torch
import torch.nn as nn
from torchvision import models # For using the ResNet-50 model
import torch.nn.functional as F
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
class MultiModalMoCo(nn.Module):
def __init__(self, writer, K=4096, m=0.99, T=0.07):
super(MultiModalMoCo, self).__init__()
self.writer = writer
self.K = K
self.m = m
self.T = T
self.intra_dim = 64
self.inter_dim = 64
self.W_vision_intra = nn.Parameter(torch.randn(self.intra_dim, self.intra_dim))
self.W_tactile_intra = nn.Parameter(torch.randn(self.intra_dim, self.intra_dim))
self.W_vision_inter = nn.Parameter(torch.randn(self.inter_dim, self.inter_dim))
self.W_tactile_inter = nn.Parameter(torch.randn(self.inter_dim, self.inter_dim))
def create_mlp_head(output_dim):
return nn.Sequential(
nn.Linear(2048, 2048),
nn.ReLU(),
nn.Linear(2048, output_dim)
)
def create_resnet_encoder():
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
features = list(resnet.children())[:-2]
features.append(nn.AdaptiveAvgPool2d((1, 1)))
features.append(nn.Flatten())
return nn.Sequential(*features)
# Vision encoders
self.vision_base_q = create_resnet_encoder()
self.vision_head_intra_q = create_mlp_head(self.intra_dim)
self.vision_head_inter_q = create_mlp_head(self.inter_dim)
self.vision_base_k = create_resnet_encoder()
self.vision_head_intra_k = create_mlp_head(self.intra_dim)
self.vision_head_inter_k = create_mlp_head(self.inter_dim)
# Tactile encoders
self.tactile_base_q = create_resnet_encoder()
self.tactile_head_intra_q = create_mlp_head(self.intra_dim)
self.tactile_head_inter_q = create_mlp_head(self.inter_dim)
self.tactile_base_k = create_resnet_encoder()
self.tactile_head_intra_k = create_mlp_head(self.intra_dim)
self.tactile_head_inter_k = create_mlp_head(self.inter_dim)
# Initialize key encoders with query encoder weights
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
@torch.no_grad()
def _momentum_update_key_encoder(self, base_q, base_k):
for param_q, param_k in zip(base_q.parameters(), base_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
def compute_logits(self, z_a, z_pos, W):
Wz = torch.matmul(W, z_pos.T) # (z_dim, B)
logits = torch.matmul(z_a, Wz) # (B, B)
logits = logits - torch.max(logits, 1)[0][:, None]
return logits
def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader):
vision_base_q = self.vision_base_q(x_vision_q)
vision_queries_intra = self.vision_head_intra_q(vision_base_q)
vision_queries_inter = self.vision_head_inter_q(vision_base_q)
with torch.no_grad():
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
vision_base_k = self.vision_base_k(x_vision_k)
vision_keys_intra = self.vision_head_intra_k(vision_base_k)
vision_keys_inter = self.vision_head_inter_k(vision_base_k)
tactile_base_q = self.tactile_base_q(x_tactile_q)
tactile_queries_intra = self.tactile_head_intra_q(tactile_base_q)
tactile_queries_inter = self.tactile_head_inter_q(tactile_base_q)
with torch.no_grad():
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
tactile_base_k = self.tactile_base_k(x_tactile_k)
tactile_keys_intra = self.tactile_head_intra_k(tactile_base_k)
tactile_keys_inter = self.tactile_head_inter_k(tactile_base_k)
# Compute the contrastive loss for each pair of queries and keys
vision_loss_intra = nn.CrossEntropyLoss()(self.compute_logits(vision_queries_intra, vision_keys_intra, self.W_vision_intra),
torch.arange(x_vision_q.size(0)).to(x_vision_q.device))
tactile_loss_intra = nn.CrossEntropyLoss()(self.compute_logits(tactile_queries_intra, tactile_keys_intra, self.W_tactile_intra),
torch.arange(x_tactile_q.size(0)).to(x_tactile_q.device))
vision_tactile_inter = nn.CrossEntropyLoss()(self.compute_logits(vision_queries_inter, tactile_keys_inter, self.W_vision_inter),
torch.arange(x_vision_q.size(0)).to(x_vision_q.device))
tactile_vision_inter = nn.CrossEntropyLoss()(self.compute_logits(tactile_queries_inter, vision_keys_inter, self.W_tactile_inter),
torch.arange(x_tactile_q.size(0)).to(x_tactile_q.device))
# Combine losses (you can use different strategies to combine these losses)
weight_inter = 0.1
combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter
if len_train_dataloader != 0:
self.writer.add_scalar('module loss/vision intra loss', vision_loss_intra.item(), epoch * len_train_dataloader + i)
self.writer.add_scalar('module loss/tactile intra loss', tactile_loss_intra.item(), epoch * len_train_dataloader + i)
self.writer.add_scalar('module loss/vision tactile inter loss', vision_tactile_inter.item() * weight_inter, epoch * len_train_dataloader + i)
self.writer.add_scalar('module loss/tactile vision inter loss', tactile_vision_inter.item() * weight_inter, epoch * len_train_dataloader + i)
return combined_loss
def denormalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.mul_(s).add_(m)
return tensor
def evaluate_and_plot(model, test_dataloader, epoch, writer, device):
model.eval()
with torch.no_grad():
test_data_list = list(test_dataloader)
x_vision_test, x_tactile_test = random.choice(test_data_list)
random_indices = random.sample(range(x_vision_test.shape[0]), 4)
x_vision_test = x_vision_test[random_indices].to(device)
x_tactile_test = x_tactile_test[random_indices].to(device)
with torch.no_grad():
test_loss = model(x_vision_test, x_vision_test, x_tactile_test, x_tactile_test, epoch, 0, 0)
# Denormalize vision images
x_vision_test_denorm = denormalize(x_vision_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
x_vision_test_denorm = x_vision_test_denorm.cpu().numpy()
x_vision_test_denorm = np.clip(x_vision_test_denorm, 0, 1)
# Denormalize tactile images
x_tactile_test_denorm = denormalize(x_tactile_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy()
x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1)
writer.add_images('Vision_Images', x_vision_test_denorm, epoch)
writer.add_images('Tactile_Images', x_tactile_test_denorm, epoch)
writer.add_scalar('testing loss', test_loss.item(), epoch * len(test_dataloader))
print(f"Test Loss: {test_loss.item():.4f}")
def compute_tsne(model, test_dataloader, writer, epoch):
with torch.no_grad():
test_data_list = list(test_dataloader)
x_vision_test, x_tactile_test = random.choice(test_data_list)
random_indices = random.sample(range(x_vision_test.shape[0]), 10)
x_vision_test = x_vision_test[random_indices].to('cuda')
x_tactile_test = x_tactile_test[random_indices].to('cuda')
vision_base_q = model.vision_base_q(x_vision_test)
tactile_base_q = model.tactile_base_q(x_tactile_test)
vision_base_q = vision_base_q.cpu().numpy()
tactile_base_q = tactile_base_q.cpu().numpy()
tsne = TSNE(n_components=2, random_state=0, perplexity=5)
# Create pairs of corresponding representations and labels
num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0])
data = np.concatenate((vision_base_q[:num_samples], tactile_base_q[:num_samples]), axis=0)
labels = np.arange(1, num_samples+1).repeat(2)
tsne_data = tsne.fit_transform(data)
fig = plt.figure(figsize=(10, 10))
for i, (x, y) in enumerate(tsne_data):
plt.scatter(x, y, color='blue')
plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom')
plt.savefig('temp_figure.png')
plt.close(fig)
image = Image.open('temp_figure.png')
image = np.array(image) # Convert image to a NumPy array
image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW
writer.add_image('t-SNE', image, global_step=epoch)