Doing minor changes
This commit is contained in:
parent
c93804a2c4
commit
6a11c9ce61
64
tac_ssl.py
64
tac_ssl.py
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo
|
from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo
|
||||||
@ -43,7 +44,8 @@ class CustomMultiModalDataset(Dataset):
|
|||||||
|
|
||||||
# Initialize augmentation
|
# Initialize augmentation
|
||||||
simple_transforms = transforms.Compose([
|
simple_transforms = transforms.Compose([
|
||||||
transforms.CenterCrop(500),
|
transforms.Resize((275, 275)),
|
||||||
|
#transforms.CenterCrop(500),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
@ -52,38 +54,63 @@ data_transforms = transforms.Compose([
|
|||||||
transforms.RandomApply([transforms.RandomRotation(150)], p=0.50),
|
transforms.RandomApply([transforms.RandomRotation(150)], p=0.50),
|
||||||
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
|
||||||
transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50),
|
transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50),
|
||||||
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
|
#transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
|
||||||
transforms.RandomGrayscale(p=0.2),
|
transforms.RandomGrayscale(p=0.2),
|
||||||
transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
|
#transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
|
||||||
])
|
])
|
||||||
|
|
||||||
# Initialize dataset and dataloader
|
# Initialize dataset and dataloader
|
||||||
vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb"
|
vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb"
|
||||||
tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac"
|
tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac"
|
||||||
dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms)
|
dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms)
|
||||||
#dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
|
|
||||||
|
|
||||||
# Split the dataset into 80-20
|
preload = True
|
||||||
train_size = int(0.8 * len(dataset))
|
if not preload:
|
||||||
test_size = len(dataset) - train_size
|
# Split the dataset into 80-20
|
||||||
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
train_size = int(0.8 * len(dataset))
|
||||||
|
test_size = len(dataset) - train_size
|
||||||
|
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
||||||
|
|
||||||
# Initialize dataloaders for train and test
|
# Get the indices of the training and test sets
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True)
|
train_indices = train_dataset.indices
|
||||||
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
test_indices = test_dataset.indices
|
||||||
|
|
||||||
|
# Save these indices to disk
|
||||||
|
with open('indices/train_indices.pkl', 'wb') as f:
|
||||||
|
pickle.dump(train_indices, f)
|
||||||
|
|
||||||
|
with open('indices/test_indices.pkl', 'wb') as f:
|
||||||
|
pickle.dump(test_indices, f)
|
||||||
|
|
||||||
|
# Initialize dataloaders for train and test
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True)
|
||||||
|
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||||
|
else:
|
||||||
|
# Load the indices from disk
|
||||||
|
with open('indices/train_indices.pkl', 'rb') as f:
|
||||||
|
train_indices = pickle.load(f)
|
||||||
|
|
||||||
|
with open('indices/test_indices.pkl', 'rb') as f:
|
||||||
|
test_indices = pickle.load(f)
|
||||||
|
|
||||||
|
# Create subset datasets and DataLoaders
|
||||||
|
train_subset = torch.utils.data.Subset(dataset, train_indices)
|
||||||
|
test_subset = torch.utils.data.Subset(dataset, test_indices)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(train_subset, batch_size=96, shuffle=True)
|
||||||
|
test_dataloader = DataLoader(test_subset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = MultiModalMoCo(writer, K=4096, m=0.999, T=0.07).to(device)
|
model = MultiModalMoCo(writer, K=4096, m=0.99, T=0.07).to(device)
|
||||||
|
|
||||||
# Initialize optimizer
|
# Initialize optimizer
|
||||||
vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters())
|
vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters())
|
||||||
tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters())
|
tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters())
|
||||||
optim_vision = optim.Adam(vision_module, lr=0.0001)
|
optim_vision = optim.Adam(vision_module, lr=0.1)
|
||||||
optim_tactile = optim.Adam(tactile_module, lr=0.0001)
|
optim_tactile = optim.Adam(tactile_module, lr=0.1)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
n_epochs = 250 # Number of epochs
|
n_epochs = 500 # Number of epochs
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(n_epochs):
|
||||||
for i, (x_vision, x_tactile) in enumerate(train_dataloader):
|
for i, (x_vision, x_tactile) in enumerate(train_dataloader):
|
||||||
|
|
||||||
@ -110,10 +137,7 @@ for epoch in range(n_epochs):
|
|||||||
writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i)
|
writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i)
|
||||||
|
|
||||||
# Evaluate and plot
|
# Evaluate and plot
|
||||||
compute_tsne(model, test_dataloader, writer, epoch)
|
#compute_tsne(model, test_dataloader, writer, epoch)
|
||||||
evaluate_and_plot(model, test_dataloader, epoch, writer, device)
|
#evaluate_and_plot(model, test_dataloader, epoch, writer, device)
|
||||||
if epoch % 10 == 0:
|
if epoch % 10 == 0:
|
||||||
torch.save(model.state_dict(), 'models/model.pth')
|
torch.save(model.state_dict(), 'models/model.pth')
|
||||||
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
@ -11,7 +11,7 @@ from PIL import Image
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
from train_mm_moco import MultiModalMoCo
|
from train_mm_moco import MultiModalMoCo
|
||||||
from sklearn.neighbors import NearestNeighbors
|
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
|
||||||
|
|
||||||
|
|
||||||
def denormalize(tensor, mean, std):
|
def denormalize(tensor, mean, std):
|
||||||
@ -34,6 +34,10 @@ def compute_tsne(model, test_dataloader):
|
|||||||
|
|
||||||
vision_base_q = vision_base_q.cpu().numpy()
|
vision_base_q = vision_base_q.cpu().numpy()
|
||||||
tactile_base_q = tactile_base_q.cpu().numpy()
|
tactile_base_q = tactile_base_q.cpu().numpy()
|
||||||
|
combined_data = np.concatenate((vision_base_q, tactile_base_q), axis=0)
|
||||||
|
|
||||||
|
nn_all = find_knn(combined_data, np.asarray(range(1,201)), n=8)
|
||||||
|
plot_images_by_labels(np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0), nn_all[0])
|
||||||
|
|
||||||
image_data = np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0)
|
image_data = np.concatenate((x_vision_test.cpu().numpy(), x_tactile_test.cpu().numpy()), axis=0)
|
||||||
|
|
||||||
@ -47,13 +51,12 @@ def compute_tsne(model, test_dataloader):
|
|||||||
tsne_data = tsne.fit_transform(data)
|
tsne_data = tsne.fit_transform(data)
|
||||||
nn_all = find_knn(tsne_data, labels)
|
nn_all = find_knn(tsne_data, labels)
|
||||||
plot_images_by_labels(image_data, nn_all[0])
|
plot_images_by_labels(image_data, nn_all[0])
|
||||||
print(nn_all[:5])
|
|
||||||
|
|
||||||
fig = plt.figure(figsize=(10, 10))
|
fig = plt.figure(figsize=(10, 10))
|
||||||
|
|
||||||
for i, (x, y) in enumerate(tsne_data):
|
for i, (x, y) in enumerate(tsne_data):
|
||||||
plt.scatter(x, y, color='blue')
|
plt.scatter(x, y, color='blue' if labels[i] <= 100 else 'red')
|
||||||
plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom')
|
#plt.text(x, y, f"{labels[i]}", fontsize=12, ha='center', va='bottom')
|
||||||
plt.savefig('temp_figure.png')
|
plt.savefig('temp_figure.png')
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
@ -65,18 +68,24 @@ def compute_tsne(model, test_dataloader):
|
|||||||
plt.axis('off')
|
plt.axis('off')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def find_knn(tsne_data, labels):
|
def find_knn(tsne_data, labels, n=6):
|
||||||
neigh = NearestNeighbors(n_neighbors=8)
|
neigh = KNeighborsClassifier(n_neighbors=n, weights='distance')
|
||||||
neigh.fit(tsne_data)
|
X_train = tsne_data[1:,:]
|
||||||
knn = neigh.kneighbors(tsne_data, return_distance=False)
|
y_train = labels[1:]
|
||||||
return labels[knn]
|
neigh.fit(X_train, y_train)
|
||||||
|
_, indices = neigh.kneighbors(tsne_data[0,:].reshape(1, -1))
|
||||||
|
return labels[indices]
|
||||||
|
|
||||||
def plot_images_by_labels(image_data, labels_to_plot):
|
def plot_images_by_labels(image_data, labels_to_plot):
|
||||||
fig, axes = plt.subplots(1, len(labels_to_plot), figsize=(15, 5))
|
fig, axes = plt.subplots(2, len(labels_to_plot)//2, figsize=(15, 5))
|
||||||
|
axes = axes.flatten()
|
||||||
for i, label in enumerate(labels_to_plot):
|
for i, label in enumerate(labels_to_plot):
|
||||||
axes[i].imshow(image_data[label].transpose(1, 2, 0))
|
img = image_data[label]
|
||||||
|
normalized_image_data = (img - np.min(img)) / (np.max(img) - np.min(img))
|
||||||
|
axes[i].imshow(normalized_image_data.transpose(1, 2, 0))
|
||||||
axes[i].set_title(label)
|
axes[i].set_title(label)
|
||||||
axes[i].axis('off')
|
axes[i].axis('off')
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -18,8 +18,8 @@ class MultiModalMoCo(nn.Module):
|
|||||||
self.m = m
|
self.m = m
|
||||||
self.T = T
|
self.T = T
|
||||||
|
|
||||||
self.intra_dim = 64
|
self.intra_dim = 128
|
||||||
self.inter_dim = 64
|
self.inter_dim = 128
|
||||||
|
|
||||||
# Initialize the queue
|
# Initialize the queue
|
||||||
self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda()
|
self.queue = torch.zeros((self.K, self.intra_dim), dtype=torch.float).cuda()
|
||||||
@ -34,6 +34,7 @@ class MultiModalMoCo(nn.Module):
|
|||||||
|
|
||||||
def create_resnet_encoder():
|
def create_resnet_encoder():
|
||||||
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
|
resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V1')
|
||||||
|
#resnet = models.regnet_x_800mf(weights='RegNet_X_800MF_Weights')
|
||||||
features = list(resnet.children())[:-2]
|
features = list(resnet.children())[:-2]
|
||||||
features.append(nn.AdaptiveAvgPool2d((1, 1)))
|
features.append(nn.AdaptiveAvgPool2d((1, 1)))
|
||||||
features.append(nn.Flatten())
|
features.append(nn.Flatten())
|
||||||
@ -109,7 +110,7 @@ class MultiModalMoCo(nn.Module):
|
|||||||
tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter)
|
tactile_vision_inter = self.moco_contrastive_loss(tactile_queries_inter, vision_keys_inter)
|
||||||
|
|
||||||
# Combine losses (you can use different strategies to combine these losses)
|
# Combine losses (you can use different strategies to combine these losses)
|
||||||
weight_inter = 0.1
|
weight_inter = 1
|
||||||
combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter
|
combined_loss = vision_loss_intra + tactile_loss_intra + (vision_tactile_inter + tactile_vision_inter) * weight_inter
|
||||||
|
|
||||||
if len_train_dataloader != 0:
|
if len_train_dataloader != 0:
|
||||||
@ -160,7 +161,7 @@ def compute_tsne(model, test_dataloader, writer, epoch):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
test_data_list = list(test_dataloader)
|
test_data_list = list(test_dataloader)
|
||||||
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
x_vision_test, x_tactile_test = random.choice(test_data_list)
|
||||||
random_indices = random.sample(range(x_vision_test.shape[0]), 10)
|
random_indices = random.sample(range(x_vision_test.shape[0]), 20)
|
||||||
x_vision_test = x_vision_test[random_indices].to('cuda')
|
x_vision_test = x_vision_test[random_indices].to('cuda')
|
||||||
x_tactile_test = x_tactile_test[random_indices].to('cuda')
|
x_tactile_test = x_tactile_test[random_indices].to('cuda')
|
||||||
vision_base_q = model.vision_base_q(x_vision_test)
|
vision_base_q = model.vision_base_q(x_vision_test)
|
||||||
@ -169,7 +170,7 @@ def compute_tsne(model, test_dataloader, writer, epoch):
|
|||||||
vision_base_q = vision_base_q.cpu().numpy()
|
vision_base_q = vision_base_q.cpu().numpy()
|
||||||
tactile_base_q = tactile_base_q.cpu().numpy()
|
tactile_base_q = tactile_base_q.cpu().numpy()
|
||||||
|
|
||||||
tsne = TSNE(n_components=2, random_state=0, perplexity=5)
|
tsne = TSNE(n_components=2, random_state=0, perplexity=2)
|
||||||
|
|
||||||
# Create pairs of corresponding representations and labels
|
# Create pairs of corresponding representations and labels
|
||||||
num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0])
|
num_samples = min(vision_base_q.shape[0], tactile_base_q.shape[0])
|
||||||
@ -190,3 +191,16 @@ def compute_tsne(model, test_dataloader, writer, epoch):
|
|||||||
image = np.array(image) # Convert image to a NumPy array
|
image = np.array(image) # Convert image to a NumPy array
|
||||||
image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW
|
image = image[:, :, :3].transpose(2, 0, 1) # Extract RGB channels and change format to CHW
|
||||||
writer.add_image('t-SNE', image, global_step=epoch)
|
writer.add_image('t-SNE', image, global_step=epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def find_knn(query_point, data_points, k=5):
|
||||||
|
# Calculate the Euclidean distances
|
||||||
|
distances = torch.norm(data_points - query_point, dim=1)
|
||||||
|
|
||||||
|
# Find the indices of the k smallest distances
|
||||||
|
knn_indices = torch.topk(distances, k, largest=False, sorted=True)[1]
|
||||||
|
|
||||||
|
# Get the k smallest distances
|
||||||
|
knn_distances = distances[knn_indices]
|
||||||
|
|
||||||
|
return knn_indices, knn_distances
|
Loading…
Reference in New Issue
Block a user