TVSSL/tac_ssl.py

143 lines
5.4 KiB
Python
Raw Normal View History

2023-08-30 11:40:13 +00:00
import os
2023-09-02 15:26:10 +00:00
import pickle
2023-08-30 11:40:13 +00:00
from PIL import Image
from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter('runs/mmssl')
# Custom dataset
class CustomMultiModalDataset(Dataset):
def __init__(self, vision_folder, tactile_folder, transform=None):
self.vision_folder = vision_folder
self.tactile_folder = tactile_folder
self.transform = transform
self.vision_files = sorted(os.listdir(vision_folder))
self.tactile_files = sorted(os.listdir(tactile_folder))
def __len__(self):
return len(self.vision_files)
def __getitem__(self, idx):
vision_path = os.path.join(self.vision_folder, self.vision_files[idx])
tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx])
vision_image = Image.open(vision_path).convert("RGB")
tactile_image = Image.open(tactile_path).convert("RGB")
if self.transform:
vision_image = self.transform(vision_image)
tactile_image = self.transform(tactile_image)
return vision_image, tactile_image
# Initialize augmentation
simple_transforms = transforms.Compose([
2023-09-02 15:26:10 +00:00
transforms.Resize((275, 275)),
#transforms.CenterCrop(500),
2023-08-30 11:40:13 +00:00
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
data_transforms = transforms.Compose([
transforms.RandomApply([transforms.RandomRotation(150)], p=0.50),
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50),
2023-09-02 15:26:10 +00:00
#transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
2023-08-30 11:40:13 +00:00
transforms.RandomGrayscale(p=0.2),
2023-09-02 15:26:10 +00:00
#transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
2023-08-30 11:40:13 +00:00
])
# Initialize dataset and dataloader
vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb"
tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac"
dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms)
2023-09-02 15:26:10 +00:00
preload = True
if not preload:
# Split the dataset into 80-20
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
2023-08-30 11:40:13 +00:00
2023-09-02 15:26:10 +00:00
# Get the indices of the training and test sets
train_indices = train_dataset.indices
test_indices = test_dataset.indices
2023-08-30 11:40:13 +00:00
2023-09-02 15:26:10 +00:00
# Save these indices to disk
with open('indices/train_indices.pkl', 'wb') as f:
pickle.dump(train_indices, f)
with open('indices/test_indices.pkl', 'wb') as f:
pickle.dump(test_indices, f)
# Initialize dataloaders for train and test
train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
else:
# Load the indices from disk
with open('indices/train_indices.pkl', 'rb') as f:
train_indices = pickle.load(f)
with open('indices/test_indices.pkl', 'rb') as f:
test_indices = pickle.load(f)
# Create subset datasets and DataLoaders
train_subset = torch.utils.data.Subset(dataset, train_indices)
test_subset = torch.utils.data.Subset(dataset, test_indices)
train_dataloader = DataLoader(train_subset, batch_size=96, shuffle=True)
test_dataloader = DataLoader(test_subset, batch_size=32, shuffle=False)
2023-08-30 11:40:13 +00:00
# Initialize model
2023-09-02 15:26:10 +00:00
model = MultiModalMoCo(writer, K=4096, m=0.99, T=0.07).to(device)
2023-08-30 11:40:13 +00:00
# Initialize optimizer
vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters())
tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters())
2023-09-02 15:26:10 +00:00
optim_vision = optim.Adam(vision_module, lr=0.1)
optim_tactile = optim.Adam(tactile_module, lr=0.1)
2023-08-30 11:40:13 +00:00
# Training loop
2023-09-02 15:26:10 +00:00
n_epochs = 500 # Number of epochs
2023-08-30 11:40:13 +00:00
for epoch in range(n_epochs):
for i, (x_vision, x_tactile) in enumerate(train_dataloader):
2023-09-02 15:26:10 +00:00
2023-08-30 11:40:13 +00:00
# Augment images
x_vision_q = data_transforms(x_vision).to(device)
x_vision_k = data_transforms(x_vision).to(device)
x_tactile_q = data_transforms(x_tactile).to(device)
x_tactile_k = data_transforms(x_tactile).to(device)
# Forward pass to get the loss
loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len(train_dataloader))
# Backward pass and optimization
optim_vision.zero_grad()
optim_tactile.zero_grad()
loss.backward()
optim_vision.step()
optim_tactile.step()
# Logging
if i % 10 == 0:
print(f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}")
writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i)
# Evaluate and plot
2023-09-02 15:26:10 +00:00
#compute_tsne(model, test_dataloader, writer, epoch)
#evaluate_and_plot(model, test_dataloader, epoch, writer, device)
2023-08-30 11:40:13 +00:00
if epoch % 10 == 0:
2023-09-02 15:26:10 +00:00
torch.save(model.state_dict(), 'models/model.pth')