diff --git a/tac_ssl_test.py b/tac_ssl_test.py new file mode 100644 index 0000000..ce78701 --- /dev/null +++ b/tac_ssl_test.py @@ -0,0 +1,135 @@ +import torch +from torchvision import transforms +from torch.utils.data import random_split +from torch.utils.data import DataLoader, Dataset + +import os +import random +import pickle +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from train_mm_moco import MultiModalMoCo +from sklearn.neighbors import NearestNeighbors + + +def denormalize(tensor, mean, std): + for t, m, s in zip(tensor, mean, std): + t.mul_(s).add_(m) + return tensor + +def compute_tsne(model, test_dataloader): + with torch.no_grad(): + test_data_list = list(test_dataloader) + x_vision_test, x_tactile_test = random.choice(test_data_list) + random_indices = random.sample(range(x_vision_test.shape[0]), 100) + x_vision_test = x_vision_test[random_indices].to('cuda') + x_tactile_test = x_tactile_test[random_indices].to('cuda') + vision_base_q = model.vision_base_q(x_vision_test) + tactile_base_q = model.tactile_base_q(x_tactile_test) + + x_vision_test = denormalize(x_vision_test, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + x_tactile_test = denormalize(x_tactile_test, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + vision_base_q = vision_base_q.cpu().numpy() + tactile_base_q = tactile_base_q.cpu().numpy() + + image_data = np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0) + + tsne = TSNE(n_components=2, random_state=0, perplexity=75,n_iter=50000) + + # Create pairs of corresponding representations and labels + num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0]) + data = np.concatenate((vision_base_q[:num_samples], tactile_base_q[:num_samples]), axis=0) + labels = np.arange(1, 2*(num_samples)+1) + + tsne_data = tsne.fit_transform(data) + nn_all = find_knn(tsne_data, labels) + plot_images_by_labels(image_data, nn_all[0]) + print(nn_all[:5]) + + fig = plt.figure(figsize=(10, 10)) + + for i, (x, y) in enumerate(tsne_data): + plt.scatter(x, y, color='blue') + plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom') + plt.savefig('temp_figure.png') + plt.close(fig) + + image = Image.open('temp_figure.png') + image = np.array(image) # Convert image to a NumPy array + image_rgb = image[:, :, :3] # Extract RGB channels and change format to CHW + plt.imshow(image_rgb) + plt.title('t-SNE plot') + plt.axis('off') + plt.show() + +def find_knn(tsne_data, labels): + neigh = NearestNeighbors(n_neighbors=8) + neigh.fit(tsne_data) + knn = neigh.kneighbors(tsne_data, return_distance=False) + return labels[knn] + +def plot_images_by_labels(image_data, labels_to_plot): + fig, axes = plt.subplots(1, len(labels_to_plot), figsize=(15, 5)) + for i, label in enumerate(labels_to_plot): + axes[i].imshow(image_data[label].transpose(1, 2, 0)) + axes[i].set_title(label) + axes[i].axis('off') + plt.show() + +if __name__ == "__main__": + class CustomMultiModalDataset(Dataset): + def __init__(self, vision_folder, tactile_folder, transform=None): + self.vision_folder = vision_folder + self.tactile_folder = tactile_folder + self.transform = transform + + self.vision_files = sorted(os.listdir(vision_folder)) + self.tactile_files = sorted(os.listdir(tactile_folder)) + + def __len__(self): + return len(self.vision_files) + + def __getitem__(self, idx): + vision_path = os.path.join(self.vision_folder, self.vision_files[idx]) + tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx]) + + vision_image = Image.open(vision_path).convert("RGB") + tactile_image = Image.open(tactile_path).convert("RGB") + + if self.transform: + vision_image = self.transform(vision_image) + tactile_image = self.transform(tactile_image) + + return vision_image, tactile_image + + # Initialize augmentation + simple_transforms = transforms.Compose([ + transforms.CenterCrop(500), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # Load the indices from disk + with open('indices/train_indices.pkl', 'rb') as f: + train_indices = pickle.load(f) + + with open('indices/test_indices.pkl', 'rb') as f: + test_indices = pickle.load(f) + + # Initialize dataset and dataloader + vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb" + tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac" + dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms) + + # Create subset datasets and DataLoaders + test_subset = torch.utils.data.Subset(dataset, test_indices) + test_dataloader = DataLoader(test_subset, batch_size=150, shuffle=False) + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter('runs/mmssl1') + model = MultiModalMoCo(writer).to('cuda') + model.load_state_dict(torch.load('/home/vedant/TacSSL/models/model.pth')) + compute_tsne(model, test_dataloader) \ No newline at end of file