Adding all files
This commit is contained in:
parent
b7a00908e4
commit
499414a262
119
tac_ssl.py
Normal file
119
tac_ssl.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import random_split
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
writer = SummaryWriter('runs/mmssl')
|
||||||
|
|
||||||
|
# Custom dataset
|
||||||
|
class CustomMultiModalDataset(Dataset):
|
||||||
|
def __init__(self, vision_folder, tactile_folder, transform=None):
|
||||||
|
self.vision_folder = vision_folder
|
||||||
|
self.tactile_folder = tactile_folder
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
self.vision_files = sorted(os.listdir(vision_folder))
|
||||||
|
self.tactile_files = sorted(os.listdir(tactile_folder))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.vision_files)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
vision_path = os.path.join(self.vision_folder, self.vision_files[idx])
|
||||||
|
tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx])
|
||||||
|
|
||||||
|
vision_image = Image.open(vision_path).convert("RGB")
|
||||||
|
tactile_image = Image.open(tactile_path).convert("RGB")
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
vision_image = self.transform(vision_image)
|
||||||
|
tactile_image = self.transform(tactile_image)
|
||||||
|
|
||||||
|
return vision_image, tactile_image
|
||||||
|
|
||||||
|
# Initialize augmentation
|
||||||
|
simple_transforms = transforms.Compose([
|
||||||
|
transforms.CenterCrop(500),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
data_transforms = transforms.Compose([
|
||||||
|
transforms.RandomApply([transforms.RandomRotation(150)], p=0.50),
|
||||||
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
|
||||||
|
transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50),
|
||||||
|
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
|
||||||
|
transforms.RandomGrayscale(p=0.2),
|
||||||
|
transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Initialize dataset and dataloader
|
||||||
|
vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb"
|
||||||
|
tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac"
|
||||||
|
dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms)
|
||||||
|
#dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
|
||||||
|
|
||||||
|
# Split the dataset into 80-20
|
||||||
|
train_size = int(0.8 * len(dataset))
|
||||||
|
test_size = len(dataset) - train_size
|
||||||
|
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
||||||
|
|
||||||
|
# Initialize dataloaders for train and test
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True)
|
||||||
|
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = MultiModalMoCo(writer, K=4096, m=0.999, T=0.07).to(device)
|
||||||
|
|
||||||
|
# Initialize optimizer
|
||||||
|
vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters())
|
||||||
|
tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters())
|
||||||
|
optim_vision = optim.Adam(vision_module, lr=0.0001)
|
||||||
|
optim_tactile = optim.Adam(tactile_module, lr=0.0001)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
n_epochs = 250 # Number of epochs
|
||||||
|
for epoch in range(n_epochs):
|
||||||
|
for i, (x_vision, x_tactile) in enumerate(train_dataloader):
|
||||||
|
|
||||||
|
# Augment images
|
||||||
|
x_vision_q = data_transforms(x_vision).to(device)
|
||||||
|
x_vision_k = data_transforms(x_vision).to(device)
|
||||||
|
|
||||||
|
x_tactile_q = data_transforms(x_tactile).to(device)
|
||||||
|
x_tactile_k = data_transforms(x_tactile).to(device)
|
||||||
|
|
||||||
|
# Forward pass to get the loss
|
||||||
|
loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len(train_dataloader))
|
||||||
|
|
||||||
|
# Backward pass and optimization
|
||||||
|
optim_vision.zero_grad()
|
||||||
|
optim_tactile.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optim_vision.step()
|
||||||
|
optim_tactile.step()
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
if i % 10 == 0:
|
||||||
|
print(f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}")
|
||||||
|
writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i)
|
||||||
|
|
||||||
|
# Evaluate and plot
|
||||||
|
compute_tsne(model, test_dataloader, writer, epoch)
|
||||||
|
evaluate_and_plot(model, test_dataloader, epoch, writer, device)
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
torch.save(model.state_dict(), 'models/model.pth')
|
||||||
|
|
||||||
|
|
||||||
|
plt.show()
|
192
train_mm_moco.py
Normal file
192
train_mm_moco.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import models # For using the ResNet-50 model
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import timm
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
|
class MultiModalMoCo(nn.Module):
|
||||||
|
def __init__(self, writer, K=4096, m=0.99, T=1.0):
|
||||||
|
super(MultiModalMoCo, self).__init__()
|
||||||
|
self.writer = writer
|
||||||
|
self.K = K
|
||||||
|
self.m = m
|
||||||
|
self.T = T
|
||||||
|
|
||||||
|
self.intra_dim = 64
|
||||||
|
self.inter_dim = 64
|
||||||
|
|
||||||
|
# Initialize the queue
|
||||||
|
self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda()
|
||||||
|
self.queue_ptr = 0
|
||||||
|
|
||||||
|
def create_mlp_head(output_dim):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(2048, 2048),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(2048, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_resnet_encoder():
|
||||||
|
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
|
||||||
|
features = list(resnet.children())[:-2]
|
||||||
|
features.append(nn.AdaptiveAvgPool2d((1, 1)))
|
||||||
|
features.append(nn.Flatten())
|
||||||
|
return nn.Sequential(*features)
|
||||||
|
|
||||||
|
# Vision encoders
|
||||||
|
self.vision_base_q = create_resnet_encoder()
|
||||||
|
self.vision_head_intra_q = create_mlp_head(self.intra_dim)
|
||||||
|
self.vision_head_inter_q = create_mlp_head(self.inter_dim)
|
||||||
|
|
||||||
|
self.vision_base_k = create_resnet_encoder()
|
||||||
|
self.vision_head_intra_k = create_mlp_head(self.intra_dim)
|
||||||
|
self.vision_head_inter_k = create_mlp_head(self.inter_dim)
|
||||||
|
|
||||||
|
# Tactile encoders
|
||||||
|
self.tactile_base_q = create_resnet_encoder()
|
||||||
|
self.tactile_head_intra_q = create_mlp_head(self.intra_dim)
|
||||||
|
self.tactile_head_inter_q = create_mlp_head(self.inter_dim)
|
||||||
|
|
||||||
|
self.tactile_base_k = create_resnet_encoder()
|
||||||
|
self.tactile_head_intra_k = create_mlp_head(self.intra_dim)
|
||||||
|
self.tactile_head_inter_k = create_mlp_head(self.inter_dim)
|
||||||
|
|
||||||
|
# Initialize key encoders with query encoder weights
|
||||||
|
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
||||||
|
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def concat_all_gather(self,tensor):
|
||||||
|
tensors_gather = [torch.ones_like(tensor)
|
||||||
|
for _ in range(torch.distributed.get_world_size())]
|
||||||
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||||
|
output = torch.cat(tensors_gather, dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def moco_contrastive_loss(self, q, k):
|
||||||
|
q = nn.functional.normalize(q, dim=1)
|
||||||
|
k = nn.functional.normalize(k, dim=1)
|
||||||
|
logits = torch.mm(q, k.T.detach()) / self.T
|
||||||
|
labels = torch.arange(logits.shape[0], dtype=torch.long).cuda()
|
||||||
|
return nn.CrossEntropyLoss()(logits, labels)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _momentum_update_key_encoder(self, base_q, base_k):
|
||||||
|
for param_q, param_k in zip(base_q.parameters(), base_k.parameters()):
|
||||||
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
|
||||||
|
|
||||||
|
def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader):
|
||||||
|
vision_base_q = self.vision_base_q(x_vision_q)
|
||||||
|
vision_queries_intra = self.vision_head_intra_q(vision_base_q)
|
||||||
|
vision_queries_inter = self.vision_head_inter_q(vision_base_q)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k)
|
||||||
|
vision_base_k = self.vision_base_k(x_vision_k)
|
||||||
|
vision_keys_intra = self.vision_head_intra_k(vision_base_k)
|
||||||
|
vision_keys_inter = self.vision_head_inter_k(vision_base_k)
|
||||||
|
|
||||||
|
tactile_base_q = self.tactile_base_q(x_tactile_q)
|
||||||
|
tactile_queries_intra = self.tactile_head_intra_q(tactile_base_q)
|
||||||
|
tactile_queries_inter = self.tactile_head_inter_q(tactile_base_q)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k)
|
||||||
|
tactile_base_k = self.tactile_base_k(x_tactile_k)
|
||||||
|
tactile_keys_intra = self.tactile_head_intra_k(tactile_base_k)
|
||||||
|
tactile_keys_inter = self.tactile_head_inter_k(tactile_base_k)
|
||||||
|
|
||||||
|
# Compute the contrastive loss for each pair of queries and keys
|
||||||
|
vision_loss_intra = self.moco_contrastive_loss(vision_queries_intra, vision_keys_intra)
|
||||||
|
tactile_loss_intra = self.moco_contrastive_loss(tactile_queries_intra, tactile_keys_intra)
|
||||||
|
vision_tactile_inter = self.moco_contrastive_loss(vision_queries_inter, tactile_keys_inter)
|
||||||
|
tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter)
|
||||||
|
|
||||||
|
# Combine losses (you can use different strategies to combine these losses)
|
||||||
|
weight_inter = 0.1
|
||||||
|
combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter
|
||||||
|
|
||||||
|
if len_train_dataloader != 0:
|
||||||
|
self.writer.add_scalar('module loss/vision intra loss', vision_loss_intra.item(), epoch * len_train_dataloader + i)
|
||||||
|
self.writer.add_scalar('module loss/tactile intra loss', tactile_loss_intra.item(), epoch * len_train_dataloader + i)
|
||||||
|
self.writer.add_scalar('module loss/vision tactile inter loss', vision_tactile_inter.item() * weight_inter, epoch * len_train_dataloader + i)
|
||||||
|
self.writer.add_scalar('module loss/tactile vision inter loss', tactile_vision_inter.item() * weight_inter, epoch * len_train_dataloader + i)
|
||||||
|
|
||||||
|
return combined_loss
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize(tensor, mean, std):
|
||||||
|
for t, m, s in zip(tensor, mean, std):
|
||||||
|
t.mul_(s).add_(m)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def evaluate_and_plot(model, test_dataloader, epoch, writer, device):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
test_data_list = list(test_dataloader)
|
||||||
|
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
||||||
|
random_indices = random.sample(range(x_vision_test.shape[0]), 4)
|
||||||
|
x_vision_test = x_vision_test[random_indices].to(device)
|
||||||
|
x_tactile_test = x_tactile_test[random_indices].to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
test_loss = model(x_vision_test, x_vision_test, x_tactile_test, x_tactile_test, epoch, 0, 0)
|
||||||
|
|
||||||
|
# Denormalize vision images
|
||||||
|
x_vision_test_denorm = denormalize(x_vision_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
x_vision_test_denorm = x_vision_test_denorm.cpu().numpy()
|
||||||
|
x_vision_test_denorm = np.clip(x_vision_test_denorm, 0, 1)
|
||||||
|
|
||||||
|
# Denormalize tactile images
|
||||||
|
x_tactile_test_denorm = denormalize(x_tactile_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy()
|
||||||
|
x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1)
|
||||||
|
|
||||||
|
writer.add_images('Vision_Images', x_vision_test_denorm, epoch)
|
||||||
|
writer.add_images('Tactile_Images', x_tactile_test_denorm, epoch)
|
||||||
|
|
||||||
|
writer.add_scalar('testing loss', test_loss.item(), epoch * len(test_dataloader))
|
||||||
|
print(f"Test Loss: {test_loss.item():.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_tsne(model, test_dataloader, writer, epoch):
|
||||||
|
with torch.no_grad():
|
||||||
|
test_data_list = list(test_dataloader)
|
||||||
|
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
||||||
|
random_indices = random.sample(range(x_vision_test.shape[0]), 10)
|
||||||
|
x_vision_test = x_vision_test[random_indices].to('cuda')
|
||||||
|
x_tactile_test = x_tactile_test[random_indices].to('cuda')
|
||||||
|
vision_base_q = model.vision_base_q(x_vision_test)
|
||||||
|
tactile_base_q = model.tactile_base_q(x_tactile_test)
|
||||||
|
|
||||||
|
vision_base_q = vision_base_q.cpu().numpy()
|
||||||
|
tactile_base_q = tactile_base_q.cpu().numpy()
|
||||||
|
|
||||||
|
tsne = TSNE(n_components=2, random_state=0, perplexity=5)
|
||||||
|
|
||||||
|
# Create pairs of corresponding representations and labels
|
||||||
|
num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0])
|
||||||
|
data = np.concatenate((vision_base_q[:num_samples], tactile_base_q[:num_samples]), axis=0)
|
||||||
|
labels = np.arange(1, num_samples+1).repeat(2)
|
||||||
|
|
||||||
|
tsne_data = tsne.fit_transform(data)
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(10, 10))
|
||||||
|
|
||||||
|
for i, (x, y) in enumerate(tsne_data):
|
||||||
|
plt.scatter(x, y, color='blue')
|
||||||
|
plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom')
|
||||||
|
plt.savefig('temp_figure.png')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
image = Image.open('temp_figure.png')
|
||||||
|
image = np.array(image) # Convert image to a NumPy array
|
||||||
|
image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW
|
||||||
|
writer.add_image('t-SNE', image, global_step=epoch)
|
Loading…
Reference in New Issue
Block a user