From 6a11c9ce615c52bc9c514280d94d2157e18ac09c Mon Sep 17 00:00:00 2001 From: Vedant Dave Date: Sat, 2 Sep 2023 17:26:10 +0200 Subject: [PATCH] Doing minor changes --- tac_ssl.py | 68 ++++++++++++++++++++++++++++++++---------------- tac_ssl_test.py | 31 ++++++++++++++-------- train_mm_moco.py | 26 +++++++++++++----- 3 files changed, 86 insertions(+), 39 deletions(-) diff --git a/tac_ssl.py b/tac_ssl.py index 4c7e239..bf59f87 100644 --- a/tac_ssl.py +++ b/tac_ssl.py @@ -1,4 +1,5 @@ import os +import pickle from PIL import Image from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo @@ -43,7 +44,8 @@ class CustomMultiModalDataset(Dataset): # Initialize augmentation simple_transforms = transforms.Compose([ - transforms.CenterCrop(500), + transforms.Resize((275, 275)), + #transforms.CenterCrop(500), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) @@ -52,41 +54,66 @@ data_transforms = transforms.Compose([ transforms.RandomApply([transforms.RandomRotation(150)], p=0.50), transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50), - transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), + #transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), transforms.RandomGrayscale(p=0.2), - transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5), + #transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5), ]) # Initialize dataset and dataloader vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb" tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac" dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms) -#dataloader = DataLoader(dataset, batch_size=128, shuffle=True) -# Split the dataset into 80-20 -train_size = int(0.8 * len(dataset)) -test_size = len(dataset) - train_size -train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) +preload = True +if not preload: + # Split the dataset into 80-20 + train_size = int(0.8 * len(dataset)) + test_size = len(dataset) - train_size + train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) -# Initialize dataloaders for train and test -train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True) -test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) + # Get the indices of the training and test sets + train_indices = train_dataset.indices + test_indices = test_dataset.indices + # Save these indices to disk + with open('indices/train_indices.pkl', 'wb') as f: + pickle.dump(train_indices, f) + + with open('indices/test_indices.pkl', 'wb') as f: + pickle.dump(test_indices, f) + + # Initialize dataloaders for train and test + train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True) + test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) +else: + # Load the indices from disk + with open('indices/train_indices.pkl', 'rb') as f: + train_indices = pickle.load(f) + + with open('indices/test_indices.pkl', 'rb') as f: + test_indices = pickle.load(f) + + # Create subset datasets and DataLoaders + train_subset = torch.utils.data.Subset(dataset, train_indices) + test_subset = torch.utils.data.Subset(dataset, test_indices) + + train_dataloader = DataLoader(train_subset, batch_size=96, shuffle=True) + test_dataloader = DataLoader(test_subset, batch_size=32, shuffle=False) # Initialize model -model = MultiModalMoCo(writer, K=4096, m=0.999, T=0.07).to(device) +model = MultiModalMoCo(writer, K=4096, m=0.99, T=0.07).to(device) # Initialize optimizer vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters()) tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters()) -optim_vision = optim.Adam(vision_module, lr=0.0001) -optim_tactile = optim.Adam(tactile_module, lr=0.0001) +optim_vision = optim.Adam(vision_module, lr=0.1) +optim_tactile = optim.Adam(tactile_module, lr=0.1) # Training loop -n_epochs = 250 # Number of epochs +n_epochs = 500 # Number of epochs for epoch in range(n_epochs): for i, (x_vision, x_tactile) in enumerate(train_dataloader): - + # Augment images x_vision_q = data_transforms(x_vision).to(device) x_vision_k = data_transforms(x_vision).to(device) @@ -110,10 +137,7 @@ for epoch in range(n_epochs): writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i) # Evaluate and plot - compute_tsne(model, test_dataloader, writer, epoch) - evaluate_and_plot(model, test_dataloader, epoch, writer, device) + #compute_tsne(model, test_dataloader, writer, epoch) + #evaluate_and_plot(model, test_dataloader, epoch, writer, device) if epoch % 10 == 0: - torch.save(model.state_dict(), 'models/model.pth') - - -plt.show() + torch.save(model.state_dict(), 'models/model.pth') \ No newline at end of file diff --git a/tac_ssl_test.py b/tac_ssl_test.py index ce78701..d10be3a 100644 --- a/tac_ssl_test.py +++ b/tac_ssl_test.py @@ -11,7 +11,7 @@ from PIL import Image import matplotlib.pyplot as plt from sklearn.manifold import TSNE from train_mm_moco import MultiModalMoCo -from sklearn.neighbors import NearestNeighbors +from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier def denormalize(tensor, mean, std): @@ -34,7 +34,11 @@ def compute_tsne(model, test_dataloader): vision_base_q = vision_base_q.cpu().numpy() tactile_base_q = tactile_base_q.cpu().numpy() + combined_data = np.concatenate((vision_base_q, tactile_base_q), axis=0) + nn_all = find_knn(combined_data, np.asarray(range(1,201)), n=8) + plot_images_by_labels(np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0), nn_all[0]) + image_data = np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0) tsne = TSNE(n_components=2, random_state=0, perplexity=75,n_iter=50000) @@ -47,13 +51,12 @@ def compute_tsne(model, test_dataloader): tsne_data = tsne.fit_transform(data) nn_all = find_knn(tsne_data, labels) plot_images_by_labels(image_data, nn_all[0]) - print(nn_all[:5]) fig = plt.figure(figsize=(10, 10)) for i, (x, y) in enumerate(tsne_data): - plt.scatter(x, y, color='blue') - plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom') + plt.scatter(x, y, color='blue' if labels[i] <= 100 else 'red') + #plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom') plt.savefig('temp_figure.png') plt.close(fig) @@ -65,18 +68,24 @@ def compute_tsne(model, test_dataloader): plt.axis('off') plt.show() -def find_knn(tsne_data, labels): - neigh = NearestNeighbors(n_neighbors=8) - neigh.fit(tsne_data) - knn = neigh.kneighbors(tsne_data, return_distance=False) - return labels[knn] +def find_knn(tsne_data, labels, n=6): + neigh = KNeighborsClassifier(n_neighbors=n, weights='distance') + X_train = tsne_data[1:,:] + y_train = labels[1:] + neigh.fit(X_train, y_train) + _, indices = neigh.kneighbors(tsne_data[0,:].reshape(1, -1)) + return labels[indices] def plot_images_by_labels(image_data, labels_to_plot): - fig, axes = plt.subplots(1, len(labels_to_plot), figsize=(15, 5)) + fig, axes = plt.subplots(2, len(labels_to_plot)//2, figsize=(15, 5)) + axes = axes.flatten() for i, label in enumerate(labels_to_plot): - axes[i].imshow(image_data[label].transpose(1, 2, 0)) + img = image_data[label] + normalized_image_data = (img - np.min(img)) / (np.max(img) - np.min(img)) + axes[i].imshow(normalized_image_data.transpose(1, 2, 0)) axes[i].set_title(label) axes[i].axis('off') + plt.show() if __name__ == "__main__": diff --git a/train_mm_moco.py b/train_mm_moco.py index 38a3442..47c02f1 100644 --- a/train_mm_moco.py +++ b/train_mm_moco.py @@ -18,8 +18,8 @@ class MultiModalMoCo(nn.Module): self.m = m self.T = T - self.intra_dim = 64 - self.inter_dim = 64 + self.intra_dim = 128 + self.inter_dim = 128 # Initialize the queue self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda() @@ -34,6 +34,7 @@ class MultiModalMoCo(nn.Module): def create_resnet_encoder(): resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1') + #resnet = models.regnet_x_800mf(weights='RegNet_X_800MF_Weights') features = list(resnet.children())[:-2] features.append(nn.AdaptiveAvgPool2d((1, 1))) features.append(nn.Flatten()) @@ -109,7 +110,7 @@ class MultiModalMoCo(nn.Module): tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter) # Combine losses (you can use different strategies to combine these losses) - weight_inter = 0.1 + weight_inter = 1 combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter if len_train_dataloader != 0: @@ -160,7 +161,7 @@ def compute_tsne(model, test_dataloader, writer, epoch): with torch.no_grad(): test_data_list = list(test_dataloader) x_vision_test, x_tactile_test = random.choice(test_data_list) - random_indices = random.sample(range(x_vision_test.shape[0]), 10) + random_indices = random.sample(range(x_vision_test.shape[0]), 20) x_vision_test = x_vision_test[random_indices].to('cuda') x_tactile_test = x_tactile_test[random_indices].to('cuda') vision_base_q = model.vision_base_q(x_vision_test) @@ -169,7 +170,7 @@ def compute_tsne(model, test_dataloader, writer, epoch): vision_base_q = vision_base_q.cpu().numpy() tactile_base_q = tactile_base_q.cpu().numpy() - tsne = TSNE(n_components=2, random_state=0, perplexity=5) + tsne = TSNE(n_components=2, random_state=0, perplexity=2) # Create pairs of corresponding representations and labels num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0]) @@ -189,4 +190,17 @@ def compute_tsne(model, test_dataloader, writer, epoch): image = Image.open('temp_figure.png') image = np.array(image) # Convert image to a NumPy array image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW - writer.add_image('t-SNE', image, global_step=epoch) \ No newline at end of file + writer.add_image('t-SNE', image, global_step=epoch) + + +def find_knn(query_point, data_points, k=5): + # Calculate the Euclidean distances + distances = torch.norm(data_points - query_point, dim=1) + + # Find the indices of the k smallest distances + knn_indices = torch.topk(distances, k, largest=False, sorted=True)[1] + + # Get the k smallest distances + knn_distances = distances[knn_indices] + + return knn_indices, knn_distances \ No newline at end of file