import os from PIL import Image from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo import matplotlib.pyplot as plt import torch import torch.optim as optim from torchvision import transforms from torch.utils.data import random_split from torch.utils.data import DataLoader, Dataset from torch.utils.tensorboard import SummaryWriter device = torch.device("cuda" if torch.cuda.is_available() else "cpu") writer = SummaryWriter('runs/mmssl') # Custom dataset class CustomMultiModalDataset(Dataset): def __init__(self, vision_folder, tactile_folder, transform=None): self.vision_folder = vision_folder self.tactile_folder = tactile_folder self.transform = transform self.vision_files = sorted(os.listdir(vision_folder)) self.tactile_files = sorted(os.listdir(tactile_folder)) def __len__(self): return len(self.vision_files) def __getitem__(self, idx): vision_path = os.path.join(self.vision_folder, self.vision_files[idx]) tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx]) vision_image = Image.open(vision_path).convert("RGB") tactile_image = Image.open(tactile_path).convert("RGB") if self.transform: vision_image = self.transform(vision_image) tactile_image = self.transform(tactile_image) return vision_image, tactile_image # Initialize augmentation simple_transforms = transforms.Compose([ transforms.CenterCrop(500), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) data_transforms = transforms.Compose([ transforms.RandomApply([transforms.RandomRotation(150)], p=0.50), transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50), transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5), ]) # Initialize dataset and dataloader vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb" tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac" dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms) #dataloader = DataLoader(dataset, batch_size=128, shuffle=True) # Split the dataset into 80-20 train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) # Initialize dataloaders for train and test train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Initialize model model = MultiModalMoCo(writer, K=4096, m=0.999, T=0.07).to(device) # Initialize optimizer vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters()) tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters()) optim_vision = optim.Adam(vision_module, lr=0.0001) optim_tactile = optim.Adam(tactile_module, lr=0.0001) # Training loop n_epochs = 250 # Number of epochs for epoch in range(n_epochs): for i, (x_vision, x_tactile) in enumerate(train_dataloader): # Augment images x_vision_q = data_transforms(x_vision).to(device) x_vision_k = data_transforms(x_vision).to(device) x_tactile_q = data_transforms(x_tactile).to(device) x_tactile_k = data_transforms(x_tactile).to(device) # Forward pass to get the loss loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len(train_dataloader)) # Backward pass and optimization optim_vision.zero_grad() optim_tactile.zero_grad() loss.backward() optim_vision.step() optim_tactile.step() # Logging if i % 10 == 0: print(f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}") writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i) # Evaluate and plot compute_tsne(model, test_dataloader, writer, epoch) evaluate_and_plot(model, test_dataloader, epoch, writer, device) if epoch % 10 == 0: torch.save(model.state_dict(), 'models/model.pth') plt.show()