diff --git a/tac_ssl.py b/tac_ssl.py new file mode 100644 index 0000000..4c7e239 --- /dev/null +++ b/tac_ssl.py @@ -0,0 +1,119 @@ +import os +from PIL import Image + +from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo +import matplotlib.pyplot as plt + +import torch +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import random_split +from torch.utils.data import DataLoader, Dataset +from torch.utils.tensorboard import SummaryWriter + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +writer = SummaryWriter('runs/mmssl') + +# Custom dataset +class CustomMultiModalDataset(Dataset): + def __init__(self, vision_folder, tactile_folder, transform=None): + self.vision_folder = vision_folder + self.tactile_folder = tactile_folder + self.transform = transform + + self.vision_files = sorted(os.listdir(vision_folder)) + self.tactile_files = sorted(os.listdir(tactile_folder)) + + def __len__(self): + return len(self.vision_files) + + def __getitem__(self, idx): + vision_path = os.path.join(self.vision_folder, self.vision_files[idx]) + tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx]) + + vision_image = Image.open(vision_path).convert("RGB") + tactile_image = Image.open(tactile_path).convert("RGB") + + if self.transform: + vision_image = self.transform(vision_image) + tactile_image = self.transform(tactile_image) + + return vision_image, tactile_image + +# Initialize augmentation +simple_transforms = transforms.Compose([ + transforms.CenterCrop(500), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +data_transforms = transforms.Compose([ + transforms.RandomApply([transforms.RandomRotation(150)], p=0.50), + transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), + transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50), + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5), +]) + +# Initialize dataset and dataloader +vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb" +tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac" +dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms) +#dataloader = DataLoader(dataset, batch_size=128, shuffle=True) + +# Split the dataset into 80-20 +train_size = int(0.8 * len(dataset)) +test_size = len(dataset) - train_size +train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) + +# Initialize dataloaders for train and test +train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True) +test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) + + +# Initialize model +model = MultiModalMoCo(writer, K=4096, m=0.999, T=0.07).to(device) + +# Initialize optimizer +vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters()) +tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters()) +optim_vision = optim.Adam(vision_module, lr=0.0001) +optim_tactile = optim.Adam(tactile_module, lr=0.0001) + +# Training loop +n_epochs = 250 # Number of epochs +for epoch in range(n_epochs): + for i, (x_vision, x_tactile) in enumerate(train_dataloader): + + # Augment images + x_vision_q = data_transforms(x_vision).to(device) + x_vision_k = data_transforms(x_vision).to(device) + + x_tactile_q = data_transforms(x_tactile).to(device) + x_tactile_k = data_transforms(x_tactile).to(device) + + # Forward pass to get the loss + loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len(train_dataloader)) + + # Backward pass and optimization + optim_vision.zero_grad() + optim_tactile.zero_grad() + loss.backward() + optim_vision.step() + optim_tactile.step() + + # Logging + if i % 10 == 0: + print(f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}") + writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i) + + # Evaluate and plot + compute_tsne(model, test_dataloader, writer, epoch) + evaluate_and_plot(model, test_dataloader, epoch, writer, device) + if epoch % 10 == 0: + torch.save(model.state_dict(), 'models/model.pth') + + +plt.show() diff --git a/train_mm_moco.py b/train_mm_moco.py new file mode 100644 index 0000000..38a3442 --- /dev/null +++ b/train_mm_moco.py @@ -0,0 +1,192 @@ +import torch +import torch.nn as nn +from torchvision import models # For using the ResNet-50 model +import torch.nn.functional as F + +import timm +import random +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE + +class MultiModalMoCo(nn.Module): + def __init__(self, writer, K=4096, m=0.99, T=1.0): + super(MultiModalMoCo, self).__init__() + self.writer = writer + self.K = K + self.m = m + self.T = T + + self.intra_dim = 64 + self.inter_dim = 64 + + # Initialize the queue + self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda() + self.queue_ptr = 0 + + def create_mlp_head(output_dim): + return nn.Sequential( + nn.Linear(2048, 2048), + nn.ReLU(), + nn.Linear(2048, output_dim) + ) + + def create_resnet_encoder(): + resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1') + features = list(resnet.children())[:-2] + features.append(nn.AdaptiveAvgPool2d((1, 1))) + features.append(nn.Flatten()) + return nn.Sequential(*features) + + # Vision encoders + self.vision_base_q = create_resnet_encoder() + self.vision_head_intra_q = create_mlp_head(self.intra_dim) + self.vision_head_inter_q = create_mlp_head(self.inter_dim) + + self.vision_base_k = create_resnet_encoder() + self.vision_head_intra_k = create_mlp_head(self.intra_dim) + self.vision_head_inter_k = create_mlp_head(self.inter_dim) + + # Tactile encoders + self.tactile_base_q = create_resnet_encoder() + self.tactile_head_intra_q = create_mlp_head(self.intra_dim) + self.tactile_head_inter_q = create_mlp_head(self.inter_dim) + + self.tactile_base_k = create_resnet_encoder() + self.tactile_head_intra_k = create_mlp_head(self.intra_dim) + self.tactile_head_inter_k = create_mlp_head(self.inter_dim) + + # Initialize key encoders with query encoder weights + self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) + self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) + + @torch.no_grad() + def concat_all_gather(self,tensor): + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + output = torch.cat(tensors_gather, dim=0) + return output + + def moco_contrastive_loss(self, q, k): + q = nn.functional.normalize(q, dim=1) + k = nn.functional.normalize(k, dim=1) + logits = torch.mm(q, k.T.detach()) / self.T + labels = torch.arange(logits.shape[0], dtype=torch.long).cuda() + return nn.CrossEntropyLoss()(logits, labels) + + @torch.no_grad() + def _momentum_update_key_encoder(self, base_q, base_k): + for param_q, param_k in zip(base_q.parameters(), base_k.parameters()): + param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + + def forward(self, x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len_train_dataloader): + vision_base_q = self.vision_base_q(x_vision_q) + vision_queries_intra = self.vision_head_intra_q(vision_base_q) + vision_queries_inter = self.vision_head_inter_q(vision_base_q) + + with torch.no_grad(): + self._momentum_update_key_encoder(self.vision_base_q, self.vision_base_k) + vision_base_k = self.vision_base_k(x_vision_k) + vision_keys_intra = self.vision_head_intra_k(vision_base_k) + vision_keys_inter = self.vision_head_inter_k(vision_base_k) + + tactile_base_q = self.tactile_base_q(x_tactile_q) + tactile_queries_intra = self.tactile_head_intra_q(tactile_base_q) + tactile_queries_inter = self.tactile_head_inter_q(tactile_base_q) + + with torch.no_grad(): + self._momentum_update_key_encoder(self.tactile_base_q, self.tactile_base_k) + tactile_base_k = self.tactile_base_k(x_tactile_k) + tactile_keys_intra = self.tactile_head_intra_k(tactile_base_k) + tactile_keys_inter = self.tactile_head_inter_k(tactile_base_k) + + # Compute the contrastive loss for each pair of queries and keys + vision_loss_intra = self.moco_contrastive_loss(vision_queries_intra, vision_keys_intra) + tactile_loss_intra = self.moco_contrastive_loss(tactile_queries_intra, tactile_keys_intra) + vision_tactile_inter = self.moco_contrastive_loss(vision_queries_inter, tactile_keys_inter) + tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter) + + # Combine losses (you can use different strategies to combine these losses) + weight_inter = 0.1 + combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter + + if len_train_dataloader != 0: + self.writer.add_scalar('module loss/vision intra loss', vision_loss_intra.item(), epoch * len_train_dataloader + i) + self.writer.add_scalar('module loss/tactile intra loss', tactile_loss_intra.item(), epoch * len_train_dataloader + i) + self.writer.add_scalar('module loss/vision tactile inter loss', vision_tactile_inter.item() * weight_inter, epoch * len_train_dataloader + i) + self.writer.add_scalar('module loss/tactile vision inter loss', tactile_vision_inter.item() * weight_inter, epoch * len_train_dataloader + i) + + return combined_loss + + +def denormalize(tensor, mean, std): + for t, m, s in zip(tensor, mean, std): + t.mul_(s).add_(m) + return tensor + +def evaluate_and_plot(model, test_dataloader, epoch, writer, device): + model.eval() + + with torch.no_grad(): + test_data_list = list(test_dataloader) + x_vision_test, x_tactile_test = random.choice(test_data_list) + random_indices = random.sample(range(x_vision_test.shape[0]), 4) + x_vision_test = x_vision_test[random_indices].to(device) + x_tactile_test = x_tactile_test[random_indices].to(device) + + with torch.no_grad(): + test_loss = model(x_vision_test, x_vision_test, x_tactile_test, x_tactile_test, epoch, 0, 0) + + # Denormalize vision images + x_vision_test_denorm = denormalize(x_vision_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + x_vision_test_denorm = x_vision_test_denorm.cpu().numpy() + x_vision_test_denorm = np.clip(x_vision_test_denorm, 0, 1) + + # Denormalize tactile images + x_tactile_test_denorm = denormalize(x_tactile_test.clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + x_tactile_test_denorm = x_tactile_test_denorm.cpu().numpy() + x_tactile_test_denorm = np.clip(x_tactile_test_denorm, 0, 1) + + writer.add_images('Vision_Images', x_vision_test_denorm, epoch) + writer.add_images('Tactile_Images', x_tactile_test_denorm, epoch) + + writer.add_scalar('testing loss', test_loss.item(), epoch * len(test_dataloader)) + print(f"Test Loss: {test_loss.item():.4f}") + + +def compute_tsne(model, test_dataloader, writer, epoch): + with torch.no_grad(): + test_data_list = list(test_dataloader) + x_vision_test, x_tactile_test = random.choice(test_data_list) + random_indices = random.sample(range(x_vision_test.shape[0]), 10) + x_vision_test = x_vision_test[random_indices].to('cuda') + x_tactile_test = x_tactile_test[random_indices].to('cuda') + vision_base_q = model.vision_base_q(x_vision_test) + tactile_base_q = model.tactile_base_q(x_tactile_test) + + vision_base_q = vision_base_q.cpu().numpy() + tactile_base_q = tactile_base_q.cpu().numpy() + + tsne = TSNE(n_components=2, random_state=0, perplexity=5) + + # Create pairs of corresponding representations and labels + num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0]) + data = np.concatenate((vision_base_q[:num_samples], tactile_base_q[:num_samples]), axis=0) + labels = np.arange(1, num_samples+1).repeat(2) + + tsne_data = tsne.fit_transform(data) + + fig = plt.figure(figsize=(10, 10)) + + for i, (x, y) in enumerate(tsne_data): + plt.scatter(x, y, color='blue') + plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom') + plt.savefig('temp_figure.png') + plt.close(fig) + + image = Image.open('temp_figure.png') + image = np.array(image) # Convert image to a NumPy array + image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW + writer.add_image('t-SNE', image, global_step=epoch) \ No newline at end of file