2023-08-31 10:42:13 +00:00
|
|
|
import torch
|
|
|
|
from torchvision import transforms
|
|
|
|
from torch.utils.data import random_split
|
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import pickle
|
|
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from sklearn.manifold import TSNE
|
|
|
|
from train_mm_moco import MultiModalMoCo
|
2023-09-02 15:26:10 +00:00
|
|
|
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
|
2023-08-31 10:42:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
def denormalize(tensor, mean, std):
|
|
|
|
for t, m, s in zip(tensor, mean, std):
|
|
|
|
t.mul_(s).add_(m)
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
def compute_tsne(model, test_dataloader):
|
|
|
|
with torch.no_grad():
|
|
|
|
test_data_list = list(test_dataloader)
|
|
|
|
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
|
|
|
random_indices = random.sample(range(x_vision_test.shape[0]), 100)
|
|
|
|
x_vision_test = x_vision_test[random_indices].to('cuda')
|
|
|
|
x_tactile_test = x_tactile_test[random_indices].to('cuda')
|
|
|
|
vision_base_q = model.vision_base_q(x_vision_test)
|
|
|
|
tactile_base_q = model.tactile_base_q(x_tactile_test)
|
|
|
|
|
|
|
|
x_vision_test = denormalize(x_vision_test, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
|
x_tactile_test = denormalize(x_tactile_test, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
|
|
|
|
|
vision_base_q = vision_base_q.cpu().numpy()
|
|
|
|
tactile_base_q = tactile_base_q.cpu().numpy()
|
2023-09-02 15:26:10 +00:00
|
|
|
combined_data = np.concatenate((vision_base_q, tactile_base_q), axis=0)
|
2023-08-31 10:42:13 +00:00
|
|
|
|
2023-09-02 15:26:10 +00:00
|
|
|
nn_all = find_knn(combined_data, np.asarray(range(1,201)), n=8)
|
|
|
|
plot_images_by_labels(np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0), nn_all[0])
|
|
|
|
|
2023-08-31 10:42:13 +00:00
|
|
|
image_data = np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0)
|
|
|
|
|
|
|
|
tsne = TSNE(n_components=2, random_state=0, perplexity=75,n_iter=50000)
|
|
|
|
|
|
|
|
# Create pairs of corresponding representations and labels
|
|
|
|
num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0])
|
|
|
|
data = np.concatenate((vision_base_q[:num_samples], tactile_base_q[:num_samples]), axis=0)
|
|
|
|
labels = np.arange(1, 2*(num_samples)+1)
|
|
|
|
|
|
|
|
tsne_data = tsne.fit_transform(data)
|
|
|
|
nn_all = find_knn(tsne_data, labels)
|
|
|
|
plot_images_by_labels(image_data, nn_all[0])
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(10, 10))
|
|
|
|
|
|
|
|
for i, (x, y) in enumerate(tsne_data):
|
2023-09-02 15:26:10 +00:00
|
|
|
plt.scatter(x, y, color='blue' if labels[i] <= 100 else 'red')
|
|
|
|
#plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom')
|
2023-08-31 10:42:13 +00:00
|
|
|
plt.savefig('temp_figure.png')
|
|
|
|
plt.close(fig)
|
|
|
|
|
|
|
|
image = Image.open('temp_figure.png')
|
|
|
|
image = np.array(image) # Convert image to a NumPy array
|
|
|
|
image_rgb = image[:, :, :3] # Extract RGB channels and change format to CHW
|
|
|
|
plt.imshow(image_rgb)
|
|
|
|
plt.title('t-SNE plot')
|
|
|
|
plt.axis('off')
|
|
|
|
plt.show()
|
|
|
|
|
2023-09-02 15:26:10 +00:00
|
|
|
def find_knn(tsne_data, labels, n=6):
|
|
|
|
neigh = KNeighborsClassifier(n_neighbors=n, weights='distance')
|
|
|
|
X_train = tsne_data[1:,:]
|
|
|
|
y_train = labels[1:]
|
|
|
|
neigh.fit(X_train, y_train)
|
|
|
|
_, indices = neigh.kneighbors(tsne_data[0,:].reshape(1, -1))
|
|
|
|
return labels[indices]
|
2023-08-31 10:42:13 +00:00
|
|
|
|
|
|
|
def plot_images_by_labels(image_data, labels_to_plot):
|
2023-09-02 15:26:10 +00:00
|
|
|
fig, axes = plt.subplots(2, len(labels_to_plot)//2, figsize=(15, 5))
|
|
|
|
axes = axes.flatten()
|
2023-08-31 10:42:13 +00:00
|
|
|
for i, label in enumerate(labels_to_plot):
|
2023-09-02 15:26:10 +00:00
|
|
|
img = image_data[label]
|
|
|
|
normalized_image_data = (img - np.min(img)) / (np.max(img) - np.min(img))
|
|
|
|
axes[i].imshow(normalized_image_data.transpose(1, 2, 0))
|
2023-08-31 10:42:13 +00:00
|
|
|
axes[i].set_title(label)
|
|
|
|
axes[i].axis('off')
|
2023-09-02 15:26:10 +00:00
|
|
|
|
2023-08-31 10:42:13 +00:00
|
|
|
plt.show()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
class CustomMultiModalDataset(Dataset):
|
|
|
|
def __init__(self, vision_folder, tactile_folder, transform=None):
|
|
|
|
self.vision_folder = vision_folder
|
|
|
|
self.tactile_folder = tactile_folder
|
|
|
|
self.transform = transform
|
|
|
|
|
|
|
|
self.vision_files = sorted(os.listdir(vision_folder))
|
|
|
|
self.tactile_files = sorted(os.listdir(tactile_folder))
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.vision_files)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
vision_path = os.path.join(self.vision_folder, self.vision_files[idx])
|
|
|
|
tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx])
|
|
|
|
|
|
|
|
vision_image = Image.open(vision_path).convert("RGB")
|
|
|
|
tactile_image = Image.open(tactile_path).convert("RGB")
|
|
|
|
|
|
|
|
if self.transform:
|
|
|
|
vision_image = self.transform(vision_image)
|
|
|
|
tactile_image = self.transform(tactile_image)
|
|
|
|
|
|
|
|
return vision_image, tactile_image
|
|
|
|
|
|
|
|
# Initialize augmentation
|
|
|
|
simple_transforms = transforms.Compose([
|
|
|
|
transforms.CenterCrop(500),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
|
|
])
|
|
|
|
|
|
|
|
# Load the indices from disk
|
|
|
|
with open('indices/train_indices.pkl', 'rb') as f:
|
|
|
|
train_indices = pickle.load(f)
|
|
|
|
|
|
|
|
with open('indices/test_indices.pkl', 'rb') as f:
|
|
|
|
test_indices = pickle.load(f)
|
|
|
|
|
|
|
|
# Initialize dataset and dataloader
|
|
|
|
vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb"
|
|
|
|
tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac"
|
|
|
|
dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms)
|
|
|
|
|
|
|
|
# Create subset datasets and DataLoaders
|
|
|
|
test_subset = torch.utils.data.Subset(dataset, test_indices)
|
|
|
|
test_dataloader = DataLoader(test_subset, batch_size=150, shuffle=False)
|
|
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
writer = SummaryWriter('runs/mmssl1')
|
|
|
|
model = MultiModalMoCo(writer).to('cuda')
|
|
|
|
model.load_state_dict(torch.load('/home/vedant/TacSSL/models/model.pth'))
|
|
|
|
compute_tsne(model, test_dataloader)
|